diff --git a/huma.go b/huma.go index 38b6c81b..f3d6c41f 100644 --- a/huma.go +++ b/huma.go @@ -806,6 +806,15 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I) }) if len(res.Errors) > 0 { + for i := len(res.Errors) - 1; i >= 0; i-- { + // If there are errors, and they provide a status, then update the + // response status code to match. Otherwise, use the default status + // code is used. Since these run in order, the last error code wins. + if s, ok := res.Errors[i].(StatusError); ok { + errStatus = s.GetStatus() + break + } + } WriteErr(api, ctx, errStatus, "validation failed", res.Errors...) return } diff --git a/huma_test.go b/huma_test.go index 79e9c0c1..83044d90 100644 --- a/huma_test.go +++ b/huma_test.go @@ -725,6 +725,31 @@ func TestNestedResolverWithPath(t *testing.T) { assert.Contains(t, w.Body.String(), `"location":"body.field1.foo[0].field2"`) } +type ResolverCustomStatus struct{} + +func (r *ResolverCustomStatus) Resolve(ctx Context) []error { + return []error{Error403Forbidden("nope")} +} + +func TestResolverCustomStatus(t *testing.T) { + r := chi.NewRouter() + app := NewTestAdapter(r, DefaultConfig("Test API", "1.0.0")) + Register(app, Operation{ + OperationID: "test", + Method: http.MethodPut, + Path: "/test", + }, func(ctx context.Context, input *ResolverCustomStatus) (*struct{}, error) { + return nil, nil + }) + + req, _ := http.NewRequest(http.MethodPut, "/test", strings.NewReader(`{}`)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + assert.Equal(t, http.StatusForbidden, w.Code, w.Body.String()) + assert.Contains(t, w.Body.String(), "nope") +} + func BenchmarkSecondDecode(b *testing.B) { type MediumSized struct { ID int `json:"id"`