diff --git a/configs/config.yaml b/configs/config.yaml index 4666b6f..adca223 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -3,4 +3,7 @@ FLOWKIT_PYTHON_ENDPOINT: http://0.0.0.0:50052 FLOWKIT_PYTHON_WORKERS: 2 USE_SSL: False #SSL_CERT_PUBLIC_KEY_FILE: -#SSL_CERT_PRIVATE_KEY_FILE: \ No newline at end of file +#SSL_CERT_PRIVATE_KEY_FILE: +EXTRACT_CONFIG_FROM_AZURE_KEY_VAULT: False +# AZURE_MANAGED_IDENTITY_ID: +# AZURE_KEY_VAULT_NAME: \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 73174eb..24c658b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,8 @@ classifiers = [ "Programming Language :: Python :: 3.12", ] dependencies = [ + "azure-identity >= 1.17.1,<2", + "azure-keyvault-secrets >= 4.8.0,<5", "fastapi >= 0.111.1,<1", "langchain >= 0.2.11,<1", "pydantic >= 2.8.2,<3", diff --git a/src/allie/flowkit/__main__.py b/src/allie/flowkit/__main__.py index 64d5f45..f54dc70 100644 --- a/src/allie/flowkit/__main__.py +++ b/src/allie/flowkit/__main__.py @@ -47,29 +47,34 @@ def substitute_empty_values(args): """Substitute the empty values with configuration values.""" host = args.host or urlparse(CONFIG.flowkit_python_endpoint).hostname port = args.port or urlparse(CONFIG.flowkit_python_endpoint).port - workers = args.workers or CONFIG.flowkit_python_workers - use_ssl = (args.use_ssl.lower() == "true") if args.use_ssl is not None else CONFIG.use_ssl - ssl_keyfile = args.ssl_keyfile or CONFIG.ssl_cert_private_key_file - ssl_certfile = args.ssl_certfile or CONFIG.ssl_cert_public_key_file - return host, port, workers, use_ssl, ssl_keyfile, ssl_certfile + CONFIG.flowkit_python_endpoint = f"http://{host}:{port}" + CONFIG.flowkit_python_workers = args.workers or CONFIG.flowkit_python_workers + CONFIG.use_ssl = (args.use_ssl.lower() == "true") if args.use_ssl is not None else CONFIG.use_ssl + CONFIG.ssl_cert_private_key_file = args.ssl_keyfile or CONFIG.ssl_cert_private_key_file + CONFIG.ssl_cert_public_key_file = args.ssl_certfile or CONFIG.ssl_cert_public_key_file + return def main(): """Run entrypoint for the FlowKit service.""" - # Parse the command line arguments - args = parse_cli_args() + if not CONFIG.extract_config_from_azure_key_vault: + # Parse the command line arguments + args = parse_cli_args() - # Substitute the empty values with configuration values - host, port, workers, use_ssl, ssl_keyfile, ssl_certfile = substitute_empty_values(args) + # Substitute the empty values with configuration values + substitute_empty_values(args) + + host = urlparse(CONFIG.flowkit_python_endpoint).hostname + port = urlparse(CONFIG.flowkit_python_endpoint).port # Run the service uvicorn.run( "allie.flowkit.flowkit_service:flowkit_service", host=host, port=port, - workers=workers, - ssl_keyfile=ssl_keyfile if use_ssl else None, - ssl_certfile=ssl_certfile if use_ssl else None, + workers=CONFIG.flowkit_python_workers, + ssl_keyfile=CONFIG.ssl_cert_private_key_file if CONFIG.use_ssl else None, + ssl_certfile=CONFIG.ssl_cert_public_key_file if CONFIG.use_ssl else None, ) diff --git a/src/allie/flowkit/config/_config.py b/src/allie/flowkit/config/_config.py index 81e9637..090e08f 100644 --- a/src/allie/flowkit/config/_config.py +++ b/src/allie/flowkit/config/_config.py @@ -22,9 +22,12 @@ """Module for reading the configuration settings from a YAML file.""" +import json import os from pathlib import Path +from azure.identity import ManagedIdentityCredential +from azure.keyvault.secrets import SecretClient import yaml @@ -67,6 +70,13 @@ def __init__(self): self.use_ssl = self._yaml.get("USE_SSL", False) self.ssl_cert_public_key_file = self._yaml.get("SSL_CERT_PUBLIC_KEY_FILE") self.ssl_cert_private_key_file = self._yaml.get("SSL_CERT_PRIVATE_KEY_FILE") + self.extract_config_from_azure_key_vault = self._yaml.get("EXTRACT_CONFIG_FROM_AZURE_KEY_VAULT", False) + self.azure_managed_identity_id = self._yaml.get("AZURE_MANAGED_IDENTITY_ID") + self.azure_key_vault_name = self._yaml.get("AZURE_KEY_VAULT_NAME") + + # If azure key vault configured, read values from vault + if self.extract_config_from_azure_key_vault: + self._get_config_from_azure_key_vault() # Check the mandatory configuration variables if not self.flowkit_python_api_key: @@ -106,6 +116,60 @@ def _load_config(self, config_path: str) -> dict: except FileNotFoundError: raise FileNotFoundError("Configuration file not found at the default location.") + def _get_config_from_azure_key_vault(self): + """Extract configuration from Azure Key Vault and set attributes.""" + # Check if all required environment variables are set + if not self.azure_managed_identity_id: + raise ValueError(f"Environment variable for {self.azure_managed_identity_id} is not set") + if not self.azure_key_vault_name: + raise ValueError(f"Environment variable for {self.azure_key_vault_name} is not set") + + # Create Key Vault URL + key_vault_url = f"https://{self.azure_key_vault_name}.vault.azure.net/" + + # Create Managed Identity credential + credential = ManagedIdentityCredential(client_id=self.azure_managed_identity_id) + + # Test the managed identity by getting a token + scope = "https://vault.azure.net/.default" + token = credential.get_token(scope) + if not token: + raise ValueError("Failed to get token from managed ID") + + # Create Azure Key Vault SecretClient + client = SecretClient(vault_url=key_vault_url, credential=credential) + + # List all secrets + secret_properties = client.list_properties_of_secrets() + + # Reflect on the fields of the Config class + global_config_fields = {field: value for field, value in self.__dict__.items() if not field.startswith("_")} + + # Iterate over all secrets + for secret_property in secret_properties: + secret_name = secret_property.name + secret_value = client.get_secret(secret_name).value + + # Format the field name to match the secret name format + formatted_field_name = secret_name.replace("_", "").upper() + + # Match secret names to Config class fields and set values + for field_name in global_config_fields: + # Remove underscores and convert to uppercase for matching + if field_name.replace("_", "").upper() == formatted_field_name: + # Handle different field types + field_type = type(getattr(self, field_name)) + if field_type is str: + setattr(self, field_name, secret_value) + elif field_type is bool: + setattr(self, field_name, secret_value.lower() == "true") + elif field_type is int: + setattr(self, field_name, int(secret_value)) + elif field_type is list: + setattr(self, field_name, json.loads(secret_value)) + else: + raise ValueError(f"Unsupported field type: {field_type}") + # Initialize the config object CONFIG = Config()