diff --git a/client.go b/client.go index edf5599..c1ae0f1 100644 --- a/client.go +++ b/client.go @@ -73,7 +73,7 @@ func NewClient(configs ...Option) (*Client, error) { Transport: config.roundTripper(), } - client, err := internal.NewHttpClient(httpClient, config.baseUrl) + client, err := internal.NewHttpClient(httpClient, config.baseUrl, config.logger) if err != nil { return nil, err } diff --git a/internal/http_client.go b/internal/http_client.go index e16f8e2..be441b7 100644 --- a/internal/http_client.go +++ b/internal/http_client.go @@ -4,50 +4,108 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "net/http" "net/url" + "strconv" + "time" + + "github.com/avast/retry-go/v4" +) + +const ( + defaultWindowLimit = 400 + defaultWindowDuration = 1 * time.Minute + + headerRateLimitRemaining = "X-Rate-Limit-Remaining" ) type HttpClient struct { - client *http.Client - baseUrl *url.URL + client *http.Client + baseUrl *url.URL + rateLimiter RateLimiter + retryEnabled bool + retryMaxDelay time.Duration + retryDelay time.Duration + retryMaxAttempts uint + logger Log } -func NewHttpClient(client *http.Client, baseUrl string) (*HttpClient, error) { +func NewHttpClient(client *http.Client, baseUrl string, logger Log) (*HttpClient, error) { parsed, err := url.Parse(baseUrl) if err != nil { return nil, err } - return &HttpClient{client: client, baseUrl: parsed}, nil + + return &HttpClient{ + client: client, + baseUrl: parsed, + rateLimiter: newFixedWindowCountRateLimiter(defaultWindowLimit, defaultWindowDuration), + retryEnabled: true, + retryMaxAttempts: 10, + retryDelay: 1 * time.Second, + retryMaxDelay: defaultWindowDuration, + logger: logger, + }, nil } func (c *HttpClient) Get(ctx context.Context, name, path string, responseBody interface{}) error { - return c.connection(ctx, http.MethodGet, name, path, nil, nil, responseBody) + return c.connectionWithRetries(ctx, http.MethodGet, name, path, nil, nil, responseBody) } func (c *HttpClient) GetWithQuery(ctx context.Context, name, path string, query url.Values, responseBody interface{}) error { - return c.connection(ctx, http.MethodGet, name, path, query, nil, responseBody) + return c.connectionWithRetries(ctx, http.MethodGet, name, path, query, nil, responseBody) } func (c *HttpClient) Put(ctx context.Context, name, path string, requestBody interface{}, responseBody interface{}) error { - return c.connection(ctx, http.MethodPut, name, path, nil, requestBody, responseBody) + return c.connectionWithRetries(ctx, http.MethodPut, name, path, nil, requestBody, responseBody) } func (c *HttpClient) Post(ctx context.Context, name, path string, requestBody interface{}, responseBody interface{}) error { - return c.connection(ctx, http.MethodPost, name, path, nil, requestBody, responseBody) + return c.connectionWithRetries(ctx, http.MethodPost, name, path, nil, requestBody, responseBody) } func (c *HttpClient) Delete(ctx context.Context, name, path string, responseBody interface{}) error { - return c.connection(ctx, http.MethodDelete, name, path, nil, nil, responseBody) + return c.connectionWithRetries(ctx, http.MethodDelete, name, path, nil, nil, responseBody) } func (c *HttpClient) DeleteWithQuery(ctx context.Context, name, path string, requestBody interface{}, responseBody interface{}) error { - return c.connection(ctx, http.MethodDelete, name, path, nil, requestBody, responseBody) + return c.connectionWithRetries(ctx, http.MethodDelete, name, path, nil, requestBody, responseBody) +} + +func (c *HttpClient) connectionWithRetries(ctx context.Context, method, name, path string, query url.Values, requestBody interface{}, responseBody interface{}) error { + return retry.Do(func() error { + return c.connection(ctx, method, name, path, query, requestBody, responseBody) + }, + retry.Attempts(c.retryMaxAttempts), + retry.Delay(c.retryDelay), + retry.MaxDelay(c.retryMaxDelay), + retry.RetryIf(func(err error) bool { + if !c.retryEnabled { + return false + } + var target *HTTPError + if errors.As(err, &target) && target.StatusCode == http.StatusTooManyRequests { + c.logger.Println(fmt.Sprintf("status code 429 received, request will be retried")) + return true + } + return false + }), + retry.LastErrorOnly(true), + retry.Context(ctx), + ) } func (c *HttpClient) connection(ctx context.Context, method, name, path string, query url.Values, requestBody interface{}, responseBody interface{}) error { + if c.rateLimiter != nil { + err := c.rateLimiter.Wait(ctx) + if err != nil { + return err + } + } + parsed := new(url.URL) *parsed = *c.baseUrl @@ -81,6 +139,16 @@ func (c *HttpClient) connection(ctx context.Context, method, name, path string, return fmt.Errorf("failed to %s: %w", name, err) } + remainingLimit := response.Header.Get(headerRateLimitRemaining) + if remainingLimit != "" { + if limit, err := strconv.Atoi(remainingLimit); err == nil { + err = c.rateLimiter.Update(limit) + if err != nil { + return err + } + } + } + defer response.Body.Close() if response.StatusCode > 299 { diff --git a/internal/http_client_test.go b/internal/http_client_test.go index 2be0d68..920dfb2 100644 --- a/internal/http_client_test.go +++ b/internal/http_client_test.go @@ -6,6 +6,7 @@ import ( "net/http/httptest" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -14,9 +15,81 @@ func TestHttpClient_Get_failsFor4xx(t *testing.T) { w.WriteHeader(418) })) - subject, err := NewHttpClient(s.Client(), s.URL) + subject, err := NewHttpClient(s.Client(), s.URL, &testLogger{t: t}) require.NoError(t, err) err = subject.Get(context.TODO(), "testing", "/", nil) require.Error(t, err) } + +func TestHttpClient_Retry(t *testing.T) { + testCase := []struct { + description string + retryEnabled bool + statusCode int + expectedCount int + expectedError string + }{ + { + description: "should retry 429 requests when retry is enabled", + retryEnabled: true, + statusCode: 429, + expectedCount: 3, + }, + { + description: "should not retry other status code when retry is enabled", + retryEnabled: true, + statusCode: 404, + expectedCount: 1, + expectedError: "failed to test get request: 404 - ", + }, + { + description: "should not retry 429 requests when retry is disabled", + retryEnabled: false, + statusCode: 429, + expectedCount: 1, + expectedError: "failed to test get request: 429 - ", + }, + } + + for _, test := range testCase { + t.Run(test.description, func(t *testing.T) { + + count := 0 + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + count++ + if count < 3 { + w.WriteHeader(test.statusCode) + return + } + w.WriteHeader(200) + _, err := w.Write([]byte("{}")) + require.NoError(t, err) + })) + + subject, err := NewHttpClient(s.Client(), s.URL, &testLogger{t: t}) + require.NoError(t, err) + subject.retryEnabled = test.retryEnabled + + ctx := context.Background() + err = subject.Get(ctx, "test get request", "/", nil) + if test.expectedError != "" { + assert.EqualError(t, err, test.expectedError) + } else { + assert.NoError(t, err) + } + assert.Equal(t, test.expectedCount, count) + }) + } + +} + +type testLogger struct { + t *testing.T +} + +func (l *testLogger) Println(v ...interface{}) { + l.t.Log(v...) +} + +var _ Log = &testLogger{} diff --git a/internal/log.go b/internal/log.go new file mode 100644 index 0000000..7176728 --- /dev/null +++ b/internal/log.go @@ -0,0 +1,5 @@ +package internal + +type Log interface { + Println(v ...interface{}) +} diff --git a/internal/rate_limit.go b/internal/rate_limit.go new file mode 100644 index 0000000..cbd9380 --- /dev/null +++ b/internal/rate_limit.go @@ -0,0 +1,92 @@ +package internal + +import ( + "context" + "fmt" + "sync" + "time" +) + +type RateLimiter interface { + // Wait will verify one request can be sent or wait if it can't. + Wait(ctx context.Context) error + // Update the rate limiter when the server returns more information about the current limits. + Update(remaining int) error +} + +// A fixedWindowCountRateLimiter is a rate limiter that will count the number of requests within a period (or window) +// and block the caller for the expected remaining period in the window. +// +// The window will start again after the last one closes and the count will be reset. +// Since other requests can happen outside the SDK, callers can calls the Update() function to update the remaining +// event in the window. +// +// This rate limiter tries to model the server-side behaviour as best it can, however, it doesn't know exactly when +// the server-side window starts or ends, so it can be misaligned. Therefore, the callers still need to retry requests +// if a status code 429 (Too Many Requests) is received. +type fixedWindowCountRateLimiter struct { + limit int + period time.Duration + windowStart *time.Time + count int + mu *sync.Mutex +} + +func newFixedWindowCountRateLimiter(limit int, period time.Duration) *fixedWindowCountRateLimiter { + return &fixedWindowCountRateLimiter{ + limit: limit, + period: period, + mu: &sync.Mutex{}, + } +} + +// Wait will block the caller when the number of requests has exceeded the limit in the current window. +// This function allows bursting so it will only block when the limit is reached. +func (rl *fixedWindowCountRateLimiter) Wait(ctx context.Context) error { + rl.mu.Lock() + defer rl.mu.Unlock() + + // Start window on first requests + if rl.windowStart == nil { + now := time.Now() + rl.windowStart = &now + } + + windowEnd := rl.windowStart.Add(rl.period) + if time.Now().After(windowEnd) { + rl.count = 0 + rl.windowStart = &windowEnd + windowEnd = rl.windowStart.Add(rl.period) + } + + if rl.count == rl.limit { + delay := windowEnd.Sub(time.Now()) + err := sleepWithContext(ctx, delay) + if err != nil { + return err + } + } + rl.count++ + return nil +} + +func (rl *fixedWindowCountRateLimiter) Update(remaining int) error { + rl.mu.Lock() + defer rl.mu.Unlock() + rl.count = rl.limit - remaining + return nil +} + +func sleepWithContext(ctx context.Context, d time.Duration) error { + timer := time.NewTimer(d) + select { + case <-ctx.Done(): + if !timer.Stop() { + return fmt.Errorf("context expired before timer stopped") + } + case <-timer.C: + } + return nil +} + +var _ RateLimiter = &fixedWindowCountRateLimiter{} diff --git a/internal/rate_limit_test.go b/internal/rate_limit_test.go new file mode 100644 index 0000000..956a352 --- /dev/null +++ b/internal/rate_limit_test.go @@ -0,0 +1,53 @@ +package internal + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFixedWindowCountRateLimiter_Wait(t *testing.T) { + windowSize := 2 * time.Second + windowLimit := 10 + + limiter := newFixedWindowCountRateLimiter(windowLimit, windowSize) + + ctx := context.Background() + start := time.Now() + runs := 3 + count := 0 + for range windowLimit * runs { + err := limiter.Wait(ctx) + require.NoError(t, err) + count++ + } + end := time.Now() + assert.Equal(t, runs*windowLimit, count) + assert.Greater(t, end.Sub(start), windowSize.Nanoseconds()*int64(runs-1)) +} + +func TestFixedWindowCountRateLimiter_Update(t *testing.T) { + windowSize := 2 * time.Second + windowLimit := 10 + + limiter := newFixedWindowCountRateLimiter(windowLimit, windowSize) + + ctx := context.Background() + start := time.Now() + runs := 2 + count := 0 + assert.NoError(t, limiter.Update(0)) + for range windowLimit * runs { + t.Logf("%s\n", time.Now().String()) + err := limiter.Wait(ctx) + require.NoError(t, err) + count++ + } + end := time.Now() + assert.Equal(t, runs*windowLimit, count) + assert.Greater(t, end.Sub(start), windowSize.Nanoseconds()*int64(runs)) + +} diff --git a/internal/service.go b/internal/service.go index 627de49..8c8960f 100644 --- a/internal/service.go +++ b/internal/service.go @@ -12,10 +12,6 @@ import ( "github.com/avast/retry-go/v4" ) -type Log interface { - Println(v ...interface{}) -} - type Api interface { // WaitForResourceId will poll the Task, waiting for the Task to finish processing, where it will then return. // An error will be returned if the Task couldn't be retrieved or the Task was not processed successfully.