Skip to content

Commit

Permalink
Always create map attributes when setting a dict (#223)
Browse files Browse the repository at this point in the history
  • Loading branch information
garrettheel authored and danielhochman committed Feb 4, 2017
1 parent d16e2bb commit 874d18e
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 54 deletions.
80 changes: 45 additions & 35 deletions pynamodb/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ def __init__(self, hash_key=False, range_key=False, null=None, default=None, att
null=null,
default=default,
attr_name=attr_name)
self._get_attributes() # Ensure attributes are always inited
self.attribute_values = {}
self._set_defaults()
self._set_attributes(**attrs)
Expand All @@ -429,38 +430,33 @@ def __iter__(self):
def __getitem__(self, item):
return self.attribute_values[item]

def __getattr__(self, attr):
return self.attribute_values[attr]

def __set__(self, instance, value):
if isinstance(value, collections.Mapping):
value = type(self)(**value)
return super(MapAttribute, self).__set__(instance, value)

def _set_attributes(self, **attrs):
"""
Sets the attributes for this object
"""
for attr_name, attr in self._get_attributes().items():
if attr.attr_name in attrs:
value = attrs.get(attr_name)
if not isinstance(value, collections.Mapping) or type(attr) == MapAttribute:
setattr(self, attr_name, attrs.get(attr.attr_name))
else:
# it's a sub model which means we need to instantiate that type first
# pass in the attributes of that model, then set the field on this object to point to that model
sub_model = value
instance = type(attr)(**sub_model)
setattr(self, attr_name, instance)

elif attr_name in attrs:
setattr(self, attr_name, attrs.get(attr_name))

def get_values(self):
attributes = self._get_attributes()
result = {}
for k, v in six.iteritems(attributes):
result[k] = getattr(self, k)
return result
for attr_name, value in six.iteritems(attrs):
attribute = self._get_attributes().get(attr_name)
if self.is_raw():
self.attribute_values[attr_name] = value
elif not attribute:
raise AttributeError("Attribute {0} specified does not exist".format(attr_name))
else:
setattr(self, attr_name, value)

def is_correctly_typed(self, key, attr):
can_be_null = attr.null
value = getattr(self, key)
if can_be_null and value is None:
return True
if value is None:
if getattr(self, key) is None:
raise ValueError("Attribute '{0}' cannot be None".format(key))
return True # TODO: check that the actual type of `value` meets requirements of `attr`

Expand All @@ -486,16 +482,37 @@ def serialize(self, values):

def deserialize(self, values):
"""
Decode numbers from list of AttributeValue types.
Decode as a dict.
"""
deserialized_dict = dict()
for k in values:
v = values[k]
attr_class = _get_class_for_deserialize(v)
attr_value = _get_value_for_deserialize(v)
deserialized_dict[k] = attr_class.deserialize(attr_value)
key = self._dynamo_to_python_attr(k)
attr_class = self._get_deserialize_class(key, v)

deserialized_dict[key] = attr_class.deserialize(attr_value)

# If this is a subclass of a MapAttribute (i.e typed), instantiate an instance
if not self.is_raw():
return type(self)(**deserialized_dict)
return deserialized_dict

@classmethod
def is_raw(cls):
return cls == MapAttribute

def as_dict(self):
result = {}
for key, value in six.iteritems(self.attribute_values):
result[key] = value.as_dict() if isinstance(value, MapAttribute) else value
return result

@classmethod
def _get_deserialize_class(cls, key, value):
if not cls.is_raw():
return cls._get_attributes().get(key)
return _get_class_for_deserialize(value)

def _get_value_for_deserialize(value):
return value[list(value.keys())[0]]
Expand Down Expand Up @@ -562,17 +579,10 @@ def deserialize(self, values):
"""
deserialized_lst = []
for v in values:
attr_class = _get_class_for_deserialize(v)
class_for_deserialize = self.element_type() if self.element_type else _get_class_for_deserialize(v)
attr_value = _get_value_for_deserialize(v)
deserialized_lst.append(attr_class.deserialize(attr_value))

if not self.element_type:
return deserialized_lst

lst_of_type = []
for item in deserialized_lst:
lst_of_type.append(self.element_type(**item))
return lst_of_type
deserialized_lst.append(class_for_deserialize.deserialize(attr_value))
return deserialized_lst

DESERIALIZE_CLASS_MAP = {
LIST_SHORT: ListAttribute(),
Expand Down
6 changes: 1 addition & 5 deletions pynamodb/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,10 +500,7 @@ def from_raw_data(cls, data):
attr_name = cls._dynamo_to_python_attr(name)
attr = cls._get_attributes().get(attr_name, None)
if attr:
deserialized_attr = attr.deserialize(attr.get_value(value))
if isinstance(attr, MapAttribute) and not type(attr) == MapAttribute:
deserialized_attr = type(attr)(**deserialized_attr)
kwargs[attr_name] = deserialized_attr
kwargs[attr_name] = attr.deserialize(attr.get_value(value))
return cls(*args, **kwargs)

@classmethod
Expand Down Expand Up @@ -1295,7 +1292,6 @@ def _serialize(self, attr_map=False, null_check=True):
if isinstance(value, MapAttribute):
if not value.validate():
raise ValueError("Attribute '{0}' is not correctly typed".format(attr.attr_name))
value = value.get_values()

serialized = self._serialize_value(attr, value, null_check)
if NULL in serialized:
Expand Down
57 changes: 57 additions & 0 deletions pynamodb/tests/test_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class Meta:
datetime_attr = UTCDateTimeAttribute()
bool_attr = BooleanAttribute()
json_attr = JSONAttribute()
map_attr = MapAttribute()


class CustomAttrMap(MapAttribute):
Expand Down Expand Up @@ -576,6 +577,62 @@ def test_defaults(self):
}
})

def test_raw_map_from_dict(self):
item = AttributeTestModel(
map_attr={
"foo": "bar",
"num": 3,
"nested": {
"nestedfoo": "nestedbar"
}
}
)

self.assertEqual(item.map_attr['foo'], 'bar')
self.assertEqual(item.map_attr['num'], 3)

def test_raw_map_access(self):
raw = {
"foo": "bar",
"num": 3,
"nested": {
"nestedfoo": "nestedbar"
}
}
attr = MapAttribute(**raw)

for k, v in six.iteritems(raw):
self.assertEquals(attr[k], v)

def test_raw_map_json_serialize(self):
raw = {
"foo": "bar",
"num": 3,
"nested": {
"nestedfoo": "nestedbar"
}
}

serialized_raw = json.dumps(raw)
self.assertEqual(json.dumps(AttributeTestModel(map_attr=raw).map_attr.as_dict()),
serialized_raw)
self.assertEqual(json.dumps(AttributeTestModel(map_attr=MapAttribute(**raw)).map_attr.as_dict()),
serialized_raw)

def test_typed_and_raw_map_json_serialize(self):
class TypedMap(MapAttribute):
map_attr = MapAttribute()

class SomeModel(Model):
typed_map = TypedMap()

item = SomeModel(
typed_map=TypedMap(map_attr={'foo': 'bar'})
)

self.assertEqual(json.dumps({'map_attr': {'foo': 'bar'}}),
json.dumps(item.typed_map.as_dict()))


class MapAndListAttributeTestCase(TestCase):

Expand Down
127 changes: 113 additions & 14 deletions pynamodb/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,11 @@ class ModelTestCase(TestCase):
"""
Tests for the models API
"""
@staticmethod
def init_table_meta(model_clz, table_data):
with patch(PATCH_METHOD) as req:
req.return_value = table_data
model_clz._get_meta_data()

def assert_dict_lists_equal(self, list1, list2):
"""
Expand Down Expand Up @@ -3347,7 +3352,7 @@ def test_raw_map_deserializes(self):
instance = ExplicitRawMapModel(map_attr=map_native)
instance._deserialize(map_serialized)
actual = instance.map_attr
for k,v in six.iteritems(map_native):
for k, v in six.iteritems(map_native):
self.assertEqual(v, actual[k])

def test_raw_map_from_raw_data_works(self):
Expand All @@ -3361,10 +3366,10 @@ def test_raw_map_from_raw_data_works(self):
EXPLICIT_RAW_MAP_MODEL_ITEM_DATA,
'map_id', 'N',
'123')
with patch(PATCH_METHOD, new=fake_db) as req:
with patch(PATCH_METHOD, new=fake_db):
item = ExplicitRawMapModel.get(123)
actual = item.map_attr
self.assertEqual(map_native.get('listy')[2], actual.get('listy')[2])
self.assertEqual(map_native.get('listy')[2], actual['listy'][2])
for k, v in six.iteritems(map_native):
self.assertEqual(v, actual[k])

Expand Down Expand Up @@ -3396,11 +3401,11 @@ def _get_raw_map_as_sub_map_test_data(self):
map_serialized = {
'M': {
'foo': {'S': 'bar'},
'num': {'N': 1},
'num': {'N': '1'},
'bool_type': {'BOOL': True},
'other_b_type': {'BOOL': False},
'floaty': {'N': 1.2},
'listy': {'L': [{'N': 1}, {'N': 2}, {'N': 3}]},
'floaty': {'N': '1.2'},
'listy': {'L': [{'N': '1'}, {'N': '2'}, {'N': '3'}]},
'mapy': {'M': {'baz': {'S': 'bongo'}}}
}
}
Expand All @@ -3415,13 +3420,22 @@ def _get_raw_map_as_sub_map_test_data(self):
)
return map_native, map_serialized, sub_attr, instance

def test_raw_map_as_sub_map_deserializes(self):
def test_raw_map_as_sub_map(self):
map_native, map_serialized, sub_attr, instance = self._get_raw_map_as_sub_map_test_data()
instance._deserialize(map_serialized)
actual = instance.sub_attr
self.assertEqual(sub_attr, actual)
self.assertEqual(sub_attr.map_field.get('floaty'), map_native.get('floaty'))
self.assertEqual(sub_attr.map_field.get('mapy', {}).get('baz'), map_native.get('mapy', {}).get('baz'))
self.assertEqual(actual.map_field['floaty'], map_native.get('floaty'))
self.assertEqual(actual.map_field['mapy']['baz'], map_native.get('mapy').get('baz'))

def test_raw_map_as_sub_map_deserialize(self):
map_native, map_serialized, _, _ = self._get_raw_map_as_sub_map_test_data()

actual = MapAttrSubClassWithRawMapAttr().deserialize({
"map_field": map_serialized
})

for k, v in six.iteritems(map_native):
self.assertEqual(actual.map_field[k], v)

def test_raw_map_as_sub_map_from_raw_data_works(self):
map_native, map_serialized, sub_attr, instance = self._get_raw_map_as_sub_map_test_data()
Expand All @@ -3430,9 +3444,94 @@ def test_raw_map_as_sub_map_from_raw_data_works(self):
EXPLICIT_RAW_MAP_MODEL_AS_SUB_MAP_IN_TYPED_MAP_ITEM_DATA,
'map_id', 'N',
'123')
with patch(PATCH_METHOD, new=fake_db) as req:
with patch(PATCH_METHOD, new=fake_db):
item = ExplicitRawMapAsMemberOfSubClass.get(123)
self.assertEqual(sub_attr.map_field.get('floaty'),
actual = item.sub_attr
self.assertEqual(sub_attr.map_field['floaty'],
map_native.get('floaty'))
self.assertEqual(sub_attr.map_field.get('mapy', {}).get('baz'),
map_native.get('mapy', {}).get('baz'))
self.assertEqual(actual.map_field['mapy']['baz'],
map_native.get('mapy').get('baz'))


class ModelInitTestCase(TestCase):

def test_raw_map_attribute_with_dict_init(self):
attribute = {
'foo': 123,
'bar': 'baz'
}
actual = ExplicitRawMapModel(map_id=3, map_attr=attribute)
self.assertEquals(actual.map_attr['foo'], attribute['foo'])

def test_raw_map_attribute_with_initialized_instance_init(self):
attribute = {
'foo': 123,
'bar': 'baz'
}
initialized_instance = MapAttribute(**attribute)
actual = ExplicitRawMapModel(map_id=3, map_attr=initialized_instance)
self.assertEquals(actual.map_attr['foo'], initialized_instance['foo'])
self.assertEquals(actual.map_attr['foo'], attribute['foo'])

def test_subclassed_map_attribute_with_dict_init(self):
attribute = {
'make': 'Volkswagen',
'model': 'Super Beetle'
}
expected_model = CarInfoMap(**attribute)
actual = CarModel(car_id=1, car_info=attribute)
self.assertEquals(expected_model.make, actual.car_info.make)
self.assertEquals(expected_model.model, actual.car_info.model)

def test_subclassed_map_attribute_with_initialized_instance_init(self):
attribute = {
'make': 'Volkswagen',
'model': 'Super Beetle'
}
expected_model = CarInfoMap(**attribute)
actual = CarModel(car_id=1, car_info=expected_model)
self.assertEquals(expected_model.make, actual.car_info.make)
self.assertEquals(expected_model.model, actual.car_info.model)

def _get_bin_tree(self, multiplier=1):
return {
'value': 5 * multiplier,
'left': {
'value': 2 * multiplier,
'left': {
'value': 1 * multiplier
},
'right': {
'value': 3 * multiplier
}
},
'right': {
'value': 7 * multiplier,
'left': {
'value': 6 * multiplier
},
'right': {
'value': 8 * multiplier
}
}
}

def test_subclassed_map_attribute_with_map_attributes_member_with_dict_init(self):
left = self._get_bin_tree()
right = self._get_bin_tree(multiplier=2)
actual = TreeModel(tree_key='key', left=left, right=right)
self.assertEquals(actual.left.left.right.value, 3)
self.assertEquals(actual.left.left.value, 2)
self.assertEquals(actual.right.right.left.value, 12)
self.assertEquals(actual.right.right.value, 14)

def test_subclassed_map_attribute_with_map_attribute_member_with_initialized_instance_init(self):
left = self._get_bin_tree()
right = self._get_bin_tree(multiplier=2)
left_instance = TreeLeaf(**left)
right_instance = TreeLeaf(**right)
actual = TreeModel(tree_key='key', left=left_instance, right=right_instance)
self.assertEquals(actual.left.left.right.value, left_instance.left.right.value)
self.assertEquals(actual.left.left.value, left_instance.left.value)
self.assertEquals(actual.right.right.left.value, right_instance.right.left.value)
self.assertEquals(actual.right.right.value, right_instance.right.value)

0 comments on commit 874d18e

Please sign in to comment.