Skip to content

Commit

Permalink
fix: deserialization of struct unions that implement json.Unmarshaler (
Browse files Browse the repository at this point in the history
  • Loading branch information
stainless-app[bot] authored and stainless-bot committed Aug 9, 2024
1 parent 00f9455 commit 7c0847a
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 1 deletion.
4 changes: 3 additions & 1 deletion internal/apijson/decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,9 @@ func (d *decoderBuilder) newTypeDecoder(t reflect.Type) decoderFunc {
return unmarshalerDecoder
}
if !d.root && reflect.PointerTo(t).Implements(reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()) {
return indirectUnmarshalerDecoder
if _, ok := unionVariants[t]; !ok {
return indirectUnmarshalerDecoder
}
}
d.root = false

Expand Down
62 changes: 62 additions & 0 deletions internal/apijson/json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,59 @@ func init() {
)
}

type MarshallingUnionStruct struct {
Union MarshallingUnion
}

func (r *MarshallingUnionStruct) UnmarshalJSON(data []byte) (err error) {
*r = MarshallingUnionStruct{}
err = UnmarshalRoot(data, &r.Union)
return
}

func (r MarshallingUnionStruct) MarshalJSON() (data []byte, err error) {
return MarshalRoot(r.Union)
}

type MarshallingUnion interface {
marshallingUnion()
}

type MarshallingUnionA struct {
Boo string `json:"boo"`
}

func (MarshallingUnionA) marshallingUnion() {}

func (r *MarshallingUnionA) UnmarshalJSON(data []byte) (err error) {
return UnmarshalRoot(data, r)
}

type MarshallingUnionB struct {
Foo string `json:"foo"`
}

func (MarshallingUnionB) marshallingUnion() {}

func (r *MarshallingUnionB) UnmarshalJSON(data []byte) (err error) {
return UnmarshalRoot(data, r)
}

func init() {
RegisterUnion(
reflect.TypeOf((*MarshallingUnion)(nil)).Elem(),
"",
UnionVariant{
TypeFilter: gjson.JSON,
Type: reflect.TypeOf(MarshallingUnionA{}),
},
UnionVariant{
TypeFilter: gjson.JSON,
Type: reflect.TypeOf(MarshallingUnionB{}),
},
)
}

var tests = map[string]struct {
buf string
val interface{}
Expand Down Expand Up @@ -489,6 +542,15 @@ var tests = map[string]struct {
ComplexUnionStruct{Union: ComplexUnionTypeB{Baz: 12, Type: TypeB("b")}},
},

"marshalling_union_a": {
`{"boo":"hello"}`,
MarshallingUnionStruct{Union: MarshallingUnionA{Boo: "hello"}},
},
"marshalling_union_b": {
`{"foo":"hi"}`,
MarshallingUnionStruct{Union: MarshallingUnionB{Foo: "hi"}},
},

"unmarshal": {
`{"foo":"hello"}`,
&UnmarshalStruct{Foo: "hello", prop: true},
Expand Down
4 changes: 4 additions & 0 deletions internal/apijson/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ type UnionVariant struct {
}

var unionRegistry = map[reflect.Type]unionEntry{}
var unionVariants = map[reflect.Type]interface{}{}

type unionEntry struct {
discriminatorKey string
Expand All @@ -24,4 +25,7 @@ func RegisterUnion(typ reflect.Type, discriminator string, variants ...UnionVari
discriminatorKey: discriminator,
variants: variants,
}
for _, variant := range variants {
unionVariants[variant.Type] = typ
}
}

0 comments on commit 7c0847a

Please sign in to comment.