diff --git a/pynamodb/attributes.py b/pynamodb/attributes.py index c0f182b39..ce750bf4b 100644 --- a/pynamodb/attributes.py +++ b/pynamodb/attributes.py @@ -268,7 +268,7 @@ def prepend(self, other: Iterable) -> '_ListAppend': def set( self, value: Union[_T, 'Attribute[_T]', '_Increment', '_Decrement', '_IfNotExists', '_ListAppend'] - ) -> 'SetAction': + ) -> Union['SetAction', 'RemoveAction']: return Path(self).set(value) def remove(self) -> 'RemoveAction': diff --git a/pynamodb/expressions/operand.py b/pynamodb/expressions/operand.py index 9420788fd..8f7987ead 100644 --- a/pynamodb/expressions/operand.py +++ b/pynamodb/expressions/operand.py @@ -298,9 +298,14 @@ def __getitem__(self, item: Union[int, str]) -> 'Path': def __or__(self, other): return _IfNotExists(self, self._to_operand(other)) - def set(self, value: Any) -> SetAction: - # Returns an update action that sets this attribute to the given value - return SetAction(self, self._to_operand(value)) + def set(self, value: Any) -> Union[SetAction, RemoveAction]: + # Returns an update action that sets this attribute to the given value. + # For attributes that may not be empty (e.g. sets), this may result + # in a remove action. + operand = self._to_operand(value) + if isinstance(operand, Value) and next(iter(operand.value.values())) is None: + return RemoveAction(self) + return SetAction(self, operand) def remove(self) -> RemoveAction: # Returns an update action that removes this attribute from the item diff --git a/tests/integration/model_integration_test.py b/tests/integration/model_integration_test.py index 340ca728d..9a54cccc6 100644 --- a/tests/integration/model_integration_test.py +++ b/tests/integration/model_integration_test.py @@ -102,6 +102,10 @@ class Meta: for item in TestModel.view_index.query('foo', TestModel.view > 0): print("Item queried from index: {}".format(item.view)) + query_obj.update([TestModel.scores.set([])]) + query_obj.refresh() + assert query_obj.scores is None + print(query_obj.update([TestModel.view.add(1)], condition=TestModel.forum.exists())) TestModel.delete_table() diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 4d314bb06..698366f50 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -479,6 +479,13 @@ def test_set_action(self): assert self.placeholder_names == {'foo': '#0'} assert self.expression_attribute_values == {':0': {'S': 'bar'}} + def test_set_action_as_remove(self): + action = self.set_attribute.set([]) + expression = action.serialize(self.placeholder_names, self.expression_attribute_values) + assert expression == "#0" + assert self.placeholder_names == {'foo_set': '#0'} + assert self.expression_attribute_values == {} + def test_set_action_attribute_container(self): # Simulate initialization from inside an AttributeContainer my_map_attribute = MapAttribute[str, str](attr_name='foo') @@ -620,6 +627,15 @@ def test_update(self): ':2': {'NS': ['1']} } + def test_update_set_to_empty(self): + update = Update( + self.set_attribute.set([]), + ) + expression = update.serialize(self.placeholder_names, self.expression_attribute_values) + assert expression == "REMOVE #0" + assert self.placeholder_names == {'foo_set': '#0'} + assert self.expression_attribute_values == {} + def test_update_skips_empty_clauses(self): update = Update(self.attribute.remove()) expression = update.serialize(self.placeholder_names, self.expression_attribute_values)