Skip to content

Commit

Permalink
Remove defer/stream support from subscriptions
Browse files Browse the repository at this point in the history
  • Loading branch information
Cito committed Apr 5, 2024
1 parent 98b44cc commit 5e451c3
Show file tree
Hide file tree
Showing 10 changed files with 541 additions and 382 deletions.
5 changes: 1 addition & 4 deletions src/graphql/execution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
default_field_resolver,
default_type_resolver,
subscribe,
experimental_subscribe_incrementally,
ExecutionContext,
ExecutionResult,
ExperimentalIncrementalExecutionResults,
Expand All @@ -30,7 +29,7 @@
FormattedIncrementalResult,
Middleware,
)
from .async_iterables import flatten_async_iterable, map_async_iterable
from .async_iterables import map_async_iterable
from .middleware import MiddlewareManager
from .values import get_argument_values, get_directive_values, get_variable_values

Expand All @@ -43,7 +42,6 @@
"default_field_resolver",
"default_type_resolver",
"subscribe",
"experimental_subscribe_incrementally",
"ExecutionContext",
"ExecutionResult",
"ExperimentalIncrementalExecutionResults",
Expand All @@ -58,7 +56,6 @@
"FormattedIncrementalDeferResult",
"FormattedIncrementalStreamResult",
"FormattedIncrementalResult",
"flatten_async_iterable",
"map_async_iterable",
"Middleware",
"MiddlewareManager",
Expand Down
17 changes: 1 addition & 16 deletions src/graphql/execution/async_iterables.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
Union,
)

__all__ = ["aclosing", "flatten_async_iterable", "map_async_iterable"]
__all__ = ["aclosing", "map_async_iterable"]

T = TypeVar("T")
V = TypeVar("V")
Expand Down Expand Up @@ -42,21 +42,6 @@ async def __aexit__(self, *_exc_info: object) -> None:
await aclose()


async def flatten_async_iterable(
iterable: AsyncIterableOrGenerator[AsyncIterableOrGenerator[T]],
) -> AsyncGenerator[T, None]:
"""Flatten async iterables.
Given an AsyncIterable of AsyncIterables, flatten all yielded results into a
single AsyncIterable.
"""
async with aclosing(iterable) as sub_iterators: # type: ignore
async for sub_iterator in sub_iterators:
async with aclosing(sub_iterator) as items: # type: ignore
async for item in items:
yield item


async def map_async_iterable(
iterable: AsyncIterableOrGenerator[T], callback: Callable[[T], Awaitable[V]]
) -> AsyncGenerator[V, None]:
Expand Down
29 changes: 24 additions & 5 deletions src/graphql/execution/collect_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
FragmentDefinitionNode,
FragmentSpreadNode,
InlineFragmentNode,
OperationDefinitionNode,
OperationType,
SelectionSetNode,
)
from ..type import (
Expand Down Expand Up @@ -43,7 +45,7 @@ def collect_fields(
fragments: Dict[str, FragmentDefinitionNode],
variable_values: Dict[str, Any],
runtime_type: GraphQLObjectType,
selection_set: SelectionSetNode,
operation: OperationDefinitionNode,
) -> FieldsAndPatches:
"""Collect fields.
Expand All @@ -61,8 +63,9 @@ def collect_fields(
schema,
fragments,
variable_values,
operation,
runtime_type,
selection_set,
operation.selection_set,
fields,
patches,
set(),
Expand All @@ -74,6 +77,7 @@ def collect_subfields(
schema: GraphQLSchema,
fragments: Dict[str, FragmentDefinitionNode],
variable_values: Dict[str, Any],
operation: OperationDefinitionNode,
return_type: GraphQLObjectType,
field_nodes: List[FieldNode],
) -> FieldsAndPatches:
Expand All @@ -100,6 +104,7 @@ def collect_subfields(
schema,
fragments,
variable_values,
operation,
return_type,
node.selection_set,
sub_field_nodes,
Expand All @@ -113,6 +118,7 @@ def collect_fields_impl(
schema: GraphQLSchema,
fragments: Dict[str, FragmentDefinitionNode],
variable_values: Dict[str, Any],
operation: OperationDefinitionNode,
runtime_type: GraphQLObjectType,
selection_set: SelectionSetNode,
fields: Dict[str, List[FieldNode]],
Expand All @@ -133,13 +139,14 @@ def collect_fields_impl(
) or not does_fragment_condition_match(schema, selection, runtime_type):
continue

defer = get_defer_values(variable_values, selection)
defer = get_defer_values(operation, variable_values, selection)
if defer:
patch_fields = defaultdict(list)
collect_fields_impl(
schema,
fragments,
variable_values,
operation,
runtime_type,
selection.selection_set,
patch_fields,
Expand All @@ -152,6 +159,7 @@ def collect_fields_impl(
schema,
fragments,
variable_values,
operation,
runtime_type,
selection.selection_set,
fields,
Expand All @@ -164,7 +172,7 @@ def collect_fields_impl(
if not should_include_node(variable_values, selection):
continue

defer = get_defer_values(variable_values, selection)
defer = get_defer_values(operation, variable_values, selection)
if frag_name in visited_fragment_names and not defer:
continue

Expand All @@ -183,6 +191,7 @@ def collect_fields_impl(
schema,
fragments,
variable_values,
operation,
runtime_type,
fragment.selection_set,
patch_fields,
Expand All @@ -195,6 +204,7 @@ def collect_fields_impl(
schema,
fragments,
variable_values,
operation,
runtime_type,
fragment.selection_set,
fields,
Expand All @@ -210,7 +220,9 @@ class DeferValues(NamedTuple):


def get_defer_values(
variable_values: Dict[str, Any], node: Union[FragmentSpreadNode, InlineFragmentNode]
operation: OperationDefinitionNode,
variable_values: Dict[str, Any],
node: Union[FragmentSpreadNode, InlineFragmentNode],
) -> Optional[DeferValues]:
"""Get values of defer directive if active.
Expand All @@ -223,6 +235,13 @@ def get_defer_values(
if not defer or defer.get("if") is False:
return None

if operation.operation == OperationType.SUBSCRIPTION:
msg = (
"`@defer` directive not supported on subscription operations."
" Disable `@defer` by setting the `if` argument to `false`."
)
raise TypeError(msg)

return DeferValues(defer.get("label"))


Expand Down
Loading

0 comments on commit 5e451c3

Please sign in to comment.