Skip to content

Commit

Permalink
Add support for Graphene3 and update lint with black (#72)
Browse files Browse the repository at this point in the history
Co-authored-by: Sebastian Hernandez <sebastian@rhinoafrica.com>
  • Loading branch information
ulgens and Sebastian Hernandez authored Aug 18, 2021
1 parent 101b767 commit e5c57fc
Show file tree
Hide file tree
Showing 10 changed files with 128 additions and 85 deletions.
4 changes: 2 additions & 2 deletions dev-env-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
-r requirements.txt
graphene==2.1.8
graphene-django==2.7.1
graphene==3.0b7
graphene-django==3.0.0b7
pytest==4.6.3
pytest-django==3.5.0
pytest-cov==2.7.1
Expand Down
9 changes: 5 additions & 4 deletions graphene_django_optimizer/field.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import types
from graphene.types.field import Field
from graphene.types.unmountedtype import UnmountedType

Expand All @@ -9,12 +10,12 @@ def field(field_type, *args, **kwargs):
field_type = Field.mounted(field_type)

optimization_hints = OptimizationHints(*args, **kwargs)
get_resolver = field_type.get_resolver
wrap_resolve = field_type.wrap_resolve

def get_optimized_resolver(parent_resolver):
resolver = get_resolver(parent_resolver)
def get_optimized_resolver(self, parent_resolver):
resolver = wrap_resolve(parent_resolver)
resolver.optimization_hints = optimization_hints
return resolver

field_type.get_resolver = get_optimized_resolver
field_type.wrap_resolve = types.MethodType(get_optimized_resolver, field_type)
return field_type
48 changes: 30 additions & 18 deletions graphene_django_optimizer/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,20 @@
from graphene.types.generic import GenericScalar
from graphene.types.resolver import default_resolver
from graphene_django import DjangoObjectType
from graphql import ResolveInfo
from graphql.execution.base import (
get_field_def,
)
from graphql import GraphQLResolveInfo, GraphQLSchema
from graphql.execution.execute import get_field_def
from graphql.language.ast import (
FragmentSpread,
InlineFragment,
Variable,
FragmentSpreadNode,
InlineFragmentNode,
VariableNode,
)
from graphql.type.definition import (
GraphQLInterfaceType,
GraphQLUnionType,
)

from graphql.pyutils import Path

from .utils import is_iterable


Expand All @@ -31,7 +31,7 @@ def query(queryset, info, **options):
Arguments:
- queryset (Django QuerySet object) - The queryset to be optimized
- info (GraphQL ResolveInfo object) - This is passed by the graphene-django resolve methods
- info (GraphQL GraphQLResolveInfo object) - This is passed by the graphene-django resolve methods
- **options - optimization options/settings
- disable_abort_only (boolean) - in case the objecttype contains any extra fields,
then this will keep the "only" optimization enabled.
Expand All @@ -54,7 +54,7 @@ def optimize(self, queryset):
field_def = get_field_def(info.schema, info.parent_type, info.field_name)
store = self._optimize_gql_selections(
self._get_type(field_def),
info.field_asts[0],
info.field_nodes[0],
# info.parent_type,
)
return store.optimize_queryset(queryset)
Expand All @@ -65,9 +65,16 @@ def _get_type(self, field_def):
a_type = a_type.of_type
return a_type

def _get_graphql_schema(self, schema):
if isinstance(schema, GraphQLSchema):
return schema
else:
return schema.graphql_schema

def _get_possible_types(self, graphql_type):
if isinstance(graphql_type, (GraphQLInterfaceType, GraphQLUnionType)):
return self.root_info.schema.get_possible_types(graphql_type)
graphql_schema = self._get_graphql_schema(self.root_info.schema)
return graphql_schema.get_possible_types(graphql_type)
else:
return (graphql_type,)

Expand All @@ -80,7 +87,8 @@ def _get_base_model(self, graphql_types):

def handle_inline_fragment(self, selection, schema, possible_types, store):
fragment_type_name = selection.type_condition.name.value
fragment_type = schema.get_type(fragment_type_name)
graphql_schema = self._get_graphql_schema(schema)
fragment_type = graphql_schema.get_type(fragment_type_name)
fragment_possible_types = self._get_possible_types(fragment_type)
for fragment_possible_type in fragment_possible_types:
fragment_model = fragment_possible_type.graphene_type._meta.model
Expand Down Expand Up @@ -120,14 +128,16 @@ def _optimize_gql_selections(self, field_type, field_ast):
return store
optimized_fields_by_model = {}
schema = self.root_info.schema
graphql_type = schema.get_graphql_type(field_type.graphene_type)
graphql_schema = self._get_graphql_schema(schema)
graphql_type = graphql_schema.get_type(field_type.name)

possible_types = self._get_possible_types(graphql_type)
for selection in selection_set.selections:
if isinstance(selection, InlineFragment):
if isinstance(selection, InlineFragmentNode):
self.handle_inline_fragment(selection, schema, possible_types, store)
else:
name = selection.name.value
if isinstance(selection, FragmentSpread):
if isinstance(selection, FragmentSpreadNode):
self.handle_fragment_spread(store, name, field_type)
else:
for possible_type in possible_types:
Expand Down Expand Up @@ -176,7 +186,7 @@ def _optimize_field(self, store, model, selection, field_def, parent_type):
store.abort_only_optimization()

def _optimize_field_by_name(self, store, model, selection, field_def):
name = self._get_name_from_resolver(field_def.resolver)
name = self._get_name_from_resolver(field_def.resolve)
if not name:
return False
model_field = self._get_model_field_from_name(model, name)
Expand Down Expand Up @@ -215,7 +225,7 @@ def _get_optimization_hints(self, resolver):
return getattr(resolver, "optimization_hints", None)

def _get_value(self, info, value):
if isinstance(value, Variable):
if isinstance(value, VariableNode):
var_name = value.name.value
value = info.variable_values.get(var_name)
return value
Expand All @@ -225,7 +235,7 @@ def _get_value(self, info, value):
return GenericScalar.parse_literal(value)

def _optimize_field_by_hints(self, store, selection, field_def, parent_type):
optimization_hints = self._get_optimization_hints(field_def.resolver)
optimization_hints = self._get_optimization_hints(field_def.resolve)
if not optimization_hints:
return False
info = self._create_resolve_info(
Expand Down Expand Up @@ -316,17 +326,19 @@ def _is_foreign_key_id(self, model_field, name):
)

def _create_resolve_info(self, field_name, field_asts, return_type, parent_type):
return ResolveInfo(
return GraphQLResolveInfo(
field_name,
field_asts,
return_type,
parent_type,
Path(None, 0, None),
schema=self.root_info.schema,
fragments=self.root_info.fragments,
root_value=self.root_info.root_value,
operation=self.root_info.operation,
variable_values=self.root_info.variable_values,
context=self.root_info.context,
is_awaitable=self.root_info.is_awaitable,
)


Expand Down
50 changes: 25 additions & 25 deletions tests/graphql_utils.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,73 @@
from graphql import (
ResolveInfo,
GraphQLResolveInfo,
Source,
Undefined,
parse,
)
from graphql.execution.base import (
from graphql.execution.execute import (
ExecutionContext,
collect_fields,
get_field_def,
get_operation_root_type,
)
from graphql.pyutils.default_ordered_dict import DefaultOrderedDict
from graphql.utilities import get_operation_root_type
from collections import defaultdict

from graphql.pyutils import Path


def create_execution_context(schema, request_string, variables=None):
source = Source(request_string, "GraphQL request")
document_ast = parse(source)
return ExecutionContext(
return ExecutionContext.build(
schema,
document_ast,
root_value=None,
context_value=None,
variable_values=variables,
raw_variable_values=variables,
operation_name=None,
executor=None,
middleware=None,
allow_subscriptions=False,
)


def get_field_asts_from_execution_context(exe_context):
fields = collect_fields(
exe_context,
fields = exe_context.collect_fields(
type,
exe_context.operation.selection_set,
DefaultOrderedDict(list),
defaultdict(list),
set(),
)
# field_asts = next(iter(fields.values()))
field_asts = tuple(fields.values())[0]
return field_asts


def create_resolve_info(schema, request_string, variables=None):
def create_resolve_info(schema, request_string, variables=None, return_type=None):
exe_context = create_execution_context(schema, request_string, variables)
parent_type = get_operation_root_type(schema, exe_context.operation)
field_asts = get_field_asts_from_execution_context(exe_context)

field_ast = field_asts[0]
field_name = field_ast.name.value

field_def = get_field_def(schema, parent_type, field_name)
if not field_def:
return Undefined
return_type = field_def.type
if return_type is None:
field_def = get_field_def(schema, parent_type, field_name)
if not field_def:
return Undefined
return_type = field_def.type

# The resolve function's optional third argument is a context value that
# is provided to every resolve function within an execution. It is commonly
# used to represent an authenticated user, or request-specific caches.
context = exe_context.context_value
return ResolveInfo(
return GraphQLResolveInfo(
field_name,
field_asts,
return_type,
parent_type,
schema=schema,
fragments=exe_context.fragments,
root_value=exe_context.root_value,
operation=exe_context.operation,
variable_values=exe_context.variable_values,
context=context,
Path(None, 0, None),
schema,
exe_context.fragments,
exe_context.root_value,
exe_context.operation,
exe_context.variable_values,
exe_context.context_value,
exe_context.is_awaitable,
)
32 changes: 29 additions & 3 deletions tests/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class BaseItemType(OptimizedDjangoObjectType):

class Meta:
model = Item
fields = "__all__"

@gql_optimizer.resolver_hints(
model_field="children",
Expand All @@ -110,6 +111,8 @@ def resolve_relay_all_children(root, info, **kwargs):
class ItemNode(BaseItemType):
class Meta:
model = Item
fields = "__all__"

interfaces = (
graphene.relay.Node,
ItemInterface,
Expand All @@ -119,16 +122,19 @@ class Meta:
class SomeOtherItemType(OptimizedDjangoObjectType):
class Meta:
model = SomeOtherItem
fields = "__all__"


class OtherItemType(OptimizedDjangoObjectType):
class Meta:
model = OtherItem
fields = "__all__"


class ItemType(BaseItemType):
class Meta:
model = Item
fields = "__all__"
interfaces = (ItemInterface,)


Expand All @@ -144,29 +150,34 @@ class DetailedInterface(graphene.Interface):
class DetailedItemType(ItemType):
class Meta:
model = DetailedItem
fields = "__all__"
interfaces = (ItemInterface, DetailedInterface)


class RelatedItemType(ItemType):
class Meta:
model = RelatedItem
fields = "__all__"
interfaces = (ItemInterface,)


class ExtraDetailedItemType(DetailedItemType):
class Meta:
model = ExtraDetailedItem
fields = "__all__"
interfaces = (ItemInterface,)


class RelatedOneToManyItemType(OptimizedDjangoObjectType):
class Meta:
model = RelatedOneToManyItem
fields = "__all__"


class UnrelatedModelType(OptimizedDjangoObjectType):
class Meta:
model = UnrelatedModel
fields = "__all__"
interfaces = (DetailedInterface,)


Expand Down Expand Up @@ -200,6 +211,21 @@ def resolve_other_items(root, info):
return gql_optimizer.query(OtherItemType.objects.all(), info)


schema = graphene.Schema(
query=Query, types=(UnrelatedModelType,), mutation=DummyItemMutation
)
class Schema(graphene.Schema):
@property
def query_type(self):
return self.graphql_schema.get_type("Query")

@property
def mutation_type(self):
return self.graphql_schema.get_type("Mutation")

@property
def subscription_type(self):
return self.graphql_schema.get_type("Subscription")

def get_type(self, _type):
return self.graphql_schema.get_type(_type)


schema = Schema(query=Query, types=(UnrelatedModelType,), mutation=DummyItemMutation)
7 changes: 4 additions & 3 deletions tests/test_field.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import pytest
import graphene_django_optimizer as gql_optimizer

from .graphql_utils import create_resolve_info
from .models import (
Item,
)
from .models import Item
from .schema import schema
from .test_utils import assert_query_equality


@pytest.mark.django_db
def test_should_optimize_non_django_field_if_it_has_an_optimization_hint_in_the_field():
info = create_resolve_info(
schema,
Expand All @@ -29,6 +29,7 @@ def test_should_optimize_non_django_field_if_it_has_an_optimization_hint_in_the_
assert_query_equality(items, optimized_items)


@pytest.mark.django_db
def test_should_optimize_with_only_hint():
info = create_resolve_info(
schema,
Expand Down
Loading

0 comments on commit e5c57fc

Please sign in to comment.