diff --git a/backend/httpclient/count_bytes_reader.go b/backend/httpclient/count_bytes_reader.go new file mode 100644 index 000000000..278a0b378 --- /dev/null +++ b/backend/httpclient/count_bytes_reader.go @@ -0,0 +1,39 @@ +package httpclient + +import ( + "io" +) + +type CloseCallbackFunc func(bytesRead int64) + +// CountBytesReader counts the total amount of bytes read from the underlying reader. +// +// The provided callback func will be called before the underlying reader is closed. +func CountBytesReader(reader io.ReadCloser, callback CloseCallbackFunc) io.ReadCloser { + if reader == nil { + panic("reader cannot be nil") + } + + if callback == nil { + panic("callback cannot be nil") + } + + return &countBytesReader{reader: reader, callback: callback} +} + +type countBytesReader struct { + reader io.ReadCloser + callback CloseCallbackFunc + counter int64 +} + +func (r *countBytesReader) Read(p []byte) (int, error) { + n, err := r.reader.Read(p) + r.counter += int64(n) + return n, err +} + +func (r *countBytesReader) Close() error { + r.callback(r.counter) + return r.reader.Close() +} diff --git a/backend/httpclient/count_bytes_reader_test.go b/backend/httpclient/count_bytes_reader_test.go new file mode 100644 index 000000000..96e94cf55 --- /dev/null +++ b/backend/httpclient/count_bytes_reader_test.go @@ -0,0 +1,38 @@ +package httpclient + +import ( + "fmt" + "io" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCountBytesReader(t *testing.T) { + tcs := []struct { + body string + expectedBytesCount int64 + }{ + {body: "d", expectedBytesCount: 1}, + {body: "dummy", expectedBytesCount: 5}, + } + + for index, tc := range tcs { + t.Run(fmt.Sprintf("Test CountBytesReader %d", index), func(t *testing.T) { + body := io.NopCloser(strings.NewReader(tc.body)) + var actualBytesRead int64 + + readCloser := CountBytesReader(body, func(bytesRead int64) { + actualBytesRead = bytesRead + }) + + bodyBytes, err := io.ReadAll(readCloser) + require.NoError(t, err) + err = readCloser.Close() + require.NoError(t, err) + require.Equal(t, tc.expectedBytesCount, actualBytesRead) + require.Equal(t, string(bodyBytes), tc.body) + }) + } +}