From 5e451c31917470d6c5dcfa706942d54fc684ac76 Mon Sep 17 00:00:00 2001 From: Christoph Zwerschke Date: Fri, 5 Apr 2024 15:48:07 +0200 Subject: [PATCH] Remove defer/stream support from subscriptions Replicates graphql/graphql-js@1bf71eeec71d26b532a3722c54d0552ec1706af5 --- src/graphql/execution/__init__.py | 5 +- src/graphql/execution/async_iterables.py | 17 +- src/graphql/execution/collect_fields.py | 29 +- src/graphql/execution/execute.py | 170 ++------ src/graphql/validation/__init__.py | 6 + ...ream_directive_on_valid_operations_rule.py | 83 ++++ .../rules/single_field_subscriptions.py | 2 +- src/graphql/validation/specified_rules.py | 6 + .../execution/test_flatten_async_iterable.py | 210 ---------- ...er_stream_directive_on_valid_operations.py | 395 ++++++++++++++++++ 10 files changed, 541 insertions(+), 382 deletions(-) create mode 100644 src/graphql/validation/rules/defer_stream_directive_on_valid_operations_rule.py delete mode 100644 tests/execution/test_flatten_async_iterable.py create mode 100644 tests/validation/test_defer_stream_directive_on_valid_operations.py diff --git a/src/graphql/execution/__init__.py b/src/graphql/execution/__init__.py index 29aa1594..e33d4ce7 100644 --- a/src/graphql/execution/__init__.py +++ b/src/graphql/execution/__init__.py @@ -13,7 +13,6 @@ default_field_resolver, default_type_resolver, subscribe, - experimental_subscribe_incrementally, ExecutionContext, ExecutionResult, ExperimentalIncrementalExecutionResults, @@ -30,7 +29,7 @@ FormattedIncrementalResult, Middleware, ) -from .async_iterables import flatten_async_iterable, map_async_iterable +from .async_iterables import map_async_iterable from .middleware import MiddlewareManager from .values import get_argument_values, get_directive_values, get_variable_values @@ -43,7 +42,6 @@ "default_field_resolver", "default_type_resolver", "subscribe", - "experimental_subscribe_incrementally", "ExecutionContext", "ExecutionResult", "ExperimentalIncrementalExecutionResults", @@ -58,7 +56,6 @@ "FormattedIncrementalDeferResult", "FormattedIncrementalStreamResult", "FormattedIncrementalResult", - "flatten_async_iterable", "map_async_iterable", "Middleware", "MiddlewareManager", diff --git a/src/graphql/execution/async_iterables.py b/src/graphql/execution/async_iterables.py index 7b7f6340..305b495f 100644 --- a/src/graphql/execution/async_iterables.py +++ b/src/graphql/execution/async_iterables.py @@ -12,7 +12,7 @@ Union, ) -__all__ = ["aclosing", "flatten_async_iterable", "map_async_iterable"] +__all__ = ["aclosing", "map_async_iterable"] T = TypeVar("T") V = TypeVar("V") @@ -42,21 +42,6 @@ async def __aexit__(self, *_exc_info: object) -> None: await aclose() -async def flatten_async_iterable( - iterable: AsyncIterableOrGenerator[AsyncIterableOrGenerator[T]], -) -> AsyncGenerator[T, None]: - """Flatten async iterables. - - Given an AsyncIterable of AsyncIterables, flatten all yielded results into a - single AsyncIterable. - """ - async with aclosing(iterable) as sub_iterators: # type: ignore - async for sub_iterator in sub_iterators: - async with aclosing(sub_iterator) as items: # type: ignore - async for item in items: - yield item - - async def map_async_iterable( iterable: AsyncIterableOrGenerator[T], callback: Callable[[T], Awaitable[V]] ) -> AsyncGenerator[V, None]: diff --git a/src/graphql/execution/collect_fields.py b/src/graphql/execution/collect_fields.py index 260e10ae..e7d64fe8 100644 --- a/src/graphql/execution/collect_fields.py +++ b/src/graphql/execution/collect_fields.py @@ -8,6 +8,8 @@ FragmentDefinitionNode, FragmentSpreadNode, InlineFragmentNode, + OperationDefinitionNode, + OperationType, SelectionSetNode, ) from ..type import ( @@ -43,7 +45,7 @@ def collect_fields( fragments: Dict[str, FragmentDefinitionNode], variable_values: Dict[str, Any], runtime_type: GraphQLObjectType, - selection_set: SelectionSetNode, + operation: OperationDefinitionNode, ) -> FieldsAndPatches: """Collect fields. @@ -61,8 +63,9 @@ def collect_fields( schema, fragments, variable_values, + operation, runtime_type, - selection_set, + operation.selection_set, fields, patches, set(), @@ -74,6 +77,7 @@ def collect_subfields( schema: GraphQLSchema, fragments: Dict[str, FragmentDefinitionNode], variable_values: Dict[str, Any], + operation: OperationDefinitionNode, return_type: GraphQLObjectType, field_nodes: List[FieldNode], ) -> FieldsAndPatches: @@ -100,6 +104,7 @@ def collect_subfields( schema, fragments, variable_values, + operation, return_type, node.selection_set, sub_field_nodes, @@ -113,6 +118,7 @@ def collect_fields_impl( schema: GraphQLSchema, fragments: Dict[str, FragmentDefinitionNode], variable_values: Dict[str, Any], + operation: OperationDefinitionNode, runtime_type: GraphQLObjectType, selection_set: SelectionSetNode, fields: Dict[str, List[FieldNode]], @@ -133,13 +139,14 @@ def collect_fields_impl( ) or not does_fragment_condition_match(schema, selection, runtime_type): continue - defer = get_defer_values(variable_values, selection) + defer = get_defer_values(operation, variable_values, selection) if defer: patch_fields = defaultdict(list) collect_fields_impl( schema, fragments, variable_values, + operation, runtime_type, selection.selection_set, patch_fields, @@ -152,6 +159,7 @@ def collect_fields_impl( schema, fragments, variable_values, + operation, runtime_type, selection.selection_set, fields, @@ -164,7 +172,7 @@ def collect_fields_impl( if not should_include_node(variable_values, selection): continue - defer = get_defer_values(variable_values, selection) + defer = get_defer_values(operation, variable_values, selection) if frag_name in visited_fragment_names and not defer: continue @@ -183,6 +191,7 @@ def collect_fields_impl( schema, fragments, variable_values, + operation, runtime_type, fragment.selection_set, patch_fields, @@ -195,6 +204,7 @@ def collect_fields_impl( schema, fragments, variable_values, + operation, runtime_type, fragment.selection_set, fields, @@ -210,7 +220,9 @@ class DeferValues(NamedTuple): def get_defer_values( - variable_values: Dict[str, Any], node: Union[FragmentSpreadNode, InlineFragmentNode] + operation: OperationDefinitionNode, + variable_values: Dict[str, Any], + node: Union[FragmentSpreadNode, InlineFragmentNode], ) -> Optional[DeferValues]: """Get values of defer directive if active. @@ -223,6 +235,13 @@ def get_defer_values( if not defer or defer.get("if") is False: return None + if operation.operation == OperationType.SUBSCRIPTION: + msg = ( + "`@defer` directive not supported on subscription operations." + " Disable `@defer` by setting the `if` argument to `false`." + ) + raise TypeError(msg) + return DeferValues(defer.get("label")) diff --git a/src/graphql/execution/execute.py b/src/graphql/execution/execute.py index 74ead0af..6310d33b 100644 --- a/src/graphql/execution/execute.py +++ b/src/graphql/execution/execute.py @@ -77,7 +77,7 @@ is_non_null_type, is_object_type, ) -from .async_iterables import flatten_async_iterable, map_async_iterable +from .async_iterables import map_async_iterable from .collect_fields import FieldsAndPatches, collect_fields, collect_subfields from .middleware import MiddlewareManager from .values import get_argument_values, get_directive_values, get_variable_values @@ -101,7 +101,6 @@ async def anext(iterator: AsyncIterator) -> Any: # noqa: A001 "execute", "execute_sync", "experimental_execute_incrementally", - "experimental_subscribe_incrementally", "subscribe", "AsyncPayloadRecord", "DeferredFragmentRecord", @@ -817,7 +816,7 @@ def execute_operation(self) -> AwaitableOrValue[Dict[str, Any]]: self.fragments, self.variable_values, root_type, - operation.selection_set, + operation, ) root_value = self.root_value @@ -1173,6 +1172,13 @@ def get_stream_values( msg = "initialCount must be a positive integer" raise ValueError(msg) + if self.operation.operation == OperationType.SUBSCRIPTION: + msg = ( + "`@stream` directive not supported on subscription operations." + " Disable `@stream` by setting the `if` argument to `false`." + ) + raise TypeError(msg) + label = stream.get("label") return StreamArguments(initial_count=initial_count, label=label) @@ -1644,6 +1650,7 @@ def collect_subfields( self.schema, self.fragments, self.variable_values, + self.operation, return_type, field_nodes, ) @@ -1652,17 +1659,7 @@ def collect_subfields( def map_source_to_response( self, result_or_stream: Union[ExecutionResult, AsyncIterable[Any]] - ) -> Union[ - AsyncGenerator[ - Union[ - ExecutionResult, - InitialIncrementalExecutionResult, - SubsequentIncrementalExecutionResult, - ], - None, - ], - ExecutionResult, - ]: + ) -> Union[AsyncGenerator[ExecutionResult, None], ExecutionResult]: """Map source result to response. For each payload yielded from a subscription, @@ -1678,13 +1675,17 @@ def map_source_to_response( if not isinstance(result_or_stream, AsyncIterable): return result_or_stream # pragma: no cover - async def callback(payload: Any) -> AsyncGenerator: + async def callback(payload: Any) -> ExecutionResult: result = execute_impl(self.build_per_event_execution_context(payload)) - return ensure_async_iterable( - await result if self.is_awaitable(result) else result # type: ignore + # typecast to ExecutionResult, not possible to return + # ExperimentalIncrementalExecutionResults when operation is 'subscription'. + return ( + await cast(Awaitable[ExecutionResult], result) + if self.is_awaitable(result) + else cast(ExecutionResult, result) ) - return flatten_async_iterable(map_async_iterable(result_or_stream, callback)) + return map_async_iterable(result_or_stream, callback) def execute_deferred_fragment( self, @@ -2015,8 +2016,8 @@ def execute( a GraphQLError will be thrown immediately explaining the invalid input. This function does not support incremental delivery (`@defer` and `@stream`). - If an operation which would defer or stream data is executed with this - function, it will throw or resolve to an object containing an error instead. + If an operation that defers or streams data is executed with this function, + it will throw or resolve to an object containing an error instead. Use `experimental_execute_incrementally` if you want to support incremental delivery. """ @@ -2362,111 +2363,8 @@ def subscribe( a stream of ExecutionResults representing the response stream. This function does not support incremental delivery (`@defer` and `@stream`). - If an operation which would defer or stream data is executed with this function, - each :class:`InitialIncrementalExecutionResult` and - :class:`SubsequentIncrementalExecutionResult` - in the result stream will be replaced with an :class:`ExecutionResult` - with a single error stating that defer/stream is not supported. - Use :func:`experimental_subscribe_incrementally` if you want to support - incremental delivery. - """ - result = experimental_subscribe_incrementally( - schema, - document, - root_value, - context_value, - variable_values, - operation_name, - field_resolver, - type_resolver, - subscribe_field_resolver, - execution_context_class, - ) - - if isinstance(result, ExecutionResult): - return result - if isinstance(result, AsyncIterable): - return map_async_iterable(result, ensure_single_execution_result) - - async def await_result() -> Union[AsyncIterator[ExecutionResult], ExecutionResult]: - result_or_iterable = await result - if isinstance(result_or_iterable, AsyncIterable): - return map_async_iterable( - result_or_iterable, ensure_single_execution_result - ) - return result_or_iterable - - return await_result() - - -async def ensure_single_execution_result( - result: Union[ - ExecutionResult, - InitialIncrementalExecutionResult, - SubsequentIncrementalExecutionResult, - ], -) -> ExecutionResult: - """Ensure that the given result does not use incremental delivery.""" - if not isinstance(result, ExecutionResult): - return ExecutionResult( - None, errors=[GraphQLError(UNEXPECTED_MULTIPLE_PAYLOADS)] - ) - return result - - -def experimental_subscribe_incrementally( - schema: GraphQLSchema, - document: DocumentNode, - root_value: Any = None, - context_value: Any = None, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, - field_resolver: Optional[GraphQLFieldResolver] = None, - type_resolver: Optional[GraphQLTypeResolver] = None, - subscribe_field_resolver: Optional[GraphQLFieldResolver] = None, - execution_context_class: Optional[Type[ExecutionContext]] = None, -) -> AwaitableOrValue[ - Union[ - AsyncGenerator[ - Union[ - ExecutionResult, - InitialIncrementalExecutionResult, - SubsequentIncrementalExecutionResult, - ], - None, - ], - ExecutionResult, - ] -]: - """Create a GraphQL subscription. - - Implements the "Subscribe" algorithm described in the GraphQL spec. - - Returns a coroutine object which yields either an AsyncIterator (if successful) or - an ExecutionResult (client error). The coroutine will raise an exception if a server - error occurs. - - If the client-provided arguments to this function do not result in a compliant - subscription, a GraphQL Response (ExecutionResult) with descriptive errors and no - data will be returned. - - If the source stream could not be created due to faulty subscription resolver logic - or underlying systems, the coroutine object will yield a single ExecutionResult - containing ``errors`` and no ``data``. - - If the operation succeeded, the coroutine will yield an AsyncIterator, which yields - a stream of ExecutionResults representing the response stream. - - Each result may be an ExecutionResult with no ``has_next`` attribute (if executing - the event did not use `@defer` or `@stream`), or an - :class:`InitialIncrementalExecutionResult` or - :class:`SubsequentIncrementalExecutionResult` - (if executing the event used `@defer` or `@stream`). In the case of - incremental execution results, each event produces a single - :class:`InitialIncrementalExecutionResult` followed by one or more - :class:`SubsequentIncrementalExecutionResult`; all but the last have - ``has_next == true``, and the last has ``has_next == False``. - There is no interleaving between results generated from the same original event. + If an operation that defers or streams data is executed with this function, + a field error will be raised at the location of the `@defer` or `@stream` directive. """ if execution_context_class is None: execution_context_class = ExecutionContext @@ -2507,26 +2405,6 @@ async def await_result() -> Any: return context.map_source_to_response(result_or_stream) # type: ignore -async def ensure_async_iterable( - some_execution_result: Union[ - ExecutionResult, ExperimentalIncrementalExecutionResults - ], -) -> AsyncGenerator[ - Union[ - ExecutionResult, - InitialIncrementalExecutionResult, - SubsequentIncrementalExecutionResult, - ], - None, -]: - if isinstance(some_execution_result, ExecutionResult): - yield some_execution_result - else: - yield some_execution_result.initial_result - async for result in some_execution_result.subsequent_results: - yield result - - def create_source_event_stream( schema: GraphQLSchema, document: DocumentNode, @@ -2622,7 +2500,7 @@ def execute_subscription( context.fragments, context.variable_values, root_type, - context.operation.selection_set, + context.operation, ).fields first_root_field = next(iter(root_fields.items())) response_name, field_nodes = first_root_field diff --git a/src/graphql/validation/__init__.py b/src/graphql/validation/__init__.py index 270eed06..8f67f9b7 100644 --- a/src/graphql/validation/__init__.py +++ b/src/graphql/validation/__init__.py @@ -23,6 +23,11 @@ # Spec Section: "Defer And Stream Directives Are Used On Valid Root Field" from .rules.defer_stream_directive_on_root_field import DeferStreamDirectiveOnRootField +# Spec Section: "Defer And Stream Directives Are Used On Valid Operations" +from .rules.defer_stream_directive_on_valid_operations_rule import ( + DeferStreamDirectiveOnValidOperationsRule, +) + # Spec Section: "Executable Definitions" from .rules.executable_definitions import ExecutableDefinitionsRule @@ -129,6 +134,7 @@ "specified_rules", "DeferStreamDirectiveLabel", "DeferStreamDirectiveOnRootField", + "DeferStreamDirectiveOnValidOperationsRule", "ExecutableDefinitionsRule", "FieldsOnCorrectTypeRule", "FragmentsOnCompositeTypesRule", diff --git a/src/graphql/validation/rules/defer_stream_directive_on_valid_operations_rule.py b/src/graphql/validation/rules/defer_stream_directive_on_valid_operations_rule.py new file mode 100644 index 00000000..391c8932 --- /dev/null +++ b/src/graphql/validation/rules/defer_stream_directive_on_valid_operations_rule.py @@ -0,0 +1,83 @@ +"""Defer stream directive on valid operations rule""" + +from typing import Any, List, Set + +from ...error import GraphQLError +from ...language import ( + BooleanValueNode, + DirectiveNode, + FragmentDefinitionNode, + Node, + OperationDefinitionNode, + OperationType, + VariableNode, +) +from ...type import GraphQLDeferDirective, GraphQLStreamDirective +from . import ASTValidationRule, ValidationContext + +__all__ = ["DeferStreamDirectiveOnValidOperationsRule"] + + +def if_argument_can_be_false(node: DirectiveNode) -> bool: + for argument in node.arguments: + if argument.name.value == "if": + if isinstance(argument.value, BooleanValueNode): + if argument.value.value: + return False + elif not isinstance(argument.value, VariableNode): + return False + return True + return False + + +class DeferStreamDirectiveOnValidOperationsRule(ASTValidationRule): + """Defer and stream directives are used on valid root field + + A GraphQL document is only valid if defer directives are not used on root + mutation or subscription types. + """ + + def __init__(self, context: ValidationContext) -> None: + super().__init__(context) + self.fragments_used_on_subscriptions: Set[str] = set() + + def enter_operation_definition( + self, operation: OperationDefinitionNode, *_args: Any + ) -> None: + if operation.operation == OperationType.SUBSCRIPTION: + fragments = self.context.get_recursively_referenced_fragments(operation) + for fragment in fragments: + self.fragments_used_on_subscriptions.add(fragment.name.value) + + def enter_directive( + self, + node: DirectiveNode, + _key: Any, + _parent: Any, + _path: Any, + ancestors: List[Node], + ) -> None: + try: + definition_node = ancestors[2] + except IndexError: # pragma: no cover + return + if ( + isinstance(definition_node, FragmentDefinitionNode) + and definition_node.name.value in self.fragments_used_on_subscriptions + or isinstance(definition_node, OperationDefinitionNode) + and definition_node.operation == OperationType.SUBSCRIPTION + ): + if node.name.value == GraphQLDeferDirective.name: + if not if_argument_can_be_false(node): + msg = ( + "Defer directive not supported on subscription operations." + " Disable `@defer` by setting the `if` argument to `false`." + ) + self.report_error(GraphQLError(msg, node)) + elif node.name.value == GraphQLStreamDirective.name: # noqa: SIM102 + if not if_argument_can_be_false(node): + msg = ( + "Stream directive not supported on subscription operations." + " Disable `@defer` by setting the `if` argument to `false`." + ) + self.report_error(GraphQLError(msg, node)) diff --git a/src/graphql/validation/rules/single_field_subscriptions.py b/src/graphql/validation/rules/single_field_subscriptions.py index 40d37eb2..e8ce9ec5 100644 --- a/src/graphql/validation/rules/single_field_subscriptions.py +++ b/src/graphql/validation/rules/single_field_subscriptions.py @@ -45,7 +45,7 @@ def enter_operation_definition( fragments, variable_values, subscription_type, - node.selection_set, + node, ).fields if len(fields) > 1: field_selection_lists = list(fields.values()) diff --git a/src/graphql/validation/specified_rules.py b/src/graphql/validation/specified_rules.py index d8c225d8..e024d0d1 100644 --- a/src/graphql/validation/specified_rules.py +++ b/src/graphql/validation/specified_rules.py @@ -10,6 +10,11 @@ # Spec Section: "Defer And Stream Directives Are Used On Valid Root Field" from .rules.defer_stream_directive_on_root_field import DeferStreamDirectiveOnRootField +# Spec Section: "Defer And Stream Directives Are Used On Valid Operations" +from .rules.defer_stream_directive_on_valid_operations_rule import ( + DeferStreamDirectiveOnValidOperationsRule, +) + # Spec Section: "Executable Definitions" from .rules.executable_definitions import ExecutableDefinitionsRule @@ -136,6 +141,7 @@ KnownDirectivesRule, UniqueDirectivesPerLocationRule, DeferStreamDirectiveOnRootField, + DeferStreamDirectiveOnValidOperationsRule, DeferStreamDirectiveLabel, StreamDirectiveOnListField, KnownArgumentNamesRule, diff --git a/tests/execution/test_flatten_async_iterable.py b/tests/execution/test_flatten_async_iterable.py deleted file mode 100644 index 357e4cd0..00000000 --- a/tests/execution/test_flatten_async_iterable.py +++ /dev/null @@ -1,210 +0,0 @@ -from contextlib import suppress -from typing import AsyncGenerator - -import pytest -from graphql.execution import flatten_async_iterable - -try: # pragma: no cover - anext # noqa: B018 -except NameError: # pragma: no cover (Python < 3.10) - # noinspection PyShadowingBuiltins - async def anext(iterator): # noqa: A001 - """Return the next item from an async iterator.""" - return await iterator.__anext__() - - -def describe_flatten_async_iterable(): - @pytest.mark.asyncio() - async def flattens_nested_async_generators(): - async def source(): - async def nested1() -> AsyncGenerator[float, None]: - yield 1.1 - yield 1.2 - - async def nested2() -> AsyncGenerator[float, None]: - yield 2.1 - yield 2.2 - - yield nested1() - yield nested2() - - doubles = flatten_async_iterable(source()) - - result = [x async for x in doubles] - - assert result == [1.1, 1.2, 2.1, 2.2] - - @pytest.mark.asyncio() - async def allows_returning_early_from_a_nested_async_generator(): - async def source(): - async def nested1() -> AsyncGenerator[float, None]: - yield 1.1 - yield 1.2 - - async def nested2() -> AsyncGenerator[float, None]: - yield 2.1 - # Not reachable, early return - yield 2.2 # pragma: no cover - - # Not reachable, early return - async def nested3() -> AsyncGenerator[float, None]: - yield 3.1 # pragma: no cover - yield 3.2 # pragma: no cover - - yield nested1() - yield nested2() - yield nested3() # pragma: no cover - - doubles = flatten_async_iterable(source()) - - assert await anext(doubles) == 1.1 - assert await anext(doubles) == 1.2 - assert await anext(doubles) == 2.1 - - # early return - with suppress(RuntimeError): # suppress error for Python < 3.8 - await doubles.aclose() - - # subsequent anext calls - with pytest.raises(StopAsyncIteration): - assert await anext(doubles) - with pytest.raises(StopAsyncIteration): - assert await anext(doubles) - - @pytest.mark.asyncio() - async def allows_throwing_errors_from_a_nested_async_generator(): - async def source(): - async def nested1() -> AsyncGenerator[float, None]: - yield 1.1 - yield 1.2 - - async def nested2() -> AsyncGenerator[float, None]: - yield 2.1 - # Not reachable, early return - yield 2.2 # pragma: no cover - - # Not reachable, early return - async def nested3() -> AsyncGenerator[float, None]: - yield 3.1 # pragma: no cover - yield 3.2 # pragma: no cover - - yield nested1() - yield nested2() - yield nested3() # pragma: no cover - - doubles = flatten_async_iterable(source()) - - assert await anext(doubles) == 1.1 - assert await anext(doubles) == 1.2 - assert await anext(doubles) == 2.1 - - # throw error - with pytest.raises(RuntimeError, match="ouch"): - await doubles.athrow(RuntimeError("ouch")) - - @pytest.mark.asyncio() - async def completely_yields_sub_iterables_even_when_anext_called_in_parallel(): - async def source(): - async def nested1() -> AsyncGenerator[float, None]: - yield 1.1 - yield 1.2 - - async def nested2() -> AsyncGenerator[float, None]: - yield 2.1 - yield 2.2 - - yield nested1() - yield nested2() - - doubles = flatten_async_iterable(source()) - - anext1 = anext(doubles) - anext2 = anext(doubles) - assert await anext1 == 1.1 - assert await anext2 == 1.2 - assert await anext(doubles) == 2.1 - assert await anext(doubles) == 2.2 - with pytest.raises(StopAsyncIteration): - assert await anext(doubles) - - @pytest.mark.asyncio() - async def closes_nested_async_iterators(): - closed = [] - - class Source: - def __init__(self): - self.counter = 0 - - def __aiter__(self): - return self - - async def __anext__(self): - if self.counter == 2: - raise StopAsyncIteration - self.counter += 1 - return Nested(self.counter) - - async def aclose(self): - nonlocal closed - closed.append(self.counter) - - class Nested: - def __init__(self, value): - self.value = value - self.counter = 0 - - def __aiter__(self): - return self - - async def __anext__(self): - if self.counter == 2: - raise StopAsyncIteration - self.counter += 1 - return self.value + self.counter / 10 - - async def aclose(self): - nonlocal closed - closed.append(self.value + self.counter / 10) - - doubles = flatten_async_iterable(Source()) - - result = [x async for x in doubles] - - assert result == [1.1, 1.2, 2.1, 2.2] - - assert closed == [1.2, 2.2, 2] - - @pytest.mark.asyncio() - async def works_with_nested_async_iterators_that_have_no_close_method(): - class Source: - def __init__(self): - self.counter = 0 - - def __aiter__(self): - return self - - async def __anext__(self): - if self.counter == 2: - raise StopAsyncIteration - self.counter += 1 - return Nested(self.counter) - - class Nested: - def __init__(self, value): - self.value = value - self.counter = 0 - - def __aiter__(self): - return self - - async def __anext__(self): - if self.counter == 2: - raise StopAsyncIteration - self.counter += 1 - return self.value + self.counter / 10 - - doubles = flatten_async_iterable(Source()) - - result = [x async for x in doubles] - - assert result == [1.1, 1.2, 2.1, 2.2] diff --git a/tests/validation/test_defer_stream_directive_on_valid_operations.py b/tests/validation/test_defer_stream_directive_on_valid_operations.py new file mode 100644 index 00000000..7d33fd2b --- /dev/null +++ b/tests/validation/test_defer_stream_directive_on_valid_operations.py @@ -0,0 +1,395 @@ +from functools import partial + +from graphql.utilities import build_schema +from graphql.validation import DeferStreamDirectiveOnValidOperationsRule + +from .harness import assert_validation_errors + +schema = build_schema( + """ + type Message { + body: String + sender: String + } + + type SubscriptionRoot { + subscriptionField: Message + subscriptionListField: [Message] + } + + type MutationRoot { + mutationField: Message + mutationListField: [Message] + } + + type QueryRoot { + message: Message + messages: [Message] + } + + schema { + query: QueryRoot + mutation: MutationRoot + subscription: SubscriptionRoot + } + """ +) + +assert_errors = partial( + assert_validation_errors, DeferStreamDirectiveOnValidOperationsRule, schema=schema +) + +assert_valid = partial(assert_errors, errors=[]) + + +def describe_defer_stream_directive_on_valid_operations(): + def defer_fragment_spread_nested_in_query_operation(): + assert_valid( + """ + { + message { + ...myFragment @defer + } + } + fragment myFragment on Message { + message { + body + } + } + """ + ) + + def defer_inline_fragment_spread_in_query_operation(): + assert_valid( + """ + { + ... @defer { + message { + body + } + } + } + """ + ) + + def defer_fragment_spread_on_mutation_field(): + assert_valid( + """ + mutation { + mutationField { + ...myFragment @defer + } + } + fragment myFragment on Message { + body + } + """ + ) + + def defer_inline_fragment_spread_on_mutation_field(): + assert_valid( + """ + mutation { + mutationField { + ... @defer { + body + } + } + } + """ + ) + + def defer_fragment_spread_on_subscription_field(): + assert_errors( + """ + subscription { + subscriptionField { + ...myFragment @defer + } + } + fragment myFragment on Message { + body + } + """, + [ + { + "message": "Defer directive not supported" + " on subscription operations." + " Disable `@defer` by setting the `if` argument to `false`.", + "locations": [(4, 31)], + }, + ], + ) + + def defer_fragment_spread_with_boolean_true_if_argument(): + assert_errors( + """ + subscription { + subscriptionField { + ...myFragment @defer(if: true) + } + } + fragment myFragment on Message { + body + } + """, + [ + { + "message": "Defer directive not supported" + " on subscription operations." + " Disable `@defer` by setting the `if` argument to `false`.", + "locations": [(4, 31)], + }, + ], + ) + + def defer_fragment_spread_with_boolean_false_if_argument(): + assert_valid( + """ + subscription { + subscriptionField { + ...myFragment @defer(if: false) + } + } + fragment myFragment on Message { + body + } + """ + ) + + def defer_fragment_spread_on_query_in_multi_operation_document(): + assert_valid( + """ + subscription MySubscription { + subscriptionField { + ...myFragment + } + } + query MyQuery { + message { + ...myFragment @defer + } + } + fragment myFragment on Message { + body + } + """ + ) + + def defer_fragment_spread_on_subscription_in_multi_operation_document(): + assert_errors( + """ + subscription MySubscription { + subscriptionField { + ...myFragment @defer + } + } + query MyQuery { + message { + ...myFragment @defer + } + } + fragment myFragment on Message { + body + } + """, + [ + { + "message": "Defer directive not supported" + " on subscription operations." + " Disable `@defer` by setting the `if` argument to `false`.", + "locations": [(4, 31)], + }, + ], + ) + + def defer_fragment_spread_with_invalid_if_argument(): + assert_errors( + """ + subscription MySubscription { + subscriptionField { + ...myFragment @defer(if: "Oops") + } + } + fragment myFragment on Message { + body + } + """, + [ + { + "message": "Defer directive not supported" + " on subscription operations." + " Disable `@defer` by setting the `if` argument to `false`.", + "locations": [(4, 31)], + }, + ], + ) + + def stream_on_query_field(): + assert_valid( + """ + { + messages @stream { + name + } + } + """ + ) + + def stream_on_mutation_field(): + assert_valid( + """ + mutation { + mutationField { + messages @stream + } + } + """ + ) + + def stream_on_fragment_on_mutation_field(): + assert_valid( + """ + mutation { + mutationField { + ...myFragment + } + } + fragment myFragment on Message { + messages @stream + } + """ + ) + + def stream_on_subscription_field(): + assert_errors( + """ + subscription { + subscriptionField { + messages @stream + } + } + """, + [ + { + "message": "Stream directive not supported" + " on subscription operations." + " Disable `@defer` by setting the `if` argument to `false`.", + "locations": [(4, 26)], + }, + ], + ) + + def stream_on_fragment_on_subscription_field(): + assert_errors( + """ + subscription { + subscriptionField { + ...myFragment + } + } + fragment myFragment on Message { + messages @stream + } + """, + [ + { + "message": "Stream directive not supported" + " on subscription operations." + " Disable `@defer` by setting the `if` argument to `false`.", + "locations": [(8, 24)], + }, + ], + ) + + def stream_on_fragment_on_query_in_multi_operation_document(): + assert_valid( + """ + subscription MySubscription { + subscriptionField { + message + } + } + query MyQuery { + message { + ...myFragment + } + } + fragment myFragment on Message { + messages @stream + } + """ + ) + + def stream_on_subscription_in_multi_operation_document(): + assert_errors( + """ + query MyQuery { + message { + ...myFragment + } + } + subscription MySubscription { + subscriptionField { + message { + ...myFragment + } + } + } + fragment myFragment on Message { + messages @stream + } + """, + [ + { + "message": "Stream directive not supported" + " on subscription operations." + " Disable `@defer` by setting the `if` argument to `false`.", + "locations": [(15, 24)], + }, + ], + ) + + def stream_with_boolean_false_if_argument(): + assert_valid( + """ + subscription { + subscriptionField { + ...myFragment @stream(if:false) + } + } + """ + ) + + def stream_with_two_arguments(): + assert_valid( + """ + subscription { + subscriptionField { + ...myFragment @stream(foo:false,if:false) + } + } + """ + ) + + def stream_with_variable_argument(): + assert_valid( + """ + subscription ($stream: boolean!) { + subscriptionField { + ...myFragment @stream(if:$stream) + } + } + """ + ) + + def other_directive_on_subscription_field(): + assert_valid( + """ + subscription { + subscriptionField { + ...myFragment @foo + } + } + """ + )