Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for generic unions #3515

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Release type: patch

Attempt to merge union types during schema conversion.
Comment on lines +1 to +3
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with the last change this is probably a minor now, since it's a brand new feature (plus a fix) 😊

9 changes: 7 additions & 2 deletions strawberry/schema/schema_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,9 +856,14 @@ def from_union(self, union: StrawberryUnion) -> GraphQLUnionType:

if isinstance(graphql_type, GraphQLInputObjectType):
raise InvalidTypeInputForUnion(graphql_type)
assert isinstance(graphql_type, GraphQLObjectType)
assert isinstance(graphql_type, (GraphQLObjectType, GraphQLUnionType))

graphql_types.append(graphql_type)
# If the graphql_type is a GraphQLUnionType, merge its child types
if isinstance(graphql_type, GraphQLUnionType):
enoua5 marked this conversation as resolved.
Show resolved Hide resolved
# Add the child types of the GraphQLUnionType to the list of graphql_types
graphql_types.extend(graphql_type.types)
else:
graphql_types.append(graphql_type)

graphql_union = GraphQLUnionType(
name=union_name,
Expand Down
114 changes: 114 additions & 0 deletions tests/schema/test_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,3 +856,117 @@

assert not result.errors
assert result.data["something"] == {"__typename": "A", "a": 5}


@pytest.mark.xfail(reason="Not supported yet")
def test_generic_union_with_annotated():
@strawberry.type
class SomeType:
id: strawberry.ID
name: str

@strawberry.type
class NotFoundError:
id: strawberry.ID
message: str

T = TypeVar("T")

@strawberry.type
class ObjectQueries(Generic[T]):
@strawberry.field
def by_id(
self, id: strawberry.ID
) -> Annotated[Union[T, NotFoundError], strawberry.union("ByIdResult")]: ...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to clarify, this is quite different from the use case below, as this creates a generic union 😊


@strawberry.type
class Query:
@strawberry.field
def some_type_queries(self, id: strawberry.ID) -> ObjectQueries[SomeType]:
raise NotImplementedError()

schema = strawberry.Schema(Query)

# TODO: check the name
assert (
str(schema)
== textwrap.dedent(
"""
type NotFoundError {
id: ID!
message: String!
}

type Query {
someTypeQueries(id: ID!): SomeTypeByIdResult!
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type here is still wrong, we probably need to expand the annotated

}

type SomeType {
id: ID!
name: String!
}

union SomeTypeNotFoundError = SomeType | NotFoundError

type SomeTypeObjectQueries {
byId(id: ID!): SomeTypeByIdResult!
}
"""
).strip()
)


def test_generic_union_with_annotated_inside():
@strawberry.type
class SomeType:
id: strawberry.ID
name: str

@strawberry.type
class NotFoundError:
id: strawberry.ID
message: str

T = TypeVar("T")

@strawberry.type
class ObjectQueries(Generic[T]):
@strawberry.field
def by_id(
self, id: strawberry.ID
) -> Union[T, Annotated[NotFoundError, strawberry.union("ByIdResult")]]: ...

@strawberry.type
class Query:
@strawberry.field
def some_type_queries(self, id: strawberry.ID) -> ObjectQueries[SomeType]:
return ObjectQueries(SomeType)

Check warning on line 943 in tests/schema/test_union.py

View check run for this annotation

Codecov / codecov/patch

tests/schema/test_union.py#L943

Added line #L943 was not covered by tests

schema = strawberry.Schema(Query)

assert (
str(schema)
== textwrap.dedent(
"""
type NotFoundError {
id: ID!
message: String!
}

type Query {
someTypeQueries(id: ID!): SomeTypeObjectQueries!
}

type SomeType {
id: ID!
name: String!
}

union SomeTypeByIdResult = SomeType | NotFoundError

type SomeTypeObjectQueries {
byId(id: ID!): SomeTypeByIdResult!
}
"""
).strip()
)
Loading