Skip to content

Commit

Permalink
Migrate MSI Server to MSAL (#1460)
Browse files Browse the repository at this point in the history
* Migrate MSI Server to MSAL

* Use MSAL for access token

* Pip install msal

* Include /.default with resourceId
  • Loading branch information
Ed-Maeng authored Aug 30, 2024
1 parent dcf8a20 commit 44a6d4b
Show file tree
Hide file tree
Showing 4 changed files with 165 additions and 4 deletions.
123 changes: 123 additions & 0 deletions DeploymentCloud/Deployment.Common/scripts/msalmsiserver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

#!/usr/bin/env python

import msal
import json

from urlparse import urlparse
from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer
import hdinsight_common.ClusterManifestParser as ClusterManifestParser

"""
This script exposes a local http endpoint which the spark jobs can call to get the MSI access token associated with the HDI cluster.
Note that since it's a local endpoint, it's accessible only from within the cluster and not from outside.
Usage:
http://localhost:40382/managed/identity/oauth2/token?resource=<resourceid>&api-version=2018-11-01
Example:
curl -H "Metadata: true" -X GET "http://localhost:40382/managed/identity/oauth2/token?resource=https://vault.azure.net&api-version=2018-11-01"
"""

class Constants(object):
loopback_address = '127.0.0.1'
server_port = 40382
token_url_path = '/managed/identity/oauth2/token'
header_metadata = 'Metadata'
query_resource = 'resource'
cert_location = '/var/lib/waagent/{0}.prv'
aad_login_endpoint = 'https://login.microsoftonline.com/{0}'

class ManagedIdentityTokenResponse(object):
def __init__(self):
self.access_token = None
self.token_type = None
self.resource = None

class ManagedIdentityHandler(BaseHTTPRequestHandler):
def _add_to_query_dict(self, query_dict, query):
query_dict[query.split('=')[0]] = query.split('=')[1]

def _validate_request(self):
msg = ''

if self.headers[Constants.header_metadata] != 'true':
msg += 'Metadata header is required\n'

if self.client_address[0] != Constants.loopback_address:
msg += 'Only request from loopback address 127.0.0.1 is allowed\n'

url = urlparse(self.path)
if url.path != Constants.token_url_path:
msg += 'Unknown path {0}\n'.format(url.path)

return msg

def _get_cluster_manifest(self):
return ClusterManifestParser.parse_local_manifest()

def _get_private_key(self, filename):
with open(filename, 'r') as cert_file:
private_cert = cert_file.read()
return private_cert

def _acquire_token(self, resource):
cluster_manifest = self._get_cluster_manifest()
msi_settings = json.loads(cluster_manifest.settings['managedServiceIdentity'])
# Assuming there is only 1 MSI associated with the cluster, get the first one
msi_setting = list(msi_settings.values())[0]

thumbprint = msi_setting['thumbprint']
client_id = msi_setting['clientId']
tenant_id = msi_setting['tenantId']

authority = Constants.aad_login_endpoint.format(tenant_id)
file_name = Constants.cert_location.format(thumbprint)
key = self._get_private_key(file_name)

app = msal.ConfidentialClientApplication(
client_id,
authority=authority,
client_credential={"private_key": key, "thumbprint": thumbprint}
)

auth_result = app.acquire_token_for_client(scopes=[resource])

res = ManagedIdentityTokenResponse()
res.access_token = auth_result['access_token']
res.token_type = auth_result['token_type']
res.resource = resource

return res

def do_GET(self):
try:
msg = self._validate_request()

if msg:
self.send_response(400)
self.end_headers()
self.wfile.write(msg)
return

url = urlparse(self.path)
queries = {}
map(lambda q: self._add_to_query_dict(queries, q), url.query.split('&'))
res = self._acquire_token(queries[Constants.query_resource])

self.send_response(200)
self.send_header('Content-type', 'application/json')
self.end_headers()
self.wfile.write(json.dumps(res.__dict__))
except Exception:
self.send_response(500)
self.end_headers()
self.wfile.write("Internal server error, please see server log")

if __name__ == "__main__":
server_address = (Constants.loopback_address, Constants.server_port)
httpd = HTTPServer(server_address, ManagedIdentityHandler)
print('Starting http server...')
httpd.serve_forever()
10 changes: 10 additions & 0 deletions DeploymentCloud/Deployment.Common/scripts/msalmsiserverapp.service
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[Unit]
Description=hdinsight msalmsiserver app

[Service]
Type=idle

ExecStart=/usr/bin/python /usr/hdinsight/msalmsiserver.py &

[Install]
WantedBy=multi-user.target
1 change: 0 additions & 1 deletion DeploymentCloud/Deployment.Common/scripts/msiserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import adal
import json
import SocketServer

from urlparse import urlparse
from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer
Expand Down
35 changes: 32 additions & 3 deletions DeploymentCloud/Deployment.Common/scripts/startmsiserverservice.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,35 @@
sudo hdfs dfs -get wasbs://scripts@$sparkBlobAccountName.blob.core.windows.net/msiserver.py /usr/hdinsight/msiserver.py
sudo hdfs dfs -get wasbs://scripts@$sparkBlobAccountName.blob.core.windows.net/msiserverapp.service /etc/systemd/system/msiserverapp.service
#!/bin/bash

echo "Install Python Packages"
pip install msal

echo "Remove the existing files if they exist"
sudo rm -f /usr/hdinsight/msiserver.py
sudo rm -f /etc/systemd/system/msiserverapp.service

sudo rm -f /usr/hdinsight/msalmsiserver.py
sudo rm -f /etc/systemd/system/msalmsiserverapp.service

echo "Download the files from HDFS/Blob storage"
sudo hdfs dfs -copyToLocal wasbs://scriptactions@$sparkBlobAccountName.blob.core.windows.net/msiserver.py /usr/hdinsight/msiserver.py
sudo hdfs dfs -copyToLocal wasbs://scriptactions@$sparkBlobAccountName.blob.core.windows.net/msiserverapp.service /etc/systemd/system/msiserverapp.service

sudo hdfs dfs -copyToLocal wasbs://scriptactions@$sparkBlobAccountName.blob.core.windows.net/msalmsiserver.py /usr/hdinsight/msalmsiserver.py
sudo hdfs dfs -copyToLocal wasbs://scriptactions@$sparkBlobAccountName.blob.core.windows.net/msalmsiserverapp.service /etc/systemd/system/msalmsiserverapp.service

echo "Change the permission of the file"
sudo chmod 644 /etc/systemd/system/msiserverapp.service
sudo chmod 644 /etc/systemd/system/msalmsiserverapp.service

echo "Reload the systemd manager configuration to apply the changes"
sudo systemctl daemon-reload

echo "Enable the service to start on boot"
sudo systemctl enable msiserverapp.service
sudo systemctl start msiserverapp.service
sudo systemctl enable msalmsiserverapp.service

echo "Start the service"
sudo systemctl start msiserverapp.service
sudo systemctl start msalmsiserverapp.service

echo "Script execution completed"

0 comments on commit 44a6d4b

Please sign in to comment.