diff --git a/ming/encryption.py b/ming/encryption.py index 3d1819c..4f39b80 100644 --- a/ming/encryption.py +++ b/ming/encryption.py @@ -32,8 +32,11 @@ def clean_config(cls, config: dict) -> dict: if config.get('provider_options', None): for provider, values in list((config['provider_options'] or dict()).items()): if 'key_alt_names' in values and not isinstance(values['key_alt_names'], list): - key_alt_names = [s.strip() for s in values['key_alt_names'].split(',') if s] - config['provider_options'][provider]['key_alt_names'] = key_alt_names + try: + config['provider_options'][provider]['key_alt_names'] = json.loads(values['key_alt_names']) + except json.JSONDecodeError: + key_alt_names = [s.strip() for s in values['key_alt_names'].split(',') if s] + config['provider_options'][provider]['key_alt_names'] = key_alt_names return config diff --git a/ming/tests/test_encryption.py b/ming/tests/test_encryption.py index b62c888..948212f 100644 --- a/ming/tests/test_encryption.py +++ b/ming/tests/test_encryption.py @@ -72,6 +72,42 @@ def test_validation_key_alt_names(self): encryption = ming.Session.by_name('maindb').bind.encryption self.assertEqual(encryption.provider_options, {'local': {'key_alt_names': ['datakey_test1']}}) + def test_validation_key_alt_names2(self): + config_str = f""" + ming.maindb.uri = mim://host/maindb + ming.maindb.encryption.kms_providers.local.key = {self.LOCAL_KEY} + ming.maindb.encryption.key_vault_namespace = encryption_test.dataKeyVault + ming.maindb.encryption.provider_options.local.key_alt_names = ["datakey_test1", "datakey_test2"] + """ + + self._parse_config(config_str) + encryption = ming.Session.by_name('maindb').bind.encryption + self.assertEqual(encryption.provider_options, {'local': {'key_alt_names': ['datakey_test1', 'datakey_test2']}}) + + def test_validation_key_alt_names3(self): + config_str = f""" + ming.maindb.uri = mim://host/maindb + ming.maindb.encryption.kms_providers.local.key = {self.LOCAL_KEY} + ming.maindb.encryption.key_vault_namespace = encryption_test.dataKeyVault + ming.maindb.encryption.provider_options.local.key_alt_names = datakey_test1, datakey_test2 + """ + + self._parse_config(config_str) + encryption = ming.Session.by_name('maindb').bind.encryption + self.assertEqual(encryption.provider_options, {'local': {'key_alt_names': ['datakey_test1', 'datakey_test2']}}) + + def test_validation_key_alt_names4(self): + config_str = f""" + ming.maindb.uri = mim://host/maindb + ming.maindb.encryption.kms_providers.local.key = {self.LOCAL_KEY} + ming.maindb.encryption.key_vault_namespace = encryption_test.dataKeyVault + ming.maindb.encryption.provider_options.local.key_alt_names = "datakey_test1", "datakey_test2" + """ + + self._parse_config(config_str) + encryption = ming.Session.by_name('maindb').bind.encryption + self.assertEqual(encryption.provider_options, {'local': {'key_alt_names': ['datakey_test1', 'datakey_test2']}}) + def test_validation_empty(self): self._parse_config(f'''ming.maindb.uri = mim://host/maindb''') diff --git a/ming/validators.py b/ming/validators.py index 03e78ac..e40233c 100644 --- a/ming/validators.py +++ b/ming/validators.py @@ -1,4 +1,5 @@ +import json from .exc import MingConfigError from .encryption import EncryptionConfig @@ -55,10 +56,23 @@ class EncryptionConfigValidator(validators.FancyValidator): " (https://pymongo.readthedocs.io/en/stable/api/pymongo/encryption.html#pymongo.encryption.ClientEncryption.create_data_key)."), ) - def _convert_to_python(self, field_dict, state): - if not field_dict: + def _convert_to_python(self, config: dict, state): + if not config: return None - return EncryptionConfig(field_dict) + + # ensure key_alt_names is a list + provider_options = config.get('provider_options', None) or dict() + if provider_options: + for provider, options in list(provider_options.items()): + if 'key_alt_names' in options: + if not isinstance(options['key_alt_names'], list): + try: + config['provider_options'][provider]['key_alt_names'] = json.loads(options['key_alt_names']) + except json.JSONDecodeError: + key_alt_names = [s.strip(" ][\"'\t\r\n") for s in options['key_alt_names'].split(',') if s] + config['provider_options'][provider]['key_alt_names'] = key_alt_names + + return EncryptionConfig(config) def _validate_python(self, encryption_config: EncryptionConfig, state): if not encryption_config: