diff --git a/CHANGELOG.md b/CHANGELOG.md index 60d357919..9f648eea2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,8 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) ### Fixed +- Fixes empty request body on retry with compression enabled ([#543](https://github.com/opensearch-project/opensearch-go/pull/543)) + ### Security ### Dependencies diff --git a/opensearchtransport/opensearchtransport.go b/opensearchtransport/opensearchtransport.go index ac8ad057e..c93b2cd5c 100644 --- a/opensearchtransport/opensearchtransport.go +++ b/opensearchtransport/opensearchtransport.go @@ -245,7 +245,9 @@ func (c *Client) Perform(req *http.Request) (*http.Response, error) { } req.GetBody = func() (io.ReadCloser, error) { - return io.NopCloser(buf), nil + // We have to return a new reader each time so that retries don't read from an already-consumed body. + reader := bytes.NewReader(buf.Bytes()) + return io.NopCloser(reader), nil } //nolint:errcheck // error is always nil req.Body, _ = req.GetBody() @@ -258,8 +260,9 @@ func (c *Client) Perform(req *http.Request) (*http.Response, error) { //nolint:errcheck // ignored as this is only for logging buf.ReadFrom(req.Body) req.GetBody = func() (io.ReadCloser, error) { - r := buf - return io.NopCloser(&r), nil + // Return a new reader each time + reader := bytes.NewReader(buf.Bytes()) + return io.NopCloser(reader), nil } //nolint:errcheck // error is always nil req.Body, _ = req.GetBody() diff --git a/opensearchtransport/opensearchtransport_internal_test.go b/opensearchtransport/opensearchtransport_internal_test.go index fbe1f26a3..6daaeba06 100644 --- a/opensearchtransport/opensearchtransport_internal_test.go +++ b/opensearchtransport/opensearchtransport_internal_test.go @@ -689,6 +689,47 @@ func TestTransportPerformRetries(t *testing.T) { } }) + t.Run("Reset request body during retry with request body compression", func(t *testing.T) { + var bodies []string + u, _ := url.Parse("https://foo.com/bar") + tp, _ := New( + Config{ + URLs: []*url.URL{u}, + CompressRequestBody: true, + Transport: &mockTransp{ + RoundTripFunc: func(req *http.Request) (*http.Response, error) { + body, err := io.ReadAll(req.Body) + if err != nil { + panic(err) + } + bodies = append(bodies, string(body)) + return &http.Response{Status: "MOCK", StatusCode: http.StatusBadGateway}, nil + }, + }, + }, + ) + + foobar := "FOOBAR" + foobarGzipped := "\x1f\x8b\b\x00\x00\x00\x00\x00\x00\xffr\xf3\xf7wr\f\x02\x04\x00\x00\xff\xff\x13\xd8\x0en\x06\x00\x00\x00" + + req, _ := http.NewRequest(http.MethodPost, "/abc", strings.NewReader(foobar)) + //nolint:bodyclose // Mock response does not have a body to close + res, err := tp.Perform(req) + if err != nil { + t.Fatalf("Unexpected error: %s", err) + } + _ = res + + if n := len(bodies); n != 4 { + t.Fatalf("expected 4 requests, got %d", n) + } + for i, body := range bodies { + if body != foobarGzipped { + t.Fatalf("request %d body: expected %q, got %q", i, foobarGzipped, body) + } + } + }) + t.Run("Don't retry request on regular error", func(t *testing.T) { var i int