From 7eecc892abf2f2a3dd0101cb2419fcc3b9538aaf Mon Sep 17 00:00:00 2001 From: zhyu <angellwings@gmail.com> Date: Wed, 8 May 2024 07:25:54 +0000 Subject: [PATCH] fix request context cancellation is ignored when retryBackoff (#539) Signed-off-by: zhyu <angellwings@gmail.com> --- CHANGELOG.md | 2 ++ opensearchtransport/opensearchtransport.go | 13 ++++++- .../opensearchtransport_internal_test.go | 34 +++++++++++++++++++ 3 files changed, 48 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2176c81bf..4d2869fb9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,8 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) ### Fixed +- Fixes opensearchtransport ignores request context cancellation when `retryBackoff` is configured ([#540](https://github.com/opensearch-project/opensearch-go/pull/540)) + ### Security ### Dependencies diff --git a/opensearchtransport/opensearchtransport.go b/opensearchtransport/opensearchtransport.go index ac8ad057e..ad5f0a21a 100644 --- a/opensearchtransport/opensearchtransport.go +++ b/opensearchtransport/opensearchtransport.go @@ -380,7 +380,18 @@ func (c *Client) Perform(req *http.Request) (*http.Response, error) { // Delay the retry if a backoff function is configured if c.retryBackoff != nil { - time.Sleep(c.retryBackoff(i + 1)) + var cancelled bool + timer := time.NewTimer(c.retryBackoff(i + 1)) + select { + case <-req.Context().Done(): + timer.Stop() + err = req.Context().Err() + cancelled = true + case <-timer.C: + } + if cancelled { + break + } } } // Read, close and replace the http response body to close the connection diff --git a/opensearchtransport/opensearchtransport_internal_test.go b/opensearchtransport/opensearchtransport_internal_test.go index fbe1f26a3..3ecbd5da8 100644 --- a/opensearchtransport/opensearchtransport_internal_test.go +++ b/opensearchtransport/opensearchtransport_internal_test.go @@ -31,6 +31,8 @@ package opensearchtransport import ( "bytes" "compress/gzip" + "context" + "errors" "fmt" "io" "math/rand" @@ -807,6 +809,38 @@ func TestTransportPerformRetries(t *testing.T) { t.Errorf("Unexpected duration, want=>%s, got=%s", expectedDuration, end) } }) + + t.Run("Delay the retry with retry on timeout and context deadline", func(t *testing.T) { + var i int + u, _ := url.Parse("http://foo.bar") + tp, _ := New(Config{ + EnableRetryOnTimeout: true, + MaxRetries: 100, + RetryBackoff: func(i int) time.Duration { return time.Hour }, + URLs: []*url.URL{u}, + Transport: &mockTransp{ + RoundTripFunc: func(req *http.Request) (*http.Response, error) { + i++ + <-req.Context().Done() + return nil, req.Context().Err() + }, + }, + }) + + req, _ := http.NewRequest(http.MethodGet, "/abc", nil) + ctx, cancel := context.WithTimeout(req.Context(), 50*time.Millisecond) + defer cancel() + req = req.WithContext(ctx) + + //nolint:bodyclose // Mock response does not have a body to close + _, err := tp.Perform(req) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("expected context.DeadlineExceeded, got %s", err) + } + if i != 1 { + t.Fatalf("unexpected number of requests: expected 1, got %d", i) + } + }) } func TestURLs(t *testing.T) {