Skip to content

Commit

Permalink
fix: handle timeout and unexpected eof error by returning 408 or 499 (#…
Browse files Browse the repository at this point in the history
…64)

Co-authored-by: John Levey <john.levey@grafana.com>
  • Loading branch information
mar4uk and leveyjam authored Jul 28, 2023
1 parent 6b2a4f9 commit 5807bcd
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 4 deletions.
36 changes: 32 additions & 4 deletions pkg/server/middleware/request_limits.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,22 @@ package middleware

import (
"bytes"
"context"
"fmt"
"io"
"net"
"net/http"

"github.com/pkg/errors"

"github.com/go-kit/log/level"
"github.com/grafana/mimir/pkg/util/spanlogger"
)

const (
StatusClientClosedRequest = 499
)

type RequestLimits struct {
maxRequestBodySize int64
}
Expand All @@ -28,13 +36,23 @@ func (l RequestLimits) Wrap(next http.Handler) http.Handler {
reader := io.LimitReader(r.Body, int64(l.maxRequestBodySize)+1)
body, err := io.ReadAll(reader)
if err != nil {
level.Warn(log).Log("msg", "failed to read request body", "err", err)
http.Error(w, fmt.Sprintf("failed to read request body: %v", err), http.StatusInternalServerError)
return
_ = level.Warn(log).Log("msg", "failed to read request body", "err", err)

switch {
case isNetworkError(err):
http.Error(w, fmt.Sprintf("failed to read request body: %v", err), http.StatusRequestTimeout)
return
case errors.Is(err, context.Canceled) || errors.Is(err, io.ErrUnexpectedEOF):
http.Error(w, fmt.Sprintf("failed to read request body: %v", err), StatusClientClosedRequest)
return
default:
http.Error(w, fmt.Sprintf("failed to read request body: %v", err), http.StatusInternalServerError)
return
}
}
if int64(len(body)) > l.maxRequestBodySize {
msg := fmt.Sprintf("trying to send message larger than max (%d vs %d)", len(body), l.maxRequestBodySize)
level.Warn(log).Log("msg", msg)
_ = level.Warn(log).Log("msg", msg)
http.Error(w, msg, http.StatusRequestEntityTooLarge)
return
}
Expand All @@ -44,3 +62,13 @@ func (l RequestLimits) Wrap(next http.Handler) http.Handler {
next.ServeHTTP(w, r)
})
}

// isNetworkError determines if an error is caused by a network timeout
func isNetworkError(err error) bool {
if err == nil {
return false
}

netErr, ok := errors.Cause(err).(net.Error)
return ok && netErr.Timeout()
}
71 changes: 71 additions & 0 deletions pkg/server/middleware/request_limits_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@ package middleware

import (
"bytes"
"errors"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)

func TestRequestLimitsMiddleware(t *testing.T) {
Expand Down Expand Up @@ -60,3 +63,71 @@ func TestRequestLimitsMiddleware(t *testing.T) {
})
}
}

type errReader struct {
mock.Mock
}

func (m errReader) Read(p []byte) (n int, err error) {
args := m.Called(p)
return args.Int(0), args.Error(1)
}

type TimeoutError struct {
error
}

func (e TimeoutError) Timeout() bool {
return true
}

func (e TimeoutError) Temporary() bool {
return true
}

func (e TimeoutError) Error() string {
return ""
}

func TestRequestLimitsMiddlewareReadError(t *testing.T) {
for _, tc := range []struct {
name string
readerErr error
expectedStatus int
}{
{
name: "in case of unexpected EOF should return 499",
readerErr: io.ErrUnexpectedEOF,
expectedStatus: StatusClientClosedRequest,
},
{
name: "in case of timeout error should return 408",
readerErr: new(TimeoutError),
expectedStatus: http.StatusRequestTimeout,
},
{
name: "in case other errors should return 500",
readerErr: errors.New("other error"),
expectedStatus: http.StatusInternalServerError,
},
} {
middleware := NewRequestLimitsMiddleware(1 * mb)
handler := middleware.Wrap(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))

reader := new(errReader)
reader.Mock.On("Read", mock.Anything).Return(0, tc.readerErr)

req := httptest.NewRequest(
"GET",
"https://example.com",
reader,
)
resp := httptest.NewRecorder()

handler.ServeHTTP(resp, req)

assert.Equal(t, tc.expectedStatus, resp.Code)
}
}

0 comments on commit 5807bcd

Please sign in to comment.