From 7c0847aa2ae15b4442ab0625d8a780ed684c275e Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Fri, 9 Aug 2024 20:00:01 +0000 Subject: [PATCH] fix: deserialization of struct unions that implement json.Unmarshaler (#11) --- internal/apijson/decoder.go | 4 ++- internal/apijson/json_test.go | 62 +++++++++++++++++++++++++++++++++++ internal/apijson/registry.go | 4 +++ 3 files changed, 69 insertions(+), 1 deletion(-) diff --git a/internal/apijson/decoder.go b/internal/apijson/decoder.go index e1b21b7..68b7ed6 100644 --- a/internal/apijson/decoder.go +++ b/internal/apijson/decoder.go @@ -143,7 +143,9 @@ func (d *decoderBuilder) newTypeDecoder(t reflect.Type) decoderFunc { return unmarshalerDecoder } if !d.root && reflect.PointerTo(t).Implements(reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()) { - return indirectUnmarshalerDecoder + if _, ok := unionVariants[t]; !ok { + return indirectUnmarshalerDecoder + } } d.root = false diff --git a/internal/apijson/json_test.go b/internal/apijson/json_test.go index 72bc4c2..85cd2b5 100644 --- a/internal/apijson/json_test.go +++ b/internal/apijson/json_test.go @@ -261,6 +261,59 @@ func init() { ) } +type MarshallingUnionStruct struct { + Union MarshallingUnion +} + +func (r *MarshallingUnionStruct) UnmarshalJSON(data []byte) (err error) { + *r = MarshallingUnionStruct{} + err = UnmarshalRoot(data, &r.Union) + return +} + +func (r MarshallingUnionStruct) MarshalJSON() (data []byte, err error) { + return MarshalRoot(r.Union) +} + +type MarshallingUnion interface { + marshallingUnion() +} + +type MarshallingUnionA struct { + Boo string `json:"boo"` +} + +func (MarshallingUnionA) marshallingUnion() {} + +func (r *MarshallingUnionA) UnmarshalJSON(data []byte) (err error) { + return UnmarshalRoot(data, r) +} + +type MarshallingUnionB struct { + Foo string `json:"foo"` +} + +func (MarshallingUnionB) marshallingUnion() {} + +func (r *MarshallingUnionB) UnmarshalJSON(data []byte) (err error) { + return UnmarshalRoot(data, r) +} + +func init() { + RegisterUnion( + reflect.TypeOf((*MarshallingUnion)(nil)).Elem(), + "", + UnionVariant{ + TypeFilter: gjson.JSON, + Type: reflect.TypeOf(MarshallingUnionA{}), + }, + UnionVariant{ + TypeFilter: gjson.JSON, + Type: reflect.TypeOf(MarshallingUnionB{}), + }, + ) +} + var tests = map[string]struct { buf string val interface{} @@ -489,6 +542,15 @@ var tests = map[string]struct { ComplexUnionStruct{Union: ComplexUnionTypeB{Baz: 12, Type: TypeB("b")}}, }, + "marshalling_union_a": { + `{"boo":"hello"}`, + MarshallingUnionStruct{Union: MarshallingUnionA{Boo: "hello"}}, + }, + "marshalling_union_b": { + `{"foo":"hi"}`, + MarshallingUnionStruct{Union: MarshallingUnionB{Foo: "hi"}}, + }, + "unmarshal": { `{"foo":"hello"}`, &UnmarshalStruct{Foo: "hello", prop: true}, diff --git a/internal/apijson/registry.go b/internal/apijson/registry.go index fcc518b..2ea00ae 100644 --- a/internal/apijson/registry.go +++ b/internal/apijson/registry.go @@ -13,6 +13,7 @@ type UnionVariant struct { } var unionRegistry = map[reflect.Type]unionEntry{} +var unionVariants = map[reflect.Type]interface{}{} type unionEntry struct { discriminatorKey string @@ -24,4 +25,7 @@ func RegisterUnion(typ reflect.Type, discriminator string, variants ...UnionVari discriminatorKey: discriminator, variants: variants, } + for _, variant := range variants { + unionVariants[variant.Type] = typ + } }