From 46675be6793f4016593c2830f956c40101f0fc79 Mon Sep 17 00:00:00 2001 From: Will Browne Date: Tue, 27 Aug 2024 11:24:32 +0100 Subject: [PATCH] add tests --- backend/data_adapter.go | 16 +++++++ backend/error_source.go | 5 +-- backend/error_source_test.go | 85 ++++++++++++++++++++++++++++++++++++ backend/request_status.go | 11 +++++ 4 files changed, 113 insertions(+), 4 deletions(-) create mode 100644 backend/error_source_test.go diff --git a/backend/data_adapter.go b/backend/data_adapter.go index e08d01d1a..bd8078ecb 100644 --- a/backend/data_adapter.go +++ b/backend/data_adapter.go @@ -37,11 +37,16 @@ func (a *dataSDKAdapter) QueryData(ctx context.Context, req *pluginv2.QueryDataR return RequestStatusCancelled, nil } + if isHTTPTimeoutError(innerErr) { + return RequestStatusError, nil + } + // Set downstream status source in the context if there's at least one response with downstream status source, // and if there's no plugin error var hasPluginError bool var hasDownstreamError bool var hasCancelledError bool + var hasHTTPTimeoutError bool for _, r := range resp.Responses { if r.Error == nil { continue @@ -50,6 +55,10 @@ func (a *dataSDKAdapter) QueryData(ctx context.Context, req *pluginv2.QueryDataR if isCancelledError(r.Error) { hasCancelledError = true } + if isHTTPTimeoutError(r.Error) { + hasHTTPTimeoutError = true + } + if r.ErrorSource == ErrorSourceDownstream { hasDownstreamError = true } else { @@ -64,6 +73,13 @@ func (a *dataSDKAdapter) QueryData(ctx context.Context, req *pluginv2.QueryDataR return RequestStatusCancelled, nil } + if hasHTTPTimeoutError { + if err := WithDownstreamErrorSource(ctx); err != nil { + return RequestStatusError, fmt.Errorf("failed to set downstream status source: %w", errors.Join(innerErr, err)) + } + return RequestStatusError, nil + } + // A plugin error has higher priority than a downstream error, // so set to downstream only if there's no plugin error if hasDownstreamError && !hasPluginError { diff --git a/backend/error_source.go b/backend/error_source.go index fcdc27b01..1df6a0dd2 100644 --- a/backend/error_source.go +++ b/backend/error_source.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "net" "net/http" ) @@ -63,9 +62,7 @@ func IsDownstreamError(err error) bool { return true } - // if error is HTTP network timeout error, we should treat it as downstream error - var netErr net.Error - if errors.As(err, &netErr) && netErr.Timeout() { + if isHTTPTimeoutError(err) { return true } diff --git a/backend/error_source_test.go b/backend/error_source_test.go new file mode 100644 index 000000000..264ad7f4f --- /dev/null +++ b/backend/error_source_test.go @@ -0,0 +1,85 @@ +package backend + +import ( + "fmt" + "net" + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIsDownstreamError(t *testing.T) { + tcs := []struct { + name string + err error + expected bool + }{ + { + name: "nil", + err: nil, + expected: false, + }, + { + name: "downstream error", + err: DownstreamError(nil), + expected: true, + }, + { + name: "timeout network error", + err: newFakeNetworkError(true, false), + expected: true, + }, + { + name: "temporary timeout network error", + err: newFakeNetworkError(true, true), + expected: true, + }, + { + name: "non-timeout network error", + err: newFakeNetworkError(false, false), + expected: false, + }, + { + name: "os.ErrDeadlineExceeded", + err: os.ErrDeadlineExceeded, + expected: true, + }, + { + name: "other error", + err: fmt.Errorf("other error"), + expected: false, + }, + } + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + assert.Equalf(t, tc.expected, IsDownstreamError(tc.err), "IsDownstreamError(%v)", tc.err) + }) + } +} + +var _ net.Error = &fakeNetworkError{} + +type fakeNetworkError struct { + timeout bool + temporary bool +} + +func newFakeNetworkError(timeout, temporary bool) *fakeNetworkError { + return &fakeNetworkError{ + timeout: timeout, + temporary: temporary, + } +} + +func (d *fakeNetworkError) Error() string { + return "dummy timeout error" +} + +func (d *fakeNetworkError) Timeout() bool { + return d.timeout +} + +func (d *fakeNetworkError) Temporary() bool { + return d.temporary +} diff --git a/backend/request_status.go b/backend/request_status.go index 905907ba8..2fbe243fa 100644 --- a/backend/request_status.go +++ b/backend/request_status.go @@ -3,6 +3,8 @@ package backend import ( "context" "errors" + "net" + "os" "strings" grpccodes "google.golang.org/grpc/codes" @@ -105,3 +107,12 @@ func RequestStatusFromProtoQueryDataResponse(res *pluginv2.QueryDataResponse, er func isCancelledError(err error) bool { return errors.Is(err, context.Canceled) || grpcstatus.Code(err) == grpccodes.Canceled } + +func isHTTPTimeoutError(err error) bool { + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + return true + } + + return errors.Is(err, os.ErrDeadlineExceeded) // relacement for os.IsTimeout(err) +}