diff --git a/CHANGELOG.md b/CHANGELOG.md index 7d5da8a5163..061f5f53d3a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,9 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm ### Fixed - Change the `http-server-duration` instrument in `go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp` to record milliseconds instead of microseconds match what is specified in the OpenTelemetry specification. (#1414, #1537) +- The `"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp".Transport` type now correctly handles protocol switching responses. + The returned response body implements the `io.ReadWriteCloser` interface if the underlying one does. + This ensures that protocol switching requests receive a response body that they can write to. (#1329, #1628) ### Removed diff --git a/instrumentation/net/http/otelhttp/transport.go b/instrumentation/net/http/otelhttp/transport.go index c9d2de2835d..121ad99b0a6 100644 --- a/instrumentation/net/http/otelhttp/transport.go +++ b/instrumentation/net/http/otelhttp/transport.go @@ -123,18 +123,51 @@ func (t *Transport) RoundTrip(r *http.Request) (*http.Response, error) { span.SetAttributes(semconv.HTTPAttributesFromHTTPStatusCode(res.StatusCode)...) span.SetStatus(semconv.SpanStatusFromHTTPStatusCode(res.StatusCode)) - res.Body = &wrappedBody{ctx: ctx, span: span, body: res.Body} + res.Body = newWrappedBody(span, res.Body) return res, err } +// newWrappedBody returns a new and appropriately scoped *wrappedBody as an +// io.ReadCloser. If the passed body implements io.Writer, the returned value +// will implement io.ReadWriteCloser. +func newWrappedBody(span trace.Span, body io.ReadCloser) io.ReadCloser { + // The successful protocol switch responses will have a body that + // implement an io.ReadWriteCloser. Ensure this interface type continues + // to be satisfied if that is the case. + if _, ok := body.(io.ReadWriteCloser); ok { + return &wrappedBody{span: span, body: body} + } + + // Remove the implementation of the io.ReadWriteCloser and only implement + // the io.ReadCloser. + return struct{ io.ReadCloser }{&wrappedBody{span: span, body: body}} +} + +// wrappedBody is the response body type returned by the transport +// instrumentation to complete a span. Errors encountered when using the +// response body are recorded in span tracking the response. +// +// The span tracking the response is ended when this body is closed. +// +// If the response body implements the io.Writer interface (i.e. for +// successful protocol switches), the wrapped body also will. type wrappedBody struct { - ctx context.Context span trace.Span body io.ReadCloser } -var _ io.ReadCloser = &wrappedBody{} +var _ io.ReadWriteCloser = &wrappedBody{} + +func (wb *wrappedBody) Write(p []byte) (int, error) { + // This will not panic given the guard in newWrappedBody. + n, err := wb.body.(io.Writer).Write(p) + if err != nil { + wb.span.RecordError(err) + wb.span.SetStatus(codes.Error, err.Error()) + } + return n, err +} func (wb *wrappedBody) Read(b []byte) (int, error) { n, err := wb.body.Read(b) diff --git a/instrumentation/net/http/otelhttp/transport_test.go b/instrumentation/net/http/otelhttp/transport_test.go index eed7581c18d..e155d3e088f 100644 --- a/instrumentation/net/http/otelhttp/transport_test.go +++ b/instrumentation/net/http/otelhttp/transport_test.go @@ -22,9 +22,11 @@ import ( "io/ioutil" "net/http" "net/http/httptest" + "strings" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/codes" "go.opentelemetry.io/otel/propagation" @@ -286,3 +288,90 @@ func TestWrappedBodyCloseError(t *testing.T) { assert.Equal(t, expectedErr, wb.Close()) s.assert(t, true, nil, codes.Unset, "") } + +type readWriteCloser struct { + readCloser + + writeErr error +} + +const writeSize = 1 + +func (rwc readWriteCloser) Write([]byte) (int, error) { + return writeSize, rwc.writeErr +} + +func TestNewWrappedBodyReadWriteCloserImplementation(t *testing.T) { + wb := newWrappedBody(nil, readWriteCloser{}) + assert.Implements(t, (*io.ReadWriteCloser)(nil), wb) +} + +func TestNewWrappedBodyReadCloserImplementation(t *testing.T) { + wb := newWrappedBody(nil, readCloser{}) + assert.Implements(t, (*io.ReadCloser)(nil), wb) + + _, ok := wb.(io.ReadWriteCloser) + assert.False(t, ok, "wrappedBody should not implement io.ReadWriteCloser") +} + +func TestWrappedBodyWrite(t *testing.T) { + s := new(span) + var rwc io.ReadWriteCloser + assert.NotPanics(t, func() { + rwc = newWrappedBody(s, readWriteCloser{}).(io.ReadWriteCloser) + }) + + n, err := rwc.Write([]byte{}) + assert.Equal(t, writeSize, n, "wrappedBody returned wrong bytes") + assert.NoError(t, err) + s.assert(t, false, nil, codes.Unset, "") +} + +func TestWrappedBodyWriteError(t *testing.T) { + s := new(span) + expectedErr := errors.New("test") + var rwc io.ReadWriteCloser + assert.NotPanics(t, func() { + rwc = newWrappedBody(s, readWriteCloser{ + writeErr: expectedErr, + }).(io.ReadWriteCloser) + }) + n, err := rwc.Write([]byte{}) + assert.Equal(t, writeSize, n, "wrappedBody returned wrong bytes") + assert.ErrorIs(t, err, expectedErr) + s.assert(t, false, expectedErr, codes.Error, expectedErr.Error()) +} + +func TestTransportProtocolSwitch(t *testing.T) { + // This test validates the fix to #1329. + + // Simulate a "101 Switching Protocols" response from the test server. + response := []byte(strings.Join([]string{ + "HTTP/1.1 101 Switching Protocols", + "Upgrade: WebSocket", + "Connection: Upgrade", + "", "", // Needed for extra CRLF. + }, "\r\n")) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + conn, buf, err := w.(http.Hijacker).Hijack() + require.NoError(t, err) + + _, err = buf.Write(response) + require.NoError(t, err) + require.NoError(t, buf.Flush()) + require.NoError(t, conn.Close()) + })) + defer ts.Close() + + ctx := context.Background() + r, err := http.NewRequestWithContext(ctx, http.MethodGet, ts.URL, http.NoBody) + require.NoError(t, err) + + c := http.Client{Transport: NewTransport(http.DefaultTransport)} + res, err := c.Do(r) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, res.Body.Close()) }) + + assert.Implements(t, (*io.ReadWriteCloser)(nil), res.Body, "invalid body returned for protocol switch") +}