Skip to content

Commit

Permalink
fixup! add support for mongodb Client Side Field Level Encryption (CS…
Browse files Browse the repository at this point in the history
…FLE)
  • Loading branch information
dill0wn committed Sep 10, 2024
1 parent 52b60d6 commit 58a0b5a
Show file tree
Hide file tree
Showing 8 changed files with 272 additions and 5 deletions.
2 changes: 2 additions & 0 deletions ming/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def configure(**kwargs):
def configure_from_nested_dict(config):
try:
from formencode import schema, validators
import ming.validators as ming_validators
except ImportError:
raise MingConfigError("Need to install FormEncode to use ``ming.configure``")

Expand All @@ -36,6 +37,7 @@ class DatastoreSchema(schema.Schema):
auto_ensure_indexes = validators.StringBool(if_missing=True)
# pymongo
tz_aware = validators.Bool(if_missing=False)
encryption = ming_validators.EncryptionConfigValidator(if_missing=None)

datastores = {}
for name, datastore in config.items():
Expand Down
18 changes: 14 additions & 4 deletions ming/datastore.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time
import logging
from threading import Lock
from typing import Union
from typing import Union, TYPE_CHECKING
import urllib
from pymongo import MongoClient
from pymongo.database import Database
Expand All @@ -11,6 +11,9 @@
from . import mim
from . import exc

if TYPE_CHECKING:
from . import encryption

Conn = Union[mim.Connection, MongoClient]


Expand Down Expand Up @@ -74,6 +77,8 @@ def create_datastore(uri, **kwargs):
if database.startswith("/"):
database = database[1:]

encryption_config: encryption.EncryptionConfig = kwargs.pop('encryption', None)

if uri:
# User provided a valid connection URL.
if bind:
Expand All @@ -85,7 +90,7 @@ def create_datastore(uri, **kwargs):
# Create engine without connection.
bind = create_engine(**kwargs)

return DataStore(bind, database)
return DataStore(bind, database, encryption_config)


class Engine:
Expand Down Expand Up @@ -135,6 +140,7 @@ def connect(self):
try:
with self._lock:
if self._conn is None:
# NOTE: Runs MongoClient/EncryptionClient
self._conn = self._Connection(
*self._conn_args, **self._conn_kwargs)
else:
Expand All @@ -159,10 +165,10 @@ class DataStore:
:func:`.create_datastore` function.
"""

def __init__(self, bind, name, authenticate=None):
def __init__(self, bind, name, encryption_config: encryption.EncryptionConfig = None):
self.bind = bind
self.name = name
self._authenticate = authenticate
self._encryption_config = encryption_config
self._db = None

def __repr__(self): # pragma no cover
Expand Down Expand Up @@ -191,3 +197,7 @@ def db(self) -> Database:

self._db = self.bind[self.name]
return self._db

@property
def encryption(self) -> encryption.EncryptionConfig | None:
return self._encryption_config
133 changes: 133 additions & 0 deletions ming/encryption.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from __future__ import annotations

from typing import TYPE_CHECKING, TypeVar, Generic

from cachetools import cached, RRCache

from pymongo.encryption import ClientEncryption, Algorithm
from pymongo.errors import EncryptionError


if TYPE_CHECKING:
from pymongo import MongoClient
from ming import Document
import ming.datastore


class EncryptionConfig:

def __init__(self, encryption_config):
self._encryption_config = encryption_config

@property
def kms_providers(self) -> str:
return self._encryption_config.get('kms_providers')

@property
def provider_options(self) -> str:
return self._encryption_config.get('provider_options')

@property
def key_vault_namespace(self) -> str:
return self._encryption_config.get('key_vault_namespace')


T = TypeVar('T')


class DecryptedField(Generic[T]):

def __init__(self, field_type: type[T], encrypted_field: str):
self.field_type = field_type
self.encrypted_field = encrypted_field

def __get__(self, instance: EncryptedDocumentMixin, owner) -> T:
return instance.decr(getattr(instance, self.encrypted_field))

def __set__(self, instance: EncryptedDocumentMixin, value: T):
if not isinstance(value, self.field_type):
raise TypeError(f'not {self.field_type}, got {value!r}')
setattr(instance, self.encrypted_field, instance.encr(value))


class EncryptedDocumentMixin:

@classmethod
@cached(RRCache(maxsize=99)) # needs to be per datastore, so we pass that as a param
def encryptor(cls, ming_ds: ming.datastore.DataStore):
conn: MongoClient = ming_ds.conn
kms_providers = {"local": {"key": ming_ds.encryption_key}}
encryption = ClientEncryption(kms_providers, ming_ds.encr_data_key_vault,
conn, conn.codec_options)
return encryption

@classmethod
def make_data_key(cls):
ming_ds: ming.datastore.DataStore = cls.m.session.bind
# index recommended by mongodb docs:
key_vault_db_name, key_vault_coll_name = ming_ds.encr_data_key_vault.split('.')
key_vault_coll = ming_ds.conn[key_vault_db_name][key_vault_coll_name]
key_vault_coll.create_index("keyAltNames", unique=True,
partialFilterExpression={"keyAltNames": {"$exists": True}})
cls.encryptor(ming_ds).create_data_key('local', key_alt_names=[ming_ds.encr_data_key_name])

# cls.encryptor(ming_ds).create_data_key('local', **ming_ds['provider_options']['local'])
# cls.encryptor(ming_ds).create_data_key('aws', **ming_ds['provider_options']['aws'])

@classmethod
def encr(cls, s: str | None, _first_attempt=True) -> bytes | None:
if s is None:
return None
try:
ming_ds: ming.datastore.DataStore = cls.m.session.bind
return cls.encryptor(ming_ds).encrypt(s,
Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic,
key_alt_name=cls.ming_ds.encr_data_key_name)
except EncryptionError as e:
if _first_attempt and 'not all keys requested were satisfied' in str(e):
cls.make_data_key()
return cls.encr(s, _first_attempt=False)
else:
raise

@classmethod
def decr(cls, b: bytes | None) -> str | None:
if b is None:
return None
return cls.encryptor(cls.m.session.bind).decrypt(b)

@classmethod
def decrypted_field_names(cls) -> list[str]:
return [fld.replace('_encrypted', '')
for fld in cls.encrypted_field_names()]

@classmethod
def encrypted_field_names(cls) -> list[str]:
return [fld for fld in dir(cls)
if fld.endswith('_encrypted')]

@classmethod
def encrypt_some_fields(cls, data: dict) -> dict:
encrypted_data = data.copy()
for fld in cls.decrypted_field_names():
if fld in encrypted_data:
val = encrypted_data.pop(fld)
encrypted_data[f'{fld}_encrypted'] = cls.encr(val)
return encrypted_data

def decrypt_some_fields(self) -> dict:
# useful for json, removes encrypted fields and uses decrypted forms
decrypted_data = dict(self)
for k, v in self.items():
if k.endswith('_encrypted'):
del decrypted_data[k]
k_decrypted = k.replace('_encrypted', '')
decrypted_data[k_decrypted] = getattr(self, k_decrypted)
return decrypted_data

@classmethod
def make_encr(cls, data: dict) -> Document:
# wrapper around regular ming .make()
data_with_encryption = cls.encrypt_some_fields(data)
return cls.make(data_with_encryption)

3 changes: 2 additions & 1 deletion ming/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .base import Object
from .utils import fixup_index, LazyProperty
from .exc import MongoGone
from .encryption import EncryptedDocumentMixin

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -401,7 +402,7 @@ def __delete__(self, inst):
del inst[self.name]


class _Document(Object):
class _Document(Object, EncryptedDocumentMixin):

def __init__(self, data=None, skip_from_bson=False):
if data is None:
Expand Down
24 changes: 24 additions & 0 deletions ming/tests/test_datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,30 @@ def test_configure_optional_params(self):
assert session.bind.conn is not None
assert session.bind.db is not None

def test_configure_encryption(self):
# Generate a base64 encoded random string:
# ```python
# from base64 import b64encode; import secrets;
# b64encode(secrets.token_hex().encode('utf8')).decode('utf8')
# ```
encryption_key = 'ODdlZGMzNjZlZWFmYTVlMDhhYWM0ZTBhNTQ5ZTE2YzQ3OWZmMzA5MDUxMDhhOTVlN2UyYTMzNzBkZDE5OGRhMg=='
ming.configure(**{
'ming.main.uri': 'mongodb://localhost:27017/test_db',
'ming.main.replicaSet': 'foobar',
'ming.main.foo.bar': 'foobar',
'ming.main.encryption.kms_providers.local.key': encryption_key,
'ming.main.encryption.key_vault_namespace': 'encryption.collectionName',
'ming.main.encryption.provider_options.local.key_alt_names': ['datakeyName'],
# 'ming.main.encryption.provider_options.aws.key_alt_names': ['datakeyName'],
# 'ming.main.encryption.provider_options.aws.master_key': ['datakeyName'],
})
session = Session.by_name('main')
assert session.bind.conn is not None, session.bind.conn
assert session.bind.db is not None, session.bind.db
assert session.bind.encryption.kms_providers == {'local': {'key': encryption_key}}, session.bind.encryption.kms_providers
assert session.bind.encryption.key_vault_namespace == 'encryption.collectionName', session.bind.encryption.key_vault_namespace
assert session.bind.encryption.provider_options == {'local': {'key_alt_names': ['datakeyName']}}, session.bind.encryption.provider_options

def test_no_kwargs_with_bind(self):
self.assertRaises(
ming.exc.MingConfigError,
Expand Down
39 changes: 39 additions & 0 deletions ming/tests/test_declarative.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ming.base import Cursor
from ming.datastore import create_datastore
from ming.declarative import Document
from ming.encryption import DecryptedField
from ming.metadata import Field, Index
from ming import schema as S
from ming.odm.odmsession import ODMSession, ThreadLocalODMSession
Expand Down Expand Up @@ -170,6 +171,44 @@ def test_field(self):
self.assertRaises(AttributeError, getattr, doc, 'a')
self.assertEqual(self.session.count(self.TestDoc), 1)


class TestDocumentEncryptionReal(TestCase):
DATASTORE = f"mongodb://localhost/test_ming_TestDocumentReal_{os.getpid()}?serverSelectionTimeoutMS=100"

def setUp(self):
self.datastore = create_datastore(self.DATASTORE)
self.session = Session(bind=self.datastore)

class TestDoc(Document):
class __mongometa__:
name='test_doc'
session = self.session
indexes = [ ('a',) ]
_id = Field(S.Anything)
name = DecryptedField(str, 'name_encrypted')
name_encrypted = Field(S.Binary)

self.TestDoc = TestDoc

def tearDown(self):
self.TestDoc.m.remove()
# FIXME: teardown/ remove the encryption collection. likely in a different database

def test_field(self):
doc = self.TestDoc(dict(_id=1, a=1, b=dict(a=5)))
doc.m.save()

self.assertEqual(doc.a, 1)
self.assertEqual(doc.b, dict(a=5))
doc.a = 5
self.assertEqual(doc, dict(_id=1, a=5, b=dict(a=5)))
del doc.a
self.assertEqual(doc, dict(_id=1, b=dict(a=5)))
self.assertRaises(AttributeError, getattr, doc, 'c')
self.assertRaises(AttributeError, getattr, doc, 'a')
self.assertEqual(self.session.count(self.TestDoc), 1)


class TestIndexes(TestCase):

def setUp(self):
Expand Down
57 changes: 57 additions & 0 deletions ming/validators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@

from .exc import MingConfigError
from .encryption import EncryptionConfig

try:
from formencode import schema, validators
except ImportError:
raise MingConfigError("Need to install FormEncode to use ``ming.encryption`` package")


class EncryptionConfigValidator(validators.FancyValidator):
"""
Password string validation to refrain from usage of name, username or email id (full or partial) as its substrings
"""
accept_iterator = True

VALID_KMS_PROVIDERS = ('local', 'aws', 'azure', 'gcp', 'kmip')

messages = dict(
InvalidKMSProvider=f"Invalid KMS Provider %(provider). Valid options are: {', '.join(VALID_KMS_PROVIDERS)}"
)

def _convert_to_python(self, field_dict, state):
if not field_dict:
return None
return EncryptionConfig(field_dict)

def _validate_python(self, config: EncryptionConfig, state):
if not config:
return # no encryption settings

def validate_inner():
error_dict = {}
if config.kms_providers:
if not config.key_vault_namespace:
error_dict['key_vault_namespace'] = validators.Invalid(self.message('InvalidKMSProvider', state), config, state)
if not config.provider_options:
error_dict['provider_options'] = validators.Invalid(self.message('InvalidKMSProvider', state), config, state)

used_kms_providers = set()
for key, settings in config.kms_providers.items():
if key not in self.VALID_KMS_PROVIDERS:
error_dict['kms_providers'] = validators.Invalid(self.message('InvalidKMSProvider', state, provider=key), config, state)
break
used_kms_providers.add(key)

if error_dict:
return error_dict

if 'local' in used_kms_providers:
if 'key' not in config.kms_providers['local']:
error_dict['kms_providers'] = validators.Invalid(f'Local KMS Provider requires a key', config, state)

error_dict = validate_inner()

if error_dict:
raise validators.Invalid(f'Invalid Encryption Settings', config, state, error_dict=error_dict)
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
include_package_data=True,
zip_safe=True,
install_requires=[
"cachetools",
"pymongo[encryption]",
"pytz",
],
Expand Down

0 comments on commit 58a0b5a

Please sign in to comment.