Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for generic unions #3515

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Release type: patch

Attempt to merge union types during schema conversion.
Comment on lines +1 to +3
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with the last change this is probably a minor now, since it's a brand new feature (plus a fix) 😊

19 changes: 14 additions & 5 deletions strawberry/schema/name_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,14 @@ def from_union(self, union: StrawberryUnion) -> str:
return union.graphql_name

name = ""
types = union.types

for type_ in union.types:
if union.concrete_of and union.concrete_of.graphql_name:
concrete_of_types = set(union.concrete_of.types)

types = [type_ for type_ in types if type_ not in concrete_of_types]

for type_ in types:
if isinstance(type_, LazyType):
type_ = cast("StrawberryType", type_.resolve_type()) # noqa: PLW2901

Expand All @@ -121,6 +127,9 @@ def from_union(self, union: StrawberryUnion) -> str:

name += type_name

if union.concrete_of and union.concrete_of.graphql_name:
name += union.concrete_of.graphql_name

return name

def from_generic(
Expand All @@ -133,12 +142,12 @@ def from_generic(
names: List[str] = []

for type_ in types:
name = self.get_from_type(type_)
name = self.get_name_from_type(type_)
names.append(name)

return "".join(names) + generic_type_name

def get_from_type(self, type_: Union[StrawberryType, type]) -> str:
def get_name_from_type(self, type_: Union[StrawberryType, type]) -> str:
type_ = eval_type(type_)

if isinstance(type_, LazyType):
Expand All @@ -148,9 +157,9 @@ def get_from_type(self, type_: Union[StrawberryType, type]) -> str:
elif isinstance(type_, StrawberryUnion):
name = type_.graphql_name if type_.graphql_name else self.from_union(type_)
elif isinstance(type_, StrawberryList):
name = self.get_from_type(type_.of_type) + "List"
name = self.get_name_from_type(type_.of_type) + "List"
elif isinstance(type_, StrawberryOptional):
name = self.get_from_type(type_.of_type) + "Optional"
name = self.get_name_from_type(type_.of_type) + "Optional"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed the name of this function 😊

elif hasattr(type_, "_scalar_definition"):
strawberry_type = type_._scalar_definition

Expand Down
10 changes: 8 additions & 2 deletions strawberry/schema/schema_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,14 +851,20 @@ def from_union(self, union: StrawberryUnion) -> GraphQLUnionType:
return graphql_union

graphql_types: List[GraphQLObjectType] = []

for type_ in union.types:
graphql_type = self.from_type(type_)

if isinstance(graphql_type, GraphQLInputObjectType):
raise InvalidTypeInputForUnion(graphql_type)
assert isinstance(graphql_type, GraphQLObjectType)
assert isinstance(graphql_type, (GraphQLObjectType, GraphQLUnionType))

graphql_types.append(graphql_type)
# If the graphql_type is a GraphQLUnionType, merge its child types
if isinstance(graphql_type, GraphQLUnionType):
enoua5 marked this conversation as resolved.
Show resolved Hide resolved
# Add the child types of the GraphQLUnionType to the list of graphql_types
graphql_types.extend(graphql_type.types)
else:
graphql_types.append(graphql_type)

graphql_union = GraphQLUnionType(
name=union_name,
Expand Down
7 changes: 6 additions & 1 deletion strawberry/union.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(
self.directives = directives
self._source_file = None
self._source_line = None
self.concrete_of: Optional[StrawberryUnion] = None

def __eq__(self, other: object) -> bool:
if isinstance(other, StrawberryType):
Expand Down Expand Up @@ -139,6 +140,7 @@ def copy_with(
return self

new_types = []

for type_ in self.types:
new_type: Union[StrawberryType, type]

Expand All @@ -154,10 +156,13 @@ def copy_with(

new_types.append(new_type)

return StrawberryUnion(
new_union = StrawberryUnion(
type_annotations=tuple(map(StrawberryAnnotation, new_types)),
description=self.description,
)
new_union.concrete_of = self

return new_union

def __call__(self, *args: str, **kwargs: Any) -> NoReturn:
"""Do not use.
Expand Down
112 changes: 112 additions & 0 deletions tests/schema/test_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,3 +856,115 @@

assert not result.errors
assert result.data["something"] == {"__typename": "A", "a": 5}


def test_generic_union_with_annotated():
@strawberry.type
class SomeType:
id: strawberry.ID
name: str

@strawberry.type
class NotFoundError:
id: strawberry.ID
message: str

T = TypeVar("T")

@strawberry.type
class ObjectQueries(Generic[T]):
@strawberry.field
def by_id(
self, id: strawberry.ID
) -> Annotated[Union[T, NotFoundError], strawberry.union("ByIdResult")]: ...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to clarify, this is quite different from the use case below, as this creates a generic union 😊


@strawberry.type
class Query:
@strawberry.field
def some_type_queries(self, id: strawberry.ID) -> ObjectQueries[SomeType]:
raise NotImplementedError()

schema = strawberry.Schema(Query)

assert (
str(schema)
== textwrap.dedent(
"""
type NotFoundError {
id: ID!
message: String!
}

type Query {
someTypeQueries(id: ID!): SomeTypeObjectQueries!
}

type SomeType {
id: ID!
name: String!
}

union SomeTypeByIdResult = SomeType | NotFoundError

type SomeTypeObjectQueries {
byId(id: ID!): SomeTypeByIdResult!
}
"""
).strip()
)


def test_generic_union_with_annotated_inside():
@strawberry.type
class SomeType:
id: strawberry.ID
name: str

@strawberry.type
class NotFoundError:
id: strawberry.ID
message: str

T = TypeVar("T")

@strawberry.type
class ObjectQueries(Generic[T]):
@strawberry.field
def by_id(
self, id: strawberry.ID
) -> Union[T, Annotated[NotFoundError, strawberry.union("ByIdResult")]]: ...

@strawberry.type
class Query:
@strawberry.field
def some_type_queries(self, id: strawberry.ID) -> ObjectQueries[SomeType]:
return ObjectQueries(SomeType)

Check warning on line 941 in tests/schema/test_union.py

View check run for this annotation

Codecov / codecov/patch

tests/schema/test_union.py#L941

Added line #L941 was not covered by tests

schema = strawberry.Schema(Query)

assert (
str(schema)
== textwrap.dedent(
"""
type NotFoundError {
id: ID!
message: String!
}

type Query {
someTypeQueries(id: ID!): SomeTypeObjectQueries!
}

type SomeType {
id: ID!
name: String!
}

union SomeTypeByIdResult = SomeType | NotFoundError

type SomeTypeObjectQueries {
byId(id: ID!): SomeTypeByIdResult!
}
"""
).strip()
)
Loading