From 5807bcd690d5ca291d7d3306c90caeeab85f083d Mon Sep 17 00:00:00 2001 From: Irina Date: Fri, 28 Jul 2023 12:36:50 +0100 Subject: [PATCH] fix: handle timeout and unexpected eof error by returning 408 or 499 (#64) Co-authored-by: John Levey --- pkg/server/middleware/request_limits.go | 36 ++++++++-- pkg/server/middleware/request_limits_test.go | 71 ++++++++++++++++++++ 2 files changed, 103 insertions(+), 4 deletions(-) diff --git a/pkg/server/middleware/request_limits.go b/pkg/server/middleware/request_limits.go index 33d6b29..8b6ea92 100644 --- a/pkg/server/middleware/request_limits.go +++ b/pkg/server/middleware/request_limits.go @@ -2,14 +2,22 @@ package middleware import ( "bytes" + "context" "fmt" "io" + "net" "net/http" + "github.com/pkg/errors" + "github.com/go-kit/log/level" "github.com/grafana/mimir/pkg/util/spanlogger" ) +const ( + StatusClientClosedRequest = 499 +) + type RequestLimits struct { maxRequestBodySize int64 } @@ -28,13 +36,23 @@ func (l RequestLimits) Wrap(next http.Handler) http.Handler { reader := io.LimitReader(r.Body, int64(l.maxRequestBodySize)+1) body, err := io.ReadAll(reader) if err != nil { - level.Warn(log).Log("msg", "failed to read request body", "err", err) - http.Error(w, fmt.Sprintf("failed to read request body: %v", err), http.StatusInternalServerError) - return + _ = level.Warn(log).Log("msg", "failed to read request body", "err", err) + + switch { + case isNetworkError(err): + http.Error(w, fmt.Sprintf("failed to read request body: %v", err), http.StatusRequestTimeout) + return + case errors.Is(err, context.Canceled) || errors.Is(err, io.ErrUnexpectedEOF): + http.Error(w, fmt.Sprintf("failed to read request body: %v", err), StatusClientClosedRequest) + return + default: + http.Error(w, fmt.Sprintf("failed to read request body: %v", err), http.StatusInternalServerError) + return + } } if int64(len(body)) > l.maxRequestBodySize { msg := fmt.Sprintf("trying to send message larger than max (%d vs %d)", len(body), l.maxRequestBodySize) - level.Warn(log).Log("msg", msg) + _ = level.Warn(log).Log("msg", msg) http.Error(w, msg, http.StatusRequestEntityTooLarge) return } @@ -44,3 +62,13 @@ func (l RequestLimits) Wrap(next http.Handler) http.Handler { next.ServeHTTP(w, r) }) } + +// isNetworkError determines if an error is caused by a network timeout +func isNetworkError(err error) bool { + if err == nil { + return false + } + + netErr, ok := errors.Cause(err).(net.Error) + return ok && netErr.Timeout() +} diff --git a/pkg/server/middleware/request_limits_test.go b/pkg/server/middleware/request_limits_test.go index 1db4927..d8dbb93 100644 --- a/pkg/server/middleware/request_limits_test.go +++ b/pkg/server/middleware/request_limits_test.go @@ -2,12 +2,15 @@ package middleware import ( "bytes" + "errors" + "io" "net/http" "net/http/httptest" "strings" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" ) func TestRequestLimitsMiddleware(t *testing.T) { @@ -60,3 +63,71 @@ func TestRequestLimitsMiddleware(t *testing.T) { }) } } + +type errReader struct { + mock.Mock +} + +func (m errReader) Read(p []byte) (n int, err error) { + args := m.Called(p) + return args.Int(0), args.Error(1) +} + +type TimeoutError struct { + error +} + +func (e TimeoutError) Timeout() bool { + return true +} + +func (e TimeoutError) Temporary() bool { + return true +} + +func (e TimeoutError) Error() string { + return "" +} + +func TestRequestLimitsMiddlewareReadError(t *testing.T) { + for _, tc := range []struct { + name string + readerErr error + expectedStatus int + }{ + { + name: "in case of unexpected EOF should return 499", + readerErr: io.ErrUnexpectedEOF, + expectedStatus: StatusClientClosedRequest, + }, + { + name: "in case of timeout error should return 408", + readerErr: new(TimeoutError), + expectedStatus: http.StatusRequestTimeout, + }, + { + name: "in case other errors should return 500", + readerErr: errors.New("other error"), + expectedStatus: http.StatusInternalServerError, + }, + } { + middleware := NewRequestLimitsMiddleware(1 * mb) + handler := middleware.Wrap(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + reader := new(errReader) + reader.Mock.On("Read", mock.Anything).Return(0, tc.readerErr) + + req := httptest.NewRequest( + "GET", + "https://example.com", + reader, + ) + resp := httptest.NewRecorder() + + handler.ServeHTTP(resp, req) + + assert.Equal(t, tc.expectedStatus, resp.Code) + } +}