From 2cff828f785d53d5f79d15b4c813335047c238de Mon Sep 17 00:00:00 2001 From: "Daniel G. Taylor" Date: Thu, 19 Sep 2024 08:34:06 -0700 Subject: [PATCH] fix: use status code returned from NewError when writing errors --- error.go | 6 +++++- huma_test.go | 26 ++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/error.go b/error.go index 51f091c5..ce2e91cd 100644 --- a/error.go +++ b/error.go @@ -252,7 +252,11 @@ var NewError = func(status int, msg string, errs ...error) StatusError { // configured error type and with the given status code and message. It is // marshaled using the API's content negotiation methods. func WriteErr(api API, ctx Context, status int, msg string, errs ...error) error { - var err any = NewError(status, msg, errs...) + var err = NewError(status, msg, errs...) + + // NewError may have modified the status code, so update it here if needed. + // If it was not modified then this is a no-op. + status = err.GetStatus() ct, negotiateErr := api.Negotiate(ctx.Header("Accept")) if negotiateErr != nil { diff --git a/huma_test.go b/huma_test.go index cfad795d..435df780 100644 --- a/huma_test.go +++ b/huma_test.go @@ -2298,6 +2298,32 @@ func TestResolverCustomTypePrimitive(t *testing.T) { }) } +func TestCustomValidationErrorStatus(t *testing.T) { + orig := huma.NewError + huma.NewError = func(status int, message string, errs ...error) huma.StatusError { + if status == 422 { + status = 400 + } + return orig(status, message, errs...) + } + t.Cleanup(func() { + huma.NewError = orig + }) + + _, api := humatest.New(t, huma.DefaultConfig("Test API", "1.0.0")) + huma.Post(api, "/test", func(ctx context.Context, input *struct { + Body struct { + Value string `json:"value" minLength:"5"` + } + }) (*struct{}, error) { + return nil, nil + }) + + resp := api.Post("/test", map[string]any{"value": "foo"}) + assert.Equal(t, http.StatusBadRequest, resp.Result().StatusCode) + assert.Contains(t, resp.Body.String(), "Bad Request") +} + // func BenchmarkSecondDecode(b *testing.B) { // //nolint: musttag // type MediumSized struct {