From 7eecc892abf2f2a3dd0101cb2419fcc3b9538aaf Mon Sep 17 00:00:00 2001
From: zhyu <angellwings@gmail.com>
Date: Wed, 8 May 2024 07:25:54 +0000
Subject: [PATCH] fix request context cancellation is ignored when retryBackoff
 (#539)

Signed-off-by: zhyu <angellwings@gmail.com>
---
 CHANGELOG.md                                  |  2 ++
 opensearchtransport/opensearchtransport.go    | 13 ++++++-
 .../opensearchtransport_internal_test.go      | 34 +++++++++++++++++++
 3 files changed, 48 insertions(+), 1 deletion(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 2176c81bf..4d2869fb9 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -17,6 +17,8 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
 
 ### Fixed
 
+- Fixes opensearchtransport ignores request context cancellation when `retryBackoff` is configured ([#540](https://github.com/opensearch-project/opensearch-go/pull/540))
+
 ### Security
 
 ### Dependencies
diff --git a/opensearchtransport/opensearchtransport.go b/opensearchtransport/opensearchtransport.go
index ac8ad057e..ad5f0a21a 100644
--- a/opensearchtransport/opensearchtransport.go
+++ b/opensearchtransport/opensearchtransport.go
@@ -380,7 +380,18 @@ func (c *Client) Perform(req *http.Request) (*http.Response, error) {
 
 		// Delay the retry if a backoff function is configured
 		if c.retryBackoff != nil {
-			time.Sleep(c.retryBackoff(i + 1))
+			var cancelled bool
+			timer := time.NewTimer(c.retryBackoff(i + 1))
+			select {
+			case <-req.Context().Done():
+				timer.Stop()
+				err = req.Context().Err()
+				cancelled = true
+			case <-timer.C:
+			}
+			if cancelled {
+				break
+			}
 		}
 	}
 	// Read, close and replace the http response body to close the connection
diff --git a/opensearchtransport/opensearchtransport_internal_test.go b/opensearchtransport/opensearchtransport_internal_test.go
index fbe1f26a3..3ecbd5da8 100644
--- a/opensearchtransport/opensearchtransport_internal_test.go
+++ b/opensearchtransport/opensearchtransport_internal_test.go
@@ -31,6 +31,8 @@ package opensearchtransport
 import (
 	"bytes"
 	"compress/gzip"
+	"context"
+	"errors"
 	"fmt"
 	"io"
 	"math/rand"
@@ -807,6 +809,38 @@ func TestTransportPerformRetries(t *testing.T) {
 			t.Errorf("Unexpected duration, want=>%s, got=%s", expectedDuration, end)
 		}
 	})
+
+	t.Run("Delay the retry with retry on timeout and context deadline", func(t *testing.T) {
+		var i int
+		u, _ := url.Parse("http://foo.bar")
+		tp, _ := New(Config{
+			EnableRetryOnTimeout: true,
+			MaxRetries:           100,
+			RetryBackoff:         func(i int) time.Duration { return time.Hour },
+			URLs:                 []*url.URL{u},
+			Transport: &mockTransp{
+				RoundTripFunc: func(req *http.Request) (*http.Response, error) {
+					i++
+					<-req.Context().Done()
+					return nil, req.Context().Err()
+				},
+			},
+		})
+
+		req, _ := http.NewRequest(http.MethodGet, "/abc", nil)
+		ctx, cancel := context.WithTimeout(req.Context(), 50*time.Millisecond)
+		defer cancel()
+		req = req.WithContext(ctx)
+
+		//nolint:bodyclose // Mock response does not have a body to close
+		_, err := tp.Perform(req)
+		if !errors.Is(err, context.DeadlineExceeded) {
+			t.Fatalf("expected context.DeadlineExceeded, got %s", err)
+		}
+		if i != 1 {
+			t.Fatalf("unexpected number of requests: expected 1, got %d", i)
+		}
+	})
 }
 
 func TestURLs(t *testing.T) {