diff --git a/django_declarative_apis/machinery/filtering.py b/django_declarative_apis/machinery/filtering.py index 31c286b..ad53036 100644 --- a/django_declarative_apis/machinery/filtering.py +++ b/django_declarative_apis/machinery/filtering.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: BSD-3-Clause # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause # - +from abc import ABC, abstractmethod from collections import defaultdict import inspect import logging @@ -45,6 +45,16 @@ def expandable(model_class=None, display_key=None, inst_field_name=None): return _ExpandableForeignKey(display_key, model_class, inst_field_name) +class ExpandableGeneric(ABC): + @abstractmethod + def get_unexpanded_view(self, inst) -> dict: + raise NotImplementedError() + + @abstractmethod + def get_expanded_view(self, inst) -> dict: + raise NotImplementedError() + + def _get_unexpanded_field_value(inst, field_name, field_type): if not field_type.model_class: return DEFAULT_UNEXPANDED_VALUE @@ -141,6 +151,11 @@ def _get_filtered_field_value( # noqa: C901 val = getattr(inst, inst_field_name) else: val = _get_unexpanded_field_value(inst, field_name, field_type) + elif isinstance(field_type, ExpandableGeneric): + if expand_this: + val = field_type.get_expanded_view(inst) + else: + val = field_type.get_unexpanded_view(inst) else: try: if isinstance(inst, (models.Model)): @@ -184,6 +199,7 @@ def _get_filtered_field_value( # noqa: C901 or isinstance(field_type, types.FunctionType) or ((field_type == IF_TRUTHY) and val) or isinstance(field_type, _ExpandableForeignKey) + or isinstance(field_type, ExpandableGeneric) ): return val else: diff --git a/tests/filters.py b/tests/filters.py index 5b6cff4..212ff05 100644 --- a/tests/filters.py +++ b/tests/filters.py @@ -5,10 +5,24 @@ # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause # -from django_declarative_apis.machinery.filtering import ALWAYS, NEVER, expandable +from django_declarative_apis.machinery.filtering import ( + ALWAYS, + NEVER, + expandable, + ExpandableGeneric, +) from . import models + +class TestExpandableGeneric(ExpandableGeneric): + def get_unexpanded_view(self, inst) -> dict: + return {"id": "1234"} + + def get_expanded_view(self, inst) -> dict: + return {"id": "1234", "expanded": True} + + DEFAULT_FILTERS = { str: ALWAYS, int: ALWAYS, @@ -18,6 +32,7 @@ "int_field": ALWAYS, "expandable_dict": expandable(), "expandable_string": expandable(), + "expandable_generic": TestExpandableGeneric(), }, models.ChildModel: { "pk": ALWAYS, diff --git a/tests/test_filters.py b/tests/test_filters.py index 9b5d37c..d406a9d 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -31,6 +31,23 @@ def setUp(self): self.p1.root = self.root self.p1.save() + def test_expandable_generic_field(self): + # test unexpanded + filtered = filtering.apply_filters_to_object( + self.test_model, filters.DEFAULT_FILTERS + ) + self.assertIn("expandable_generic", filtered) + self.assertEqual("1234", filtered["expandable_generic"]["id"]) + self.assertNotIn("expanded", filtered["expandable_generic"]) + + # test exexpanded + filtered = filtering.apply_filters_to_object( + self.test_model, filters.DEFAULT_FILTERS, expand_header="expandable_generic" + ) + self.assertIn("expandable_generic", filtered) + self.assertEqual("1234", filtered["expandable_generic"]["id"]) + self.assertTrue(filtered["expandable_generic"]["expanded"]) + def test_expandable_field_not_expanded_by_default(self): filtered = filtering.apply_filters_to_object(self.root, filters.DEFAULT_FILTERS) self.assertEqual(4, len(filtered)) @@ -115,7 +132,7 @@ def test_expand_multi_level_more_than_one_field(self): self.assertTrue("pk" in filtered["parent_field"]["favorite"]) self.assertTrue("name" in filtered["parent_field"]["favorite"]) self.assertTrue("test" in filtered["parent_field"]["favorite"]) - self.assertEqual(3, len(filtered["parent_field"]["favorite"]["test"])) + self.assertEqual(4, len(filtered["parent_field"]["favorite"]["test"])) self.assertTrue("parent" in filtered["parent_field"]["favorite"]) self.assertEqual(self.p1c1.name, filtered["parent_field"]["favorite"]["name"]) self.assertEqual(2, len(filtered["parent_field"]["children"])) @@ -135,7 +152,7 @@ def test_expandable_properties(self): expand_header="expandable_dict,expandable_string", ) - self.assertEqual(5, len(filtered)) + self.assertEqual(6, len(filtered)) self.assertTrue("expandable_dict" in filtered) self.assertEqual( filtered["expandable_dict"], models.TestModel.EXPANDABLE_DICT_RETURN @@ -157,7 +174,7 @@ def test_expandable_properties(self): self.test_model, filters.DEFAULT_FILTERS ) - self.assertEqual(3, len(filtered)) + self.assertEqual(4, len(filtered)) self.assertFalse("expandable_dict" in filtered) dict_mock.assert_not_called() self.assertFalse("expandable_string" in filtered) diff --git a/tests/tests.py b/tests/tests.py index 4b085e3..47144a8 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -39,6 +39,7 @@ def test_dict_endpoint(self): "test": { "pk": 1, "int_field": 1, + "expandable_generic": {"id": "1234"}, "__expandable__": ["expandable_dict", "expandable_string"], }, "deep_test": { @@ -46,6 +47,7 @@ def test_dict_endpoint(self): "pk": 1, "int_field": 1, "__expandable__": ["expandable_dict", "expandable_string"], + "expandable_generic": {"id": "1234"}, } }, },