diff --git a/rest_flex_fields/filter_backends.py b/rest_flex_fields/filter_backends.py index 148d99c..9f4066e 100644 --- a/rest_flex_fields/filter_backends.py +++ b/rest_flex_fields/filter_backends.py @@ -34,7 +34,8 @@ def filter_queryset( context=view.get_serializer_context() ) - serializer.apply_flex_fields() + serializer.apply_flex_fields(serializer.fields, serializer._flex_options_rep_only) + serializer._flex_fields_rep_applied = True model_fields = [ self._get_field(field.source, queryset.model) diff --git a/rest_flex_fields/serializers.py b/rest_flex_fields/serializers.py index 93de6fb..5f3880f 100644 --- a/rest_flex_fields/serializers.py +++ b/rest_flex_fields/serializers.py @@ -33,32 +33,50 @@ def __init__(self, *args, **kwargs): self.parent = parent self.expanded_fields = [] - self._flex_fields_applied = False - - self._flex_options = { - "expand": ( - expand - if len(expand) > 0 - else self._get_permitted_expands_from_query_param(EXPAND_PARAM) - ), - "fields": (fields if len(fields) > 0 else self._get_query_param_value(FIELDS_PARAM)), - "omit": omit if len(omit) > 0 else self._get_query_param_value(OMIT_PARAM), + self._flex_fields_rep_applied = False + + self._flex_options_base = { + "expand": expand, + "fields": fields, + "omit": omit, + } + self._flex_options_rep_only = { + "expand": (self._get_permitted_expands_from_query_param(EXPAND_PARAM) + if not expand else + []), + "fields": (self._get_query_param_value(FIELDS_PARAM) + if not fields else + []), + "omit": (self._get_query_param_value(OMIT_PARAM) + if not omit else + []), + } + self._flex_options_all = { + "expand": self._flex_options_base["expand"] + self._flex_options_rep_only["expand"], + "fields": self._flex_options_base["fields"] + self._flex_options_rep_only["fields"], + "omit": self._flex_options_base["omit"] + self._flex_options_rep_only["omit"], } - def to_representation(self, *args, **kwargs): - if self._flex_fields_applied is False: - self.apply_flex_fields() - return super().to_representation(*args, **kwargs) + def to_representation(self, instance): + if not self._flex_fields_rep_applied: + self.apply_flex_fields(self.fields, self._flex_options_rep_only) + self._flex_fields_rep_applied = True + return super().to_representation(instance) - def apply_flex_fields(self): - expand_fields, next_expand_fields = split_levels(self._flex_options["expand"]) - sparse_fields, next_sparse_fields = split_levels(self._flex_options["fields"]) - omit_fields, next_omit_fields = split_levels(self._flex_options["omit"]) + def get_fields(self): + fields = super().get_fields() + self.apply_flex_fields(fields, self._flex_options_base) + return fields - to_remove = self._get_fields_names_to_remove(omit_fields, sparse_fields, next_omit_fields) + def apply_flex_fields(self, fields, flex_options): + expand_fields, next_expand_fields = split_levels(flex_options["expand"]) + sparse_fields, next_sparse_fields = split_levels(flex_options["fields"]) + omit_fields, next_omit_fields = split_levels(flex_options["omit"]) + + to_remove = self._get_fields_names_to_remove(omit_fields, sparse_fields, next_omit_fields, fields) for field_name in to_remove: - self.fields.pop(field_name) + fields.pop(field_name) expanded_field_names = self._get_expanded_field_names( expand_fields, omit_fields, sparse_fields, next_omit_fields @@ -67,11 +85,10 @@ def apply_flex_fields(self): for name in expanded_field_names: self.expanded_fields.append(name) - self.fields[name] = self._make_expanded_field_serializer( + fields[name] = self._make_expanded_field_serializer( name, next_expand_fields, next_sparse_fields, next_omit_fields ) - - self._flex_fields_applied = True + return fields def _make_expanded_field_serializer(self, name, nested_expand, nested_fields, nested_omit): """ @@ -143,6 +160,7 @@ def _get_fields_names_to_remove( omit_fields: List[str], sparse_fields: List[str], next_level_omits: List[str], + fields, ) -> List[str]: """ Remove fields that are found in omit list, and if sparse names @@ -154,7 +172,7 @@ def _get_fields_names_to_remove( if not sparse and len(omit_fields) == 0: return to_remove - for field_name in self.fields: + for field_name in fields: should_exist = self._should_field_exist( field_name, omit_fields, sparse_fields, next_level_omits ) diff --git a/tests/test_flex_fields_model_serializer.py b/tests/test_flex_fields_model_serializer.py index ed69bd7..1b6160c 100644 --- a/tests/test_flex_fields_model_serializer.py +++ b/tests/test_flex_fields_model_serializer.py @@ -32,8 +32,8 @@ def test_field_should_exist_if_ommitted_but_is_parent_of_omit(self): def test_clean_fields(self): serializer = FlexFieldsModelSerializer() - serializer.fields = {"cat": 1, "dog": 2, "zebra": 3} - result = serializer._get_fields_names_to_remove(["cat"], [], {}) + fields = {"cat": 1, "dog": 2, "zebra": 3} + result = serializer._get_fields_names_to_remove(["cat"], [], {}, fields) self.assertEqual(result, ["cat"]) def test_get_expanded_field_names_if_all(self): @@ -83,7 +83,7 @@ def test_get_omit_input_from_explicit_settings(self): }, ) - self.assertEqual(serializer._flex_options["omit"], ["fish"]) + self.assertEqual(serializer._flex_options_all["omit"], ["fish"]) def test_set_omit_input_from_query_param(self): serializer = FlexFieldsModelSerializer( @@ -93,7 +93,7 @@ def test_set_omit_input_from_query_param(self): ) } ) - self.assertEqual(serializer._flex_options["omit"], ["cat", "dog"]) + self.assertEqual(serializer._flex_options_all["omit"], ["cat", "dog"]) def test_set_fields_input_from_explicit_settings(self): serializer = FlexFieldsModelSerializer( @@ -105,7 +105,7 @@ def test_set_fields_input_from_explicit_settings(self): }, ) - self.assertEqual(serializer._flex_options["fields"], ["fish"]) + self.assertEqual(serializer._flex_options_all["fields"], ["fish"]) def test_set_fields_input_from_query_param(self): serializer = FlexFieldsModelSerializer( @@ -116,7 +116,7 @@ def test_set_fields_input_from_query_param(self): } ) - self.assertEqual(serializer._flex_options["fields"], ["cat", "dog"]) + self.assertEqual(serializer._flex_options_all["fields"], ["cat", "dog"]) def test_set_expand_input_from_explicit_setting(self): serializer = FlexFieldsModelSerializer( @@ -128,7 +128,7 @@ def test_set_expand_input_from_explicit_setting(self): }, ) - self.assertEqual(serializer._flex_options["fields"], ["cat"]) + self.assertEqual(serializer._flex_options_all["fields"], ["cat"]) def test_set_expand_input_from_query_param(self): serializer = FlexFieldsModelSerializer( @@ -139,7 +139,7 @@ def test_set_expand_input_from_query_param(self): } ) - self.assertEqual(serializer._flex_options["expand"], ["cat", "dog"]) + self.assertEqual(serializer._flex_options_all["expand"], ["cat", "dog"]) def test_get_expand_input_from_query_param_limit_to_list_permitted(self): serializer = FlexFieldsModelSerializer( @@ -151,7 +151,7 @@ def test_get_expand_input_from_query_param_limit_to_list_permitted(self): } ) - self.assertEqual(serializer._flex_options["expand"], ["cat"]) + self.assertEqual(serializer._flex_options_all["expand"], ["cat"]) def test_parse_request_list_value(self): test_params = [