Skip to content

Commit

Permalink
add ExpandableGeneric
Browse files Browse the repository at this point in the history
  • Loading branch information
dshafer committed Nov 3, 2023
1 parent 8c3dbc2 commit ba4c001
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 5 deletions.
18 changes: 17 additions & 1 deletion django_declarative_apis/machinery/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
#

from abc import ABC, abstractmethod
from collections import defaultdict
import inspect
import logging
Expand Down Expand Up @@ -45,6 +45,16 @@ def expandable(model_class=None, display_key=None, inst_field_name=None):
return _ExpandableForeignKey(display_key, model_class, inst_field_name)


class ExpandableGeneric(ABC):
@abstractmethod
def get_unexpanded_view(self, inst) -> dict:
raise NotImplementedError()

@abstractmethod
def get_expanded_view(self, inst) -> dict:
raise NotImplementedError()


def _get_unexpanded_field_value(inst, field_name, field_type):
if not field_type.model_class:
return DEFAULT_UNEXPANDED_VALUE
Expand Down Expand Up @@ -141,6 +151,11 @@ def _get_filtered_field_value( # noqa: C901
val = getattr(inst, inst_field_name)
else:
val = _get_unexpanded_field_value(inst, field_name, field_type)
elif isinstance(field_type, ExpandableGeneric):
if expand_this:
val = field_type.get_expanded_view(inst)
else:
val = field_type.get_unexpanded_view(inst)
else:
try:
if isinstance(inst, (models.Model)):
Expand Down Expand Up @@ -184,6 +199,7 @@ def _get_filtered_field_value( # noqa: C901
or isinstance(field_type, types.FunctionType)
or ((field_type == IF_TRUTHY) and val)
or isinstance(field_type, _ExpandableForeignKey)
or isinstance(field_type, ExpandableGeneric)
):
return val
else:
Expand Down
17 changes: 16 additions & 1 deletion tests/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,24 @@
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
#

from django_declarative_apis.machinery.filtering import ALWAYS, NEVER, expandable
from django_declarative_apis.machinery.filtering import (
ALWAYS,
NEVER,
expandable,
ExpandableGeneric,
)

from . import models


class TestExpandableGeneric(ExpandableGeneric):
def get_unexpanded_view(self, inst) -> dict:
return {"id": "1234"}

def get_expanded_view(self, inst) -> dict:
return {"id": "1234", "expanded": True}


DEFAULT_FILTERS = {
str: ALWAYS,
int: ALWAYS,
Expand All @@ -18,6 +32,7 @@
"int_field": ALWAYS,
"expandable_dict": expandable(),
"expandable_string": expandable(),
"expandable_generic": TestExpandableGeneric(),
},
models.ChildModel: {
"pk": ALWAYS,
Expand Down
23 changes: 20 additions & 3 deletions tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,23 @@ def setUp(self):
self.p1.root = self.root
self.p1.save()

def test_expandable_generic_field(self):
# test unexpanded
filtered = filtering.apply_filters_to_object(
self.test_model, filters.DEFAULT_FILTERS
)
self.assertIn("expandable_generic", filtered)
self.assertEqual("1234", filtered["expandable_generic"]["id"])
self.assertNotIn("expanded", filtered["expandable_generic"])

# test exexpanded
filtered = filtering.apply_filters_to_object(
self.test_model, filters.DEFAULT_FILTERS, expand_header="expandable_generic"
)
self.assertIn("expandable_generic", filtered)
self.assertEqual("1234", filtered["expandable_generic"]["id"])
self.assertTrue(filtered["expandable_generic"]["expanded"])

def test_expandable_field_not_expanded_by_default(self):
filtered = filtering.apply_filters_to_object(self.root, filters.DEFAULT_FILTERS)
self.assertEqual(4, len(filtered))
Expand Down Expand Up @@ -115,7 +132,7 @@ def test_expand_multi_level_more_than_one_field(self):
self.assertTrue("pk" in filtered["parent_field"]["favorite"])
self.assertTrue("name" in filtered["parent_field"]["favorite"])
self.assertTrue("test" in filtered["parent_field"]["favorite"])
self.assertEqual(3, len(filtered["parent_field"]["favorite"]["test"]))
self.assertEqual(4, len(filtered["parent_field"]["favorite"]["test"]))
self.assertTrue("parent" in filtered["parent_field"]["favorite"])
self.assertEqual(self.p1c1.name, filtered["parent_field"]["favorite"]["name"])
self.assertEqual(2, len(filtered["parent_field"]["children"]))
Expand All @@ -135,7 +152,7 @@ def test_expandable_properties(self):
expand_header="expandable_dict,expandable_string",
)

self.assertEqual(5, len(filtered))
self.assertEqual(6, len(filtered))
self.assertTrue("expandable_dict" in filtered)
self.assertEqual(
filtered["expandable_dict"], models.TestModel.EXPANDABLE_DICT_RETURN
Expand All @@ -157,7 +174,7 @@ def test_expandable_properties(self):
self.test_model, filters.DEFAULT_FILTERS
)

self.assertEqual(3, len(filtered))
self.assertEqual(4, len(filtered))
self.assertFalse("expandable_dict" in filtered)
dict_mock.assert_not_called()
self.assertFalse("expandable_string" in filtered)
Expand Down
2 changes: 2 additions & 0 deletions tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@ def test_dict_endpoint(self):
"test": {
"pk": 1,
"int_field": 1,
"expandable_generic": {"id": "1234"},
"__expandable__": ["expandable_dict", "expandable_string"],
},
"deep_test": {
"test": {
"pk": 1,
"int_field": 1,
"__expandable__": ["expandable_dict", "expandable_string"],
"expandable_generic": {"id": "1234"},
}
},
},
Expand Down

0 comments on commit ba4c001

Please sign in to comment.