From 58a0b5a984578f72ef7e229c02c7b91ae0e3dffb Mon Sep 17 00:00:00 2001 From: Dillon Walls Date: Tue, 10 Sep 2024 16:35:55 -0400 Subject: [PATCH] fixup! add support for mongodb Client Side Field Level Encryption (CSFLE) --- ming/config.py | 2 + ming/datastore.py | 18 ++++- ming/encryption.py | 133 +++++++++++++++++++++++++++++++++ ming/metadata.py | 3 +- ming/tests/test_datastore.py | 24 ++++++ ming/tests/test_declarative.py | 39 ++++++++++ ming/validators.py | 57 ++++++++++++++ setup.py | 1 + 8 files changed, 272 insertions(+), 5 deletions(-) create mode 100644 ming/encryption.py create mode 100644 ming/validators.py diff --git a/ming/config.py b/ming/config.py index f4a00fd..c266aeb 100644 --- a/ming/config.py +++ b/ming/config.py @@ -24,6 +24,7 @@ def configure(**kwargs): def configure_from_nested_dict(config): try: from formencode import schema, validators + import ming.validators as ming_validators except ImportError: raise MingConfigError("Need to install FormEncode to use ``ming.configure``") @@ -36,6 +37,7 @@ class DatastoreSchema(schema.Schema): auto_ensure_indexes = validators.StringBool(if_missing=True) # pymongo tz_aware = validators.Bool(if_missing=False) + encryption = ming_validators.EncryptionConfigValidator(if_missing=None) datastores = {} for name, datastore in config.items(): diff --git a/ming/datastore.py b/ming/datastore.py index 1117ac8..d690406 100644 --- a/ming/datastore.py +++ b/ming/datastore.py @@ -1,7 +1,7 @@ import time import logging from threading import Lock -from typing import Union +from typing import Union, TYPE_CHECKING import urllib from pymongo import MongoClient from pymongo.database import Database @@ -11,6 +11,9 @@ from . import mim from . import exc +if TYPE_CHECKING: + from . import encryption + Conn = Union[mim.Connection, MongoClient] @@ -74,6 +77,8 @@ def create_datastore(uri, **kwargs): if database.startswith("/"): database = database[1:] + encryption_config: encryption.EncryptionConfig = kwargs.pop('encryption', None) + if uri: # User provided a valid connection URL. if bind: @@ -85,7 +90,7 @@ def create_datastore(uri, **kwargs): # Create engine without connection. bind = create_engine(**kwargs) - return DataStore(bind, database) + return DataStore(bind, database, encryption_config) class Engine: @@ -135,6 +140,7 @@ def connect(self): try: with self._lock: if self._conn is None: + # NOTE: Runs MongoClient/EncryptionClient self._conn = self._Connection( *self._conn_args, **self._conn_kwargs) else: @@ -159,10 +165,10 @@ class DataStore: :func:`.create_datastore` function. """ - def __init__(self, bind, name, authenticate=None): + def __init__(self, bind, name, encryption_config: encryption.EncryptionConfig = None): self.bind = bind self.name = name - self._authenticate = authenticate + self._encryption_config = encryption_config self._db = None def __repr__(self): # pragma no cover @@ -191,3 +197,7 @@ def db(self) -> Database: self._db = self.bind[self.name] return self._db + + @property + def encryption(self) -> encryption.EncryptionConfig | None: + return self._encryption_config diff --git a/ming/encryption.py b/ming/encryption.py new file mode 100644 index 0000000..1673023 --- /dev/null +++ b/ming/encryption.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, TypeVar, Generic + +from cachetools import cached, RRCache + +from pymongo.encryption import ClientEncryption, Algorithm +from pymongo.errors import EncryptionError + + +if TYPE_CHECKING: + from pymongo import MongoClient + from ming import Document + import ming.datastore + + +class EncryptionConfig: + + def __init__(self, encryption_config): + self._encryption_config = encryption_config + + @property + def kms_providers(self) -> str: + return self._encryption_config.get('kms_providers') + + @property + def provider_options(self) -> str: + return self._encryption_config.get('provider_options') + + @property + def key_vault_namespace(self) -> str: + return self._encryption_config.get('key_vault_namespace') + + +T = TypeVar('T') + + +class DecryptedField(Generic[T]): + + def __init__(self, field_type: type[T], encrypted_field: str): + self.field_type = field_type + self.encrypted_field = encrypted_field + + def __get__(self, instance: EncryptedDocumentMixin, owner) -> T: + return instance.decr(getattr(instance, self.encrypted_field)) + + def __set__(self, instance: EncryptedDocumentMixin, value: T): + if not isinstance(value, self.field_type): + raise TypeError(f'not {self.field_type}, got {value!r}') + setattr(instance, self.encrypted_field, instance.encr(value)) + + +class EncryptedDocumentMixin: + + @classmethod + @cached(RRCache(maxsize=99)) # needs to be per datastore, so we pass that as a param + def encryptor(cls, ming_ds: ming.datastore.DataStore): + conn: MongoClient = ming_ds.conn + kms_providers = {"local": {"key": ming_ds.encryption_key}} + encryption = ClientEncryption(kms_providers, ming_ds.encr_data_key_vault, + conn, conn.codec_options) + return encryption + + @classmethod + def make_data_key(cls): + ming_ds: ming.datastore.DataStore = cls.m.session.bind + # index recommended by mongodb docs: + key_vault_db_name, key_vault_coll_name = ming_ds.encr_data_key_vault.split('.') + key_vault_coll = ming_ds.conn[key_vault_db_name][key_vault_coll_name] + key_vault_coll.create_index("keyAltNames", unique=True, + partialFilterExpression={"keyAltNames": {"$exists": True}}) + cls.encryptor(ming_ds).create_data_key('local', key_alt_names=[ming_ds.encr_data_key_name]) + + # cls.encryptor(ming_ds).create_data_key('local', **ming_ds['provider_options']['local']) + # cls.encryptor(ming_ds).create_data_key('aws', **ming_ds['provider_options']['aws']) + + @classmethod + def encr(cls, s: str | None, _first_attempt=True) -> bytes | None: + if s is None: + return None + try: + ming_ds: ming.datastore.DataStore = cls.m.session.bind + return cls.encryptor(ming_ds).encrypt(s, + Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, + key_alt_name=cls.ming_ds.encr_data_key_name) + except EncryptionError as e: + if _first_attempt and 'not all keys requested were satisfied' in str(e): + cls.make_data_key() + return cls.encr(s, _first_attempt=False) + else: + raise + + @classmethod + def decr(cls, b: bytes | None) -> str | None: + if b is None: + return None + return cls.encryptor(cls.m.session.bind).decrypt(b) + + @classmethod + def decrypted_field_names(cls) -> list[str]: + return [fld.replace('_encrypted', '') + for fld in cls.encrypted_field_names()] + + @classmethod + def encrypted_field_names(cls) -> list[str]: + return [fld for fld in dir(cls) + if fld.endswith('_encrypted')] + + @classmethod + def encrypt_some_fields(cls, data: dict) -> dict: + encrypted_data = data.copy() + for fld in cls.decrypted_field_names(): + if fld in encrypted_data: + val = encrypted_data.pop(fld) + encrypted_data[f'{fld}_encrypted'] = cls.encr(val) + return encrypted_data + + def decrypt_some_fields(self) -> dict: + # useful for json, removes encrypted fields and uses decrypted forms + decrypted_data = dict(self) + for k, v in self.items(): + if k.endswith('_encrypted'): + del decrypted_data[k] + k_decrypted = k.replace('_encrypted', '') + decrypted_data[k_decrypted] = getattr(self, k_decrypted) + return decrypted_data + + @classmethod + def make_encr(cls, data: dict) -> Document: + # wrapper around regular ming .make() + data_with_encryption = cls.encrypt_some_fields(data) + return cls.make(data_with_encryption) + diff --git a/ming/metadata.py b/ming/metadata.py index 9e6476e..b0c0a3f 100644 --- a/ming/metadata.py +++ b/ming/metadata.py @@ -10,6 +10,7 @@ from .base import Object from .utils import fixup_index, LazyProperty from .exc import MongoGone +from .encryption import EncryptedDocumentMixin log = logging.getLogger(__name__) @@ -401,7 +402,7 @@ def __delete__(self, inst): del inst[self.name] -class _Document(Object): +class _Document(Object, EncryptedDocumentMixin): def __init__(self, data=None, skip_from_bson=False): if data is None: diff --git a/ming/tests/test_datastore.py b/ming/tests/test_datastore.py index ba38ad6..265ef65 100644 --- a/ming/tests/test_datastore.py +++ b/ming/tests/test_datastore.py @@ -175,6 +175,30 @@ def test_configure_optional_params(self): assert session.bind.conn is not None assert session.bind.db is not None + def test_configure_encryption(self): + # Generate a base64 encoded random string: + # ```python + # from base64 import b64encode; import secrets; + # b64encode(secrets.token_hex().encode('utf8')).decode('utf8') + # ``` + encryption_key = 'ODdlZGMzNjZlZWFmYTVlMDhhYWM0ZTBhNTQ5ZTE2YzQ3OWZmMzA5MDUxMDhhOTVlN2UyYTMzNzBkZDE5OGRhMg==' + ming.configure(**{ + 'ming.main.uri': 'mongodb://localhost:27017/test_db', + 'ming.main.replicaSet': 'foobar', + 'ming.main.foo.bar': 'foobar', + 'ming.main.encryption.kms_providers.local.key': encryption_key, + 'ming.main.encryption.key_vault_namespace': 'encryption.collectionName', + 'ming.main.encryption.provider_options.local.key_alt_names': ['datakeyName'], + # 'ming.main.encryption.provider_options.aws.key_alt_names': ['datakeyName'], + # 'ming.main.encryption.provider_options.aws.master_key': ['datakeyName'], + }) + session = Session.by_name('main') + assert session.bind.conn is not None, session.bind.conn + assert session.bind.db is not None, session.bind.db + assert session.bind.encryption.kms_providers == {'local': {'key': encryption_key}}, session.bind.encryption.kms_providers + assert session.bind.encryption.key_vault_namespace == 'encryption.collectionName', session.bind.encryption.key_vault_namespace + assert session.bind.encryption.provider_options == {'local': {'key_alt_names': ['datakeyName']}}, session.bind.encryption.provider_options + def test_no_kwargs_with_bind(self): self.assertRaises( ming.exc.MingConfigError, diff --git a/ming/tests/test_declarative.py b/ming/tests/test_declarative.py index 7be426d..cc50aa7 100644 --- a/ming/tests/test_declarative.py +++ b/ming/tests/test_declarative.py @@ -9,6 +9,7 @@ from ming.base import Cursor from ming.datastore import create_datastore from ming.declarative import Document +from ming.encryption import DecryptedField from ming.metadata import Field, Index from ming import schema as S from ming.odm.odmsession import ODMSession, ThreadLocalODMSession @@ -170,6 +171,44 @@ def test_field(self): self.assertRaises(AttributeError, getattr, doc, 'a') self.assertEqual(self.session.count(self.TestDoc), 1) + +class TestDocumentEncryptionReal(TestCase): + DATASTORE = f"mongodb://localhost/test_ming_TestDocumentReal_{os.getpid()}?serverSelectionTimeoutMS=100" + + def setUp(self): + self.datastore = create_datastore(self.DATASTORE) + self.session = Session(bind=self.datastore) + + class TestDoc(Document): + class __mongometa__: + name='test_doc' + session = self.session + indexes = [ ('a',) ] + _id = Field(S.Anything) + name = DecryptedField(str, 'name_encrypted') + name_encrypted = Field(S.Binary) + + self.TestDoc = TestDoc + + def tearDown(self): + self.TestDoc.m.remove() + # FIXME: teardown/ remove the encryption collection. likely in a different database + + def test_field(self): + doc = self.TestDoc(dict(_id=1, a=1, b=dict(a=5))) + doc.m.save() + + self.assertEqual(doc.a, 1) + self.assertEqual(doc.b, dict(a=5)) + doc.a = 5 + self.assertEqual(doc, dict(_id=1, a=5, b=dict(a=5))) + del doc.a + self.assertEqual(doc, dict(_id=1, b=dict(a=5))) + self.assertRaises(AttributeError, getattr, doc, 'c') + self.assertRaises(AttributeError, getattr, doc, 'a') + self.assertEqual(self.session.count(self.TestDoc), 1) + + class TestIndexes(TestCase): def setUp(self): diff --git a/ming/validators.py b/ming/validators.py new file mode 100644 index 0000000..003205f --- /dev/null +++ b/ming/validators.py @@ -0,0 +1,57 @@ + +from .exc import MingConfigError +from .encryption import EncryptionConfig + +try: + from formencode import schema, validators +except ImportError: + raise MingConfigError("Need to install FormEncode to use ``ming.encryption`` package") + + +class EncryptionConfigValidator(validators.FancyValidator): + """ + Password string validation to refrain from usage of name, username or email id (full or partial) as its substrings + """ + accept_iterator = True + + VALID_KMS_PROVIDERS = ('local', 'aws', 'azure', 'gcp', 'kmip') + + messages = dict( + InvalidKMSProvider=f"Invalid KMS Provider %(provider). Valid options are: {', '.join(VALID_KMS_PROVIDERS)}" + ) + + def _convert_to_python(self, field_dict, state): + if not field_dict: + return None + return EncryptionConfig(field_dict) + + def _validate_python(self, config: EncryptionConfig, state): + if not config: + return # no encryption settings + + def validate_inner(): + error_dict = {} + if config.kms_providers: + if not config.key_vault_namespace: + error_dict['key_vault_namespace'] = validators.Invalid(self.message('InvalidKMSProvider', state), config, state) + if not config.provider_options: + error_dict['provider_options'] = validators.Invalid(self.message('InvalidKMSProvider', state), config, state) + + used_kms_providers = set() + for key, settings in config.kms_providers.items(): + if key not in self.VALID_KMS_PROVIDERS: + error_dict['kms_providers'] = validators.Invalid(self.message('InvalidKMSProvider', state, provider=key), config, state) + break + used_kms_providers.add(key) + + if error_dict: + return error_dict + + if 'local' in used_kms_providers: + if 'key' not in config.kms_providers['local']: + error_dict['kms_providers'] = validators.Invalid(f'Local KMS Provider requires a key', config, state) + + error_dict = validate_inner() + + if error_dict: + raise validators.Invalid(f'Invalid Encryption Settings', config, state, error_dict=error_dict) diff --git a/setup.py b/setup.py index 29a3757..8fe58f0 100644 --- a/setup.py +++ b/setup.py @@ -34,6 +34,7 @@ include_package_data=True, zip_safe=True, install_requires=[ + "cachetools", "pymongo[encryption]", "pytz", ],