Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for generic unions #3515

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Release type: minor

Attempt to merge union types during schema conversion.
21 changes: 15 additions & 6 deletions strawberry/schema/name_converter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, List, Optional, Union, cast
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast
from typing_extensions import Protocol

from strawberry.directive import StrawberryDirective
Expand Down Expand Up @@ -107,8 +107,14 @@ def from_union(self, union: StrawberryUnion) -> str:
return union.graphql_name

name = ""
types: Tuple[StrawberryType, ...] = union.types

for type_ in union.types:
if union.concrete_of and union.concrete_of.graphql_name:
concrete_of_types = set(union.concrete_of.types)

types = tuple(type_ for type_ in types if type_ not in concrete_of_types)

for type_ in types:
if isinstance(type_, LazyType):
type_ = cast("StrawberryType", type_.resolve_type()) # noqa: PLW2901

Expand All @@ -121,6 +127,9 @@ def from_union(self, union: StrawberryUnion) -> str:

name += type_name

if union.concrete_of and union.concrete_of.graphql_name:
name += union.concrete_of.graphql_name

return name

def from_generic(
Expand All @@ -133,12 +142,12 @@ def from_generic(
names: List[str] = []

for type_ in types:
name = self.get_from_type(type_)
name = self.get_name_from_type(type_)
names.append(name)

return "".join(names) + generic_type_name

def get_from_type(self, type_: Union[StrawberryType, type]) -> str:
def get_name_from_type(self, type_: Union[StrawberryType, type]) -> str:
type_ = eval_type(type_)

if isinstance(type_, LazyType):
Expand All @@ -148,9 +157,9 @@ def get_from_type(self, type_: Union[StrawberryType, type]) -> str:
elif isinstance(type_, StrawberryUnion):
name = type_.graphql_name if type_.graphql_name else self.from_union(type_)
elif isinstance(type_, StrawberryList):
name = self.get_from_type(type_.of_type) + "List"
name = self.get_name_from_type(type_.of_type) + "List"
elif isinstance(type_, StrawberryOptional):
name = self.get_from_type(type_.of_type) + "Optional"
name = self.get_name_from_type(type_.of_type) + "Optional"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed the name of this function 😊

elif hasattr(type_, "_scalar_definition"):
strawberry_type = type_._scalar_definition

Expand Down
10 changes: 8 additions & 2 deletions strawberry/schema/schema_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,14 +865,20 @@ def from_union(self, union: StrawberryUnion) -> GraphQLUnionType:
return graphql_union

graphql_types: List[GraphQLObjectType] = []

for type_ in union.types:
graphql_type = self.from_type(type_)

if isinstance(graphql_type, GraphQLInputObjectType):
raise InvalidTypeInputForUnion(graphql_type)
assert isinstance(graphql_type, GraphQLObjectType)
assert isinstance(graphql_type, (GraphQLObjectType, GraphQLUnionType))

graphql_types.append(graphql_type)
# If the graphql_type is a GraphQLUnionType, merge its child types
if isinstance(graphql_type, GraphQLUnionType):
enoua5 marked this conversation as resolved.
Show resolved Hide resolved
# Add the child types of the GraphQLUnionType to the list of graphql_types
graphql_types.extend(graphql_type.types)
else:
graphql_types.append(graphql_type)

graphql_union = GraphQLUnionType(
name=union_name,
Expand Down
7 changes: 6 additions & 1 deletion strawberry/types/union.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
self.directives = directives
self._source_file = None
self._source_line = None
self.concrete_of: Optional[StrawberryUnion] = None

def __eq__(self, other: object) -> bool:
if isinstance(other, StrawberryType):
Expand Down Expand Up @@ -139,6 +140,7 @@ def copy_with(
return self

new_types = []

for type_ in self.types:
new_type: Union[StrawberryType, type]

Expand All @@ -154,10 +156,13 @@ def copy_with(

new_types.append(new_type)

return StrawberryUnion(
new_union = StrawberryUnion(
type_annotations=tuple(map(StrawberryAnnotation, new_types)),
description=self.description,
)
new_union.concrete_of = self

return new_union

def __call__(self, *args: str, **kwargs: Any) -> NoReturn:
"""Do not use.
Expand Down
175 changes: 175 additions & 0 deletions tests/schema/test_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,3 +850,178 @@ class Query:

assert not result.errors
assert result.data["something"] == {"__typename": "A", "a": 5}


def test_generic_union_with_annotated():
@strawberry.type
class SomeType:
id: strawberry.ID
name: str

@strawberry.type
class NotFoundError:
id: strawberry.ID
message: str

T = TypeVar("T")

@strawberry.type
class ObjectQueries(Generic[T]):
@strawberry.field
def by_id(
self, id: strawberry.ID
) -> Annotated[Union[T, NotFoundError], strawberry.union("ByIdResult")]: ...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to clarify, this is quite different from the use case below, as this creates a generic union 😊


@strawberry.type
class Query:
@strawberry.field
def some_type_queries(self, id: strawberry.ID) -> ObjectQueries[SomeType]:
raise NotImplementedError()

schema = strawberry.Schema(Query)

assert (
str(schema)
== textwrap.dedent(
"""
type NotFoundError {
id: ID!
message: String!
}

type Query {
someTypeQueries(id: ID!): SomeTypeObjectQueries!
}

type SomeType {
id: ID!
name: String!
}

union SomeTypeByIdResult = SomeType | NotFoundError

type SomeTypeObjectQueries {
byId(id: ID!): SomeTypeByIdResult!
}
"""
).strip()
)


def test_generic_union_with_annotated_inside():
@strawberry.type
class SomeType:
id: strawberry.ID
name: str

@strawberry.type
class NotFoundError:
id: strawberry.ID
message: str

T = TypeVar("T")

@strawberry.type
class ObjectQueries(Generic[T]):
@strawberry.field
def by_id(
self, id: strawberry.ID
) -> Union[T, Annotated[NotFoundError, strawberry.union("ByIdResult")]]: ...

@strawberry.type
class Query:
@strawberry.field
def some_type_queries(self, id: strawberry.ID) -> ObjectQueries[SomeType]: ...

schema = strawberry.Schema(Query)

assert (
str(schema)
== textwrap.dedent(
"""
type NotFoundError {
id: ID!
message: String!
}

type Query {
someTypeQueries(id: ID!): SomeTypeObjectQueries!
}

type SomeType {
id: ID!
name: String!
}

union SomeTypeByIdResult = SomeType | NotFoundError

type SomeTypeObjectQueries {
byId(id: ID!): SomeTypeByIdResult!
}
"""
).strip()
)


def test_annoted_union_with_two_generics():
@strawberry.type
class SomeType:
a: str

@strawberry.type
class OtherType:
b: str

@strawberry.type
class NotFoundError:
message: str

T = TypeVar("T")
U = TypeVar("U")

@strawberry.type
class UnionObjectQueries(Generic[T, U]):
@strawberry.field
def by_id(
self, id: strawberry.ID
) -> Union[
T, Annotated[Union[U, NotFoundError], strawberry.union("ByIdResult")]
]: ...

@strawberry.type
class Query:
@strawberry.field
def some_type_queries(
self, id: strawberry.ID
) -> UnionObjectQueries[SomeType, OtherType]: ...

schema = strawberry.Schema(Query)

assert (
str(schema)
== textwrap.dedent(
"""
type NotFoundError {
message: String!
}

type OtherType {
b: String!
}

type Query {
someTypeQueries(id: ID!): SomeTypeOtherTypeUnionObjectQueries!
}

type SomeType {
a: String!
}

union SomeTypeOtherTypeByIdResult = SomeType | OtherType | NotFoundError

type SomeTypeOtherTypeUnionObjectQueries {
byId(id: ID!): SomeTypeOtherTypeByIdResult!
}
"""
).strip()
)
Loading