From 162941e9e1ddfcc63c782969d3c9c981e2e2e68e Mon Sep 17 00:00:00 2001 From: "Daniel G. Taylor" Date: Mon, 23 Oct 2023 16:48:45 -0700 Subject: [PATCH] fix: backport Huma v2 example pointer fix from #148 --- schema/schema.go | 21 ++++++++++++++++++--- schema/schema_test.go | 3 +++ 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/schema/schema.go b/schema/schema.go index 2e25ad93..d49a2e36 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -67,11 +67,18 @@ func F(value float64) *float64 { return &value } +func deref(t reflect.Type) reflect.Type { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + return t +} + // getTagValue returns a value of the schema's type for the given tag string. // Uses JSON parsing if the schema is not a string. func getTagValue(s *Schema, t reflect.Type, value string) (interface{}, error) { // Special case: strings don't need quotes. - if s.Type == TypeString { + if s.Type == TypeString || (t.Kind() == reflect.Pointer && t.Elem().Kind() == reflect.String) { return value, nil } @@ -105,11 +112,19 @@ func getTagValue(s *Schema, t reflect.Type, value string) (interface{}, error) { tmp = reflect.Append(tmp, vv.Index(i).Elem().Convert(t.Elem())) } v = tmp.Interface() - } else if !tv.ConvertibleTo(t) { + } else if !tv.ConvertibleTo(deref(t)) { return nil, fmt.Errorf("unable to convert %v to %v: %w", tv, t, ErrSchemaInvalid) } - v = reflect.ValueOf(v).Convert(t).Interface() + converted := reflect.ValueOf(v).Convert(deref(t)) + if t.Kind() == reflect.Ptr { + // Special case: if the field is a pointer, we need to get a pointer + // to the converted value. + tmp := reflect.New(t.Elem()) + tmp.Elem().Set(converted) + converted = tmp + } + v = converted.Interface() } return v, nil diff --git a/schema/schema_test.go b/schema/schema_test.go index 3f7fa167..97243060 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -145,11 +145,14 @@ func TestSchemaDefault(t *testing.T) { func TestSchemaExample(t *testing.T) { type Example struct { Foo string `json:"foo" example:"ex"` + Bar *int64 `json:"bar" example:"5"` } s, err := Generate(reflect.ValueOf(Example{}).Type()) assert.NoError(t, err) assert.Equal(t, "ex", s.Properties["foo"].Example) + ex := int64(5) + assert.Equal(t, &ex, s.Properties["bar"].Example) } func TestSchemaNullable(t *testing.T) {