Skip to content

Commit

Permalink
Merge pull request #161 from danielgtaylor/resolver-path-fix
Browse files Browse the repository at this point in the history
fix: better paths for resolvers, prevent subtle mistakes
  • Loading branch information
danielgtaylor authored Nov 1, 2023
2 parents ba0e5c6 + 33f7b64 commit 8ed9f46
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 15 deletions.
2 changes: 2 additions & 0 deletions chain.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package huma

// Middlewares is a list of middleware functions that can be attached to an
// API and will be called for all incoming requests.
type Middlewares []func(ctx Context, next func(Context))

// Handler builds and returns a handler func from the chain of middlewares,
Expand Down
48 changes: 46 additions & 2 deletions huma.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,15 @@ func findParams(registry Registry, op *Operation, t reflect.Type) *findResult[*p
return nil
}

if f.Type.Kind() == reflect.Pointer {
// TODO: support pointers? The problem is that when we dynamically
// create an instance of the input struct the `params.Every(...)`
// call cannot set them as the value is `reflect.Invalid` unless
// dynamically allocated, but we don't know when to allocate until
// after the `Every` callback has run. Doable, but a bigger change.
panic("pointers are not supported for path/query/header parameters")
}

pfi := &paramFieldInfo{
Type: f.Type,
Schema: SchemaFromField(registry, f, ""),
Expand Down Expand Up @@ -171,6 +180,9 @@ func findResolvers(resolverType, t reflect.Type) *findResult[bool] {
func findDefaults(t reflect.Type) *findResult[any] {
return findInType(t, nil, func(sf reflect.StructField, i []int) any {
if d := sf.Tag.Get("default"); d != "" {
if sf.Type.Kind() == reflect.Pointer {
panic("pointers cannot have default values")
}
return jsonTagValue(sf, sf.Type, d)
}
return nil
Expand Down Expand Up @@ -210,6 +222,12 @@ type findResult[T comparable] struct {
}

func (r *findResult[T]) every(current reflect.Value, path []int, v T, f func(reflect.Value, T)) {
if current.Kind() == reflect.Invalid {
// Indirect from below may have resulted in no value, for example
// an optional field may have been omitted; just ignore it.
return
}

if len(path) == 0 {
f(current, v)
return
Expand Down Expand Up @@ -246,19 +264,45 @@ func jsonName(field reflect.StructField) string {
}

func (r *findResult[T]) everyPB(current reflect.Value, path []int, pb *PathBuffer, v T, f func(reflect.Value, T)) {
if current.Kind() == reflect.Invalid {
// Indirect from below may have resulted in no value, for example
// an optional field may have been omitted; just ignore it.
return
}
switch current.Kind() {
case reflect.Struct:
if len(path) == 0 {
f(current, v)
return
}
field := current.Type().Field(path[0])
pops := 0
if !field.Anonymous {
// The path name can come from one of four places: path parameter,
// query parameter, header parameter, or body field.
// TODO: pre-compute type/field names? Could save a few allocations.
pb.Push(jsonName(field))
pops++
if path := field.Tag.Get("path"); path != "" && pb.Len() == 0 {
pb.Push("path")
pb.Push(path)
pops++
} else if query := field.Tag.Get("query"); query != "" && pb.Len() == 0 {
pb.Push("query")
pb.Push(query)
pops++
} else if header := field.Tag.Get("header"); header != "" && pb.Len() == 0 {
pb.Push("header")
pb.Push(header)
pops++
} else {
// The body is _always_ in a field called "Body", which turns into
// `body` in the path buffer, so we don't need to push it separately
// like the the params fields above.
pb.Push(jsonName(field))
}
}
r.everyPB(reflect.Indirect(current.Field(path[0])), path[1:], pb, v, f)
if !field.Anonymous {
for i := 0; i < pops; i++ {
pb.Pop()
}
case reflect.Slice:
Expand Down
101 changes: 88 additions & 13 deletions huma_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -612,18 +612,39 @@ func TestOpenAPI(t *testing.T) {
}
}

type IntNot3 int

func (i IntNot3) Resolve(ctx huma.Context, prefix *huma.PathBuffer) []error {
if i != 0 && i%3 == 0 {
return []error{&huma.ErrorDetail{
Location: prefix.String(),
Message: "Value cannot be a multiple of three",
Value: i,
}}
}
return nil
}

var _ huma.ResolverWithPath = (*IntNot3)(nil)

type ExhaustiveErrorsInputBody struct {
Name string `json:"name" maxLength:"10"`
Count int `json:"count" minimum:"1"`
Name string `json:"name" maxLength:"10"`
Count IntNot3 `json:"count" minimum:"1"`

// Having a pointer which is never loaded should not cause
// the tests to fail when running resolvers.
Ptr *IntNot3 `json:"ptr,omitempty" minimum:"1"`
}

func (b *ExhaustiveErrorsInputBody) Resolve(ctx huma.Context) []error {
return []error{fmt.Errorf("body resolver error")}
}

type ExhaustiveErrorsInput struct {
ID string `path:"id" maxLength:"5"`
Body ExhaustiveErrorsInputBody `json:"body"`
ID IntNot3 `path:"id" maximum:"10"`
Query IntNot3 `query:"query"`
Header IntNot3 `header:"header"`
Body ExhaustiveErrorsInputBody `json:"body"`
}

func (i *ExhaustiveErrorsInput) Resolve(ctx huma.Context) []error {
Expand All @@ -634,21 +655,21 @@ func (i *ExhaustiveErrorsInput) Resolve(ctx huma.Context) []error {
}}
}

type ExhaustiveErrorsOutput struct {
}
var _ huma.Resolver = (*ExhaustiveErrorsInput)(nil)

func TestExhaustiveErrors(t *testing.T) {
r, app := humatest.New(t, huma.DefaultConfig("Test API", "1.0.0"))
huma.Register(app, huma.Operation{
OperationID: "test",
Method: http.MethodPut,
Path: "/errors/{id}",
}, func(ctx context.Context, input *ExhaustiveErrorsInput) (*ExhaustiveErrorsOutput, error) {
return &ExhaustiveErrorsOutput{}, nil
}, func(ctx context.Context, input *ExhaustiveErrorsInput) (*struct{}, error) {
return nil, nil
})

req, _ := http.NewRequest(http.MethodPut, "/errors/123456", strings.NewReader(`{"name": "12345678901", "count": 0}`))
req, _ := http.NewRequest(http.MethodPut, "/errors/15?query=3", strings.NewReader(`{"name": "12345678901", "count": -6}`))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Header", "3")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
assert.Equal(t, http.StatusUnprocessableEntity, w.Code)
Expand All @@ -659,23 +680,39 @@ func TestExhaustiveErrors(t *testing.T) {
"detail": "validation failed",
"errors": [
{
"message": "expected length <= 5",
"message": "expected number <= 10",
"location": "path.id",
"value": "123456"
"value": 15
}, {
"message": "expected length <= 10",
"location": "body.name",
"value": "12345678901"
}, {
"message": "expected number >= 1",
"location": "body.count",
"value": 0
"value": -6
}, {
"message": "input resolver error",
"location": "path.id",
"value": "123456"
"value": 15
}, {
"message": "Value cannot be a multiple of three",
"location": "path.id",
"value": 15
}, {
"message": "Value cannot be a multiple of three",
"location": "query.query",
"value": 3
}, {
"message": "Value cannot be a multiple of three",
"location": "header.header",
"value": 3
}, {
"message": "body resolver error"
}, {
"message": "Value cannot be a multiple of three",
"location": "body.count",
"value": -6
}
]
}`, w.Body.String())
Expand Down Expand Up @@ -745,6 +782,44 @@ func TestResolverCustomStatus(t *testing.T) {
assert.Contains(t, w.Body.String(), "nope")
}

func TestParamPointerPanics(t *testing.T) {
// For now we don't support these, so we panic rather than have subtle
// bugs that are hard to track down.
_, app := humatest.New(t, huma.DefaultConfig("Test API", "1.0.0"))

assert.Panics(t, func() {
huma.Register(app, huma.Operation{
OperationID: "bug",
Method: http.MethodGet,
Path: "/bug",
}, func(ctx context.Context, input *struct {
Param *string `query:"param"`
}) (*struct{}, error) {
return nil, nil
})
})
}

func TestPointerDefaultPanics(t *testing.T) {
// For now we don't support these, so we panic rather than have subtle
// bugs that are hard to track down.
_, app := humatest.New(t, huma.DefaultConfig("Test API", "1.0.0"))

assert.Panics(t, func() {
huma.Register(app, huma.Operation{
OperationID: "bug",
Method: http.MethodGet,
Path: "/bug",
}, func(ctx context.Context, input *struct {
Body struct {
Value *string `json:"value,omitempty" default:"foo"`
}
}) (*struct{}, error) {
return nil, nil
})
})
}

func BenchmarkSecondDecode(b *testing.B) {
type MediumSized struct {
ID int `json:"id"`
Expand Down

0 comments on commit 8ed9f46

Please sign in to comment.