Skip to content

Commit

Permalink
feat: azure key vault (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
laurasgkadri98 authored Sep 2, 2024
1 parent 30fb28e commit f1dbd35
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 13 deletions.
5 changes: 4 additions & 1 deletion configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,7 @@ FLOWKIT_PYTHON_ENDPOINT: http://0.0.0.0:50052
FLOWKIT_PYTHON_WORKERS: 2
USE_SSL: False
#SSL_CERT_PUBLIC_KEY_FILE:
#SSL_CERT_PRIVATE_KEY_FILE:
#SSL_CERT_PRIVATE_KEY_FILE:
EXTRACT_CONFIG_FROM_AZURE_KEY_VAULT: False
# AZURE_MANAGED_IDENTITY_ID:
# AZURE_KEY_VAULT_NAME:
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ classifiers = [
"Programming Language :: Python :: 3.12",
]
dependencies = [
"azure-identity >= 1.17.1,<2",
"azure-keyvault-secrets >= 4.8.0,<5",
"fastapi >= 0.111.1,<1",
"langchain >= 0.2.11,<1",
"pydantic >= 2.8.2,<3",
Expand Down
29 changes: 17 additions & 12 deletions src/allie/flowkit/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,29 +47,34 @@ def substitute_empty_values(args):
"""Substitute the empty values with configuration values."""
host = args.host or urlparse(CONFIG.flowkit_python_endpoint).hostname
port = args.port or urlparse(CONFIG.flowkit_python_endpoint).port
workers = args.workers or CONFIG.flowkit_python_workers
use_ssl = (args.use_ssl.lower() == "true") if args.use_ssl is not None else CONFIG.use_ssl
ssl_keyfile = args.ssl_keyfile or CONFIG.ssl_cert_private_key_file
ssl_certfile = args.ssl_certfile or CONFIG.ssl_cert_public_key_file
return host, port, workers, use_ssl, ssl_keyfile, ssl_certfile
CONFIG.flowkit_python_endpoint = f"http://{host}:{port}"
CONFIG.flowkit_python_workers = args.workers or CONFIG.flowkit_python_workers
CONFIG.use_ssl = (args.use_ssl.lower() == "true") if args.use_ssl is not None else CONFIG.use_ssl
CONFIG.ssl_cert_private_key_file = args.ssl_keyfile or CONFIG.ssl_cert_private_key_file
CONFIG.ssl_cert_public_key_file = args.ssl_certfile or CONFIG.ssl_cert_public_key_file
return


def main():
"""Run entrypoint for the FlowKit service."""
# Parse the command line arguments
args = parse_cli_args()
if not CONFIG.extract_config_from_azure_key_vault:
# Parse the command line arguments
args = parse_cli_args()

# Substitute the empty values with configuration values
host, port, workers, use_ssl, ssl_keyfile, ssl_certfile = substitute_empty_values(args)
# Substitute the empty values with configuration values
substitute_empty_values(args)

host = urlparse(CONFIG.flowkit_python_endpoint).hostname
port = urlparse(CONFIG.flowkit_python_endpoint).port

# Run the service
uvicorn.run(
"allie.flowkit.flowkit_service:flowkit_service",
host=host,
port=port,
workers=workers,
ssl_keyfile=ssl_keyfile if use_ssl else None,
ssl_certfile=ssl_certfile if use_ssl else None,
workers=CONFIG.flowkit_python_workers,
ssl_keyfile=CONFIG.ssl_cert_private_key_file if CONFIG.use_ssl else None,
ssl_certfile=CONFIG.ssl_cert_public_key_file if CONFIG.use_ssl else None,
)


Expand Down
64 changes: 64 additions & 0 deletions src/allie/flowkit/config/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@

"""Module for reading the configuration settings from a YAML file."""

import json
import os
from pathlib import Path

from azure.identity import ManagedIdentityCredential
from azure.keyvault.secrets import SecretClient
import yaml


Expand Down Expand Up @@ -67,6 +70,13 @@ def __init__(self):
self.use_ssl = self._yaml.get("USE_SSL", False)
self.ssl_cert_public_key_file = self._yaml.get("SSL_CERT_PUBLIC_KEY_FILE")
self.ssl_cert_private_key_file = self._yaml.get("SSL_CERT_PRIVATE_KEY_FILE")
self.extract_config_from_azure_key_vault = self._yaml.get("EXTRACT_CONFIG_FROM_AZURE_KEY_VAULT", False)
self.azure_managed_identity_id = self._yaml.get("AZURE_MANAGED_IDENTITY_ID")
self.azure_key_vault_name = self._yaml.get("AZURE_KEY_VAULT_NAME")

# If azure key vault configured, read values from vault
if self.extract_config_from_azure_key_vault:
self._get_config_from_azure_key_vault()

# Check the mandatory configuration variables
if not self.flowkit_python_api_key:
Expand Down Expand Up @@ -106,6 +116,60 @@ def _load_config(self, config_path: str) -> dict:
except FileNotFoundError:
raise FileNotFoundError("Configuration file not found at the default location.")

def _get_config_from_azure_key_vault(self):
"""Extract configuration from Azure Key Vault and set attributes."""
# Check if all required environment variables are set
if not self.azure_managed_identity_id:
raise ValueError(f"Environment variable for {self.azure_managed_identity_id} is not set")
if not self.azure_key_vault_name:
raise ValueError(f"Environment variable for {self.azure_key_vault_name} is not set")

# Create Key Vault URL
key_vault_url = f"https://{self.azure_key_vault_name}.vault.azure.net/"

# Create Managed Identity credential
credential = ManagedIdentityCredential(client_id=self.azure_managed_identity_id)

# Test the managed identity by getting a token
scope = "https://vault.azure.net/.default"
token = credential.get_token(scope)
if not token:
raise ValueError("Failed to get token from managed ID")

# Create Azure Key Vault SecretClient
client = SecretClient(vault_url=key_vault_url, credential=credential)

# List all secrets
secret_properties = client.list_properties_of_secrets()

# Reflect on the fields of the Config class
global_config_fields = {field: value for field, value in self.__dict__.items() if not field.startswith("_")}

# Iterate over all secrets
for secret_property in secret_properties:
secret_name = secret_property.name
secret_value = client.get_secret(secret_name).value

# Format the field name to match the secret name format
formatted_field_name = secret_name.replace("_", "").upper()

# Match secret names to Config class fields and set values
for field_name in global_config_fields:
# Remove underscores and convert to uppercase for matching
if field_name.replace("_", "").upper() == formatted_field_name:
# Handle different field types
field_type = type(getattr(self, field_name))
if field_type is str:
setattr(self, field_name, secret_value)
elif field_type is bool:
setattr(self, field_name, secret_value.lower() == "true")
elif field_type is int:
setattr(self, field_name, int(secret_value))
elif field_type is list:
setattr(self, field_name, json.loads(secret_value))
else:
raise ValueError(f"Unsupported field type: {field_type}")


# Initialize the config object
CONFIG = Config()

0 comments on commit f1dbd35

Please sign in to comment.