Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 156 additions & 10 deletions ming/encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,20 +226,30 @@ def _wrap_item(self, item):

def _encrypt_item(self, item):
"""Encrypt or process a single item for writing."""
encrypted_iem = _encrypt_value_recursive(
encrypted_item = _encrypt_value_recursive(
item,
self._item_schema,
self._instance.encr,
field_name=ENCRYPTED_SUFFIX if self._items_encrypted else None,
force_encrypt=self._items_encrypted,
)
return encrypted_iem
return encrypted_item

def _mark_dirty(self):
"""Mark the list as modified for dirty tracking."""
if self._tracker is not None:
self._tracker.added_item(self._doc)

def _plain_list(self):
return list(self)

def _coerce_list_comparison(self, other):
if isinstance(other, EncryptedListWrapper):
return other._plain_list()
if isinstance(other, list):
return other
return NotImplemented

def __getitem__(self, index):
if isinstance(index, slice):
return [self._wrap_item(item) for item in self._doc[index]]
Expand Down Expand Up @@ -283,6 +293,42 @@ def __imul__(self, n):
self.extend(current)
return self

def __eq__(self, other):
other_list = self._coerce_list_comparison(other)
if other_list is NotImplemented:
return False
return self._plain_list() == other_list

def __ne__(self, other):
return not self == other

def __lt__(self, other):
other_list = self._coerce_list_comparison(other)
if other_list is NotImplemented:
return NotImplemented
return self._plain_list() < other_list

def __le__(self, other):
other_list = self._coerce_list_comparison(other)
if other_list is NotImplemented:
return NotImplemented
return self._plain_list() <= other_list

def __gt__(self, other):
other_list = self._coerce_list_comparison(other)
if other_list is NotImplemented:
return NotImplemented
return self._plain_list() > other_list

def __ge__(self, other):
other_list = self._coerce_list_comparison(other)
if other_list is NotImplemented:
return NotImplemented
return self._plain_list() >= other_list

def __reversed__(self):
return reversed(self._plain_list())

def append(self, value):
self._doc.append(self._encrypt_item(value))
self._mark_dirty()
Expand All @@ -302,14 +348,27 @@ def pop(self, index=-1):
return self._wrap_item(item)

def remove(self, value):
# Need to find and remove the encrypted version
encrypted_value = self._encrypt_item(value)
self._doc.remove(encrypted_value)
del self._doc[self.index(value)]
self._mark_dirty()

def index(self, value, *args):
encrypted_value = self._encrypt_item(value)
return self._doc.index(encrypted_value, *args)
return self._plain_list().index(value, *args)

def count(self, value):
return self._plain_list().count(value)

def copy(self):
return self._plain_list().copy()

def reverse(self):
self._doc.reverse()
self._mark_dirty()

def sort(self, *, key=None, reverse=False):
values = self._plain_list()
values.sort(key=key, reverse=reverse)
self._doc[:] = [self._encrypt_item(value) for value in values]
self._mark_dirty()

def replace(self, values):
self[:] = values
Expand All @@ -334,7 +393,6 @@ def __contains__(self, value):
def __repr__(self):
return f"EncryptedListWrapper({list(self)})"


class EncryptedDictWrapper:
"""Generic dict wrapper that transparently encrypts/decrypts specified fields.

Expand Down Expand Up @@ -654,6 +712,87 @@ def __set__(self, instance: EncryptedMixin, value: T):
setattr(instance, self.encrypted_field, instance.encr(value))


class DecryptedListField:
"""Virtual Document field for lists stored in a sibling encrypted field.

``DecryptedListField('emails_encrypted')`` exposes plaintext list
operations while storing encrypted values in ``emails_encrypted``.
"""

def __init__(self, encrypted_field: str):
self.encrypted_field = encrypted_field

def _encrypt_list(self, encr_func, value):
if value is None:
return []
return _encrypt_list_recursive(value, [S.Binary], encr_func, self.encrypted_field, force_encrypt=True)

def __get__(self, instance: EncryptedMixin, owner):
if instance is None:
return self

doc = instance.get(self.encrypted_field)
if doc is None:
doc = []
instance[self.encrypted_field] = doc
return EncryptedListWrapper(
doc=doc,
tracker=None,
item_schema=S.Binary,
instance=instance,
items_encrypted=True,
)

def __set__(self, instance: EncryptedMixin, value):
instance[self.encrypted_field] = self._encrypt_list(instance.encr, value)

def __delete__(self, instance):
del instance[self.encrypted_field]


class DecryptedListProperty:
"""Virtual ODM property for lists stored in a sibling encrypted field.

``DecryptedListProperty('emails_encrypted')`` exposes plaintext list
operations while storing encrypted values in ``emails_encrypted``.
"""

def __init__(self, encrypted_field: str):
self.encrypted_field = encrypted_field

def _encrypt_list(self, encr_func, value):
if value is None:
return []
return _encrypt_list_recursive(value, [S.Binary], encr_func, self.encrypted_field, force_encrypt=True)

def __get__(self, instance: EncryptedMixin, owner):
if instance is None:
return self

from ming.odm.base import state

st = state(instance)
doc = st.document.get(self.encrypted_field)
if doc is None:
doc = []
st.document[self.encrypted_field] = doc
return EncryptedListWrapper(
doc=doc,
tracker=st.tracker,
item_schema=S.Binary,
instance=instance,
items_encrypted=True,
)

def __set__(self, instance: EncryptedMixin, value):
setattr(instance, self.encrypted_field, self._encrypt_list(instance.encr, value))

def __delete__(self, instance):
from ming.odm.base import state

state(instance).delete(self.encrypted_field)


class EncryptedMixin:
"""A mixin intended to be used with :class:`~ming.declarative.Document`
or :class:`~ming.odm.declarative.MappedClass` to provide encryption.
Expand Down Expand Up @@ -763,7 +902,11 @@ def encrypt_some_fields(cls, data: dict) -> dict:
for fld in cls.decrypted_field_names():
if fld in encrypted_data:
val = encrypted_data.pop(fld)
encrypted_data[f'{fld}_encrypted'] = cls.encr(val)
prop = getattr(cls, fld, None)
if isinstance(prop, (DecryptedListField, DecryptedListProperty)):
encrypted_data[prop.encrypted_field] = prop._encrypt_list(cls.encr, val)
else:
encrypted_data[f'{fld}_encrypted'] = cls.encr(val)

# Handle nested encrypted field/property instances.
for field_name, field in cls._encrypted_field_index().items():
Expand All @@ -787,7 +930,10 @@ def decrypt_some_fields(self) -> dict:
for k in self._field_names:
if k.endswith('_encrypted'):
k_decrypted = k.replace('_encrypted', '')
decrypted_data[k_decrypted] = getattr(self, k_decrypted)
value = getattr(self, k_decrypted)
if isinstance(getattr(type(self), k_decrypted, None), (DecryptedListField, DecryptedListProperty)):
value = list(value)
decrypted_data[k_decrypted] = value
else:
decrypted_data[k] = getattr(self, k)
return decrypted_data
Expand Down
4 changes: 2 additions & 2 deletions ming/odm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from ming.odm.mapper import mapper, Mapper, MapperExtension

from ming.odm.property import RelationProperty, ForeignIdProperty
from ming.odm.property import FieldProperty, FieldPropertyWithMissingNone, DecryptedProperty
from ming.odm.property import FieldProperty, FieldPropertyWithMissingNone, DecryptedProperty, DecryptedListProperty

from ming.odm.odmsession import ODMSession, ThreadLocalODMSession, SessionExtension
from ming.odm.odmsession import ContextualODMSession
Expand All @@ -15,5 +15,5 @@

__all__ = ('state', 'session', 'mapper', 'Mapper', 'MapperExtension',
'RelationProperty', 'ForeignIdProperty', 'FieldProperty', 'DecryptedProperty',
'FieldPropertyWithMissingNone', 'ODMSession', 'ThreadLocalODMSession',
'DecryptedListProperty', 'FieldPropertyWithMissingNone', 'ODMSession', 'ThreadLocalODMSession',
'SessionExtension', 'MappedClass', 'ContextualODMSession')
3 changes: 3 additions & 0 deletions ming/odm/property.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def __repr__(self):
class DecryptedProperty(ming.encryption.DecryptedField):
pass

class DecryptedListProperty(ming.encryption.DecryptedListProperty):
pass

class FieldProperty(ORMProperty):
"""Declares property for a value stored in a MongoDB Document.

Expand Down
5 changes: 4 additions & 1 deletion ming/odm/property.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,8 @@ class RelationProperty:
@overload
def __new__(self, related: Type[MC], via: str=None, fetch=True) -> Iterable[MC]:...

class DecryptedListProperty:
def __init__(self, encrypted_field: str) -> None: ...

def __getattr__(name) -> Any: ... # marks file as incomplete

def __getattr__(name) -> Any: ... # marks file as incomplete
Loading