diff --git a/x/mongo/driver/topology/connection.go b/x/mongo/driver/topology/connection.go index d0dfe08789..7a8427ccee 100644 --- a/x/mongo/driver/topology/connection.go +++ b/x/mongo/driver/topology/connection.go @@ -9,6 +9,7 @@ package topology import ( "context" "crypto/tls" + "encoding/binary" "errors" "fmt" "io" @@ -472,10 +473,7 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) { func (c *connection) parseWmSizeBytes(wmSizeBytes [4]byte) (int32, error) { // read the length as an int32 - size := (int32(wmSizeBytes[0])) | - (int32(wmSizeBytes[1]) << 8) | - (int32(wmSizeBytes[2]) << 16) | - (int32(wmSizeBytes[3]) << 24) + size := int32(binary.LittleEndian.Uint32(wmSizeBytes[:])) if size < 4 { return 0, fmt.Errorf("malformed message length: %d", size) @@ -506,7 +504,7 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, } }() - needToWait := func(err error) bool { + isCSOTTimeout := func(err error) bool { // If the error was a timeout error and CSOT is enabled, instead of // closing the connection mark it as awaiting response so the pool // can read the response before making it available to other @@ -524,7 +522,7 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, // reading messages from an exhaust cursor. n, err := io.ReadFull(c.nc, sizeBuf[:]) if err != nil { - if l := int32(n); l == 0 && needToWait(err) { + if l := int32(n); l == 0 && isCSOTTimeout(err) { c.awaitRemainingBytes = &l } return nil, "incomplete read of message header", err @@ -540,7 +538,7 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, n, err = io.ReadFull(c.nc, dst[4:]) if err != nil { remainingBytes := size - 4 - int32(n) - if remainingBytes > 0 && needToWait(err) { + if remainingBytes > 0 && isCSOTTimeout(err) { c.awaitRemainingBytes = &remainingBytes } return dst, "incomplete read of full message", err diff --git a/x/mongo/driver/topology/pool.go b/x/mongo/driver/topology/pool.go index 5d232f1ebc..ddb69ada76 100644 --- a/x/mongo/driver/topology/pool.go +++ b/x/mongo/driver/topology/pool.go @@ -840,7 +840,7 @@ func bgRead(pool *pool, conn *connection, size int32) { } _, err = io.CopyN(io.Discard, conn.nc, int64(size)) if err != nil { - err = fmt.Errorf("error reading message of %d: %w", size, err) + err = fmt.Errorf("error discarding %d byte message: %w", size, err) } } diff --git a/x/mongo/driver/topology/pool_test.go b/x/mongo/driver/topology/pool_test.go index ebb342e17c..514d393a93 100644 --- a/x/mongo/driver/topology/pool_test.go +++ b/x/mongo/driver/topology/pool_test.go @@ -11,7 +11,6 @@ import ( "errors" "io" "net" - "os" "regexp" "sync" "testing" @@ -1126,212 +1125,207 @@ func TestPool(t *testing.T) { p.close(context.Background()) }) }) - t.Run("bgRead", func(t *testing.T) { - t.Parallel() +} + +func TestBackgroundRead(t *testing.T) { + t.Parallel() - var errCh chan error - BGReadCallback = func(addr string, start, read time.Time, errs []error, connClosed bool) { + newBGReadCallback := func(errCh chan error) func(string, time.Time, time.Time, []error, bool) { + return func(_ string, _, _ time.Time, errs []error, _ bool) { defer close(errCh) for _, err := range errs { errCh <- err } } + } - const sockPath = "./test.sock" - - var socket net.Listener + t.Run("incomplete read of message header", func(t *testing.T) { + errCh := make(chan error) + var originalCallback func(string, time.Time, time.Time, []error, bool) + originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errCh) + t.Cleanup(func() { + BGReadCallback = originalCallback + }) - setup := func(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(1) + addr := bootstrapConnections(t, 1, func(nc net.Conn) { t.Helper() - errCh = make(chan error) + defer func() { + _ = nc.Close() + wg.Done() + }() - var err error - socket, err = net.Listen("unix", sockPath) + _, err := nc.Write([]byte{10, 0, 0}) noerr(t, err) - } - teardown := func(t *testing.T) { - t.Helper() - - os.Remove(sockPath) - } - - t.Run("incomplete read of message header", func(t *testing.T) { - setup(t) - defer teardown(t) - - wg := &sync.WaitGroup{} - wg.Add(1) - go func(t *testing.T) { - t.Helper() + time.Sleep(1500 * time.Millisecond) + }) - defer wg.Done() + p := newPool( + poolConfig{}, + WithDialer(func(Dialer) Dialer { + return DialerFunc(func(context.Context, string, string) (net.Conn, error) { + return net.Dial("tcp", addr.String()) + }) + }), + ) + err := p.ready() + noerr(t, err) - conn, err := socket.Accept() - noerr(t, err) - defer conn.Close() + conn, err := p.checkOut(context.Background()) + noerr(t, err) + ctx, cancel := csot.MakeTimeoutContext(context.Background(), time.Second) + defer cancel() + _, err = conn.readWireMessage(ctx) + regex := regexp.MustCompile( + `^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`, + ) + assert.True(t, regex.MatchString(err.Error()), "mismatched err: %v", err) + assert.Nil(t, conn.awaitRemainingBytes, "conn.awaitRemainingBytes should be nil") + wg.Wait() + p.close(context.Background()) + close(errCh) + }) + t.Run("timeout on reading the message header", func(t *testing.T) { + errCh := make(chan error) + var originalCallback func(string, time.Time, time.Time, []error, bool) + originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errCh) + t.Cleanup(func() { + BGReadCallback = originalCallback + }) - _, err = conn.Write([]byte{10, 0, 0}) - noerr(t, err) - time.Sleep(1500 * time.Millisecond) - }(t) + wg := &sync.WaitGroup{} + wg.Add(1) + addr := bootstrapConnections(t, 1, func(nc net.Conn) { + t.Helper() - p := newPool( - poolConfig{}, - WithDialer(func(Dialer) Dialer { - return DialerFunc(func(context.Context, string, string) (net.Conn, error) { - return net.Dial("unix", sockPath) - }) - }), - ) - err := p.ready() - noerr(t, err) + defer func() { + _ = nc.Close() + wg.Done() + }() - conn, err := p.checkOut(context.Background()) + time.Sleep(1500 * time.Millisecond) + _, err := nc.Write([]byte{10, 0, 0, 0, 0, 0, 0, 0}) noerr(t, err) - ctx, cancel := csot.MakeTimeoutContext(context.Background(), time.Second) - defer cancel() - _, err = conn.readWireMessage(ctx) - regex := regexp.MustCompile( - `^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: read unix .*->\.\/test.sock: i\/o timeout$`, - ) - assert.True(t, regex.MatchString(err.Error()), "mismatched err: %v", err) - assert.Nil(t, conn.awaitRemainingBytes, "conn.awaitingResponse should be nil") - wg.Wait() - p.close(context.Background()) - close(errCh) + time.Sleep(1500 * time.Millisecond) }) - t.Run("timeout on reading the message header", func(t *testing.T) { - setup(t) - defer teardown(t) + go func(t *testing.T) { + }(t) - wg := &sync.WaitGroup{} - wg.Add(1) - go func(t *testing.T) { - t.Helper() - - defer wg.Done() - - conn, err := socket.Accept() - noerr(t, err) - defer conn.Close() - - time.Sleep(1500 * time.Millisecond) - _, err = conn.Write([]byte{10, 0, 0, 0, 0, 0, 0, 0}) - noerr(t, err) - time.Sleep(1500 * time.Millisecond) - }(t) - - p := newPool( - poolConfig{}, - WithDialer(func(Dialer) Dialer { - return DialerFunc(func(context.Context, string, string) (net.Conn, error) { - return net.Dial("unix", sockPath) - }) - }), - ) - err := p.ready() - noerr(t, err) + p := newPool( + poolConfig{}, + WithDialer(func(Dialer) Dialer { + return DialerFunc(func(context.Context, string, string) (net.Conn, error) { + return net.Dial("tcp", addr.String()) + }) + }), + ) + err := p.ready() + noerr(t, err) - conn, err := p.checkOut(context.Background()) - noerr(t, err) - ctx, cancel := csot.MakeTimeoutContext(context.Background(), time.Second) - defer cancel() - _, err = conn.readWireMessage(ctx) - regex := regexp.MustCompile( - `^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: read unix .*->\.\/test.sock: i\/o timeout$`, - ) - assert.True(t, regex.MatchString(err.Error()), "mismatched err: %v", err) - err = p.checkIn(conn) - noerr(t, err) - wg.Wait() - p.close(context.Background()) - errs := []*regexp.Regexp{ - regexp.MustCompile( - `^error reading message of 6: read unix .*->\.\/test.sock: i\/o timeout$`, - ), - } - for i := 0; true; i++ { - err, ok := <-errCh - if !ok { - if i != len(errs) { - assert.Fail(t, "expected more errors") - } - break - } else if i < len(errs) { - assert.True(t, errs[i].MatchString(err.Error()), "mismatched err: %v", err) - } else { - assert.Fail(t, "unexpected error", "got unexpected error: %v", err) + conn, err := p.checkOut(context.Background()) + noerr(t, err) + ctx, cancel := csot.MakeTimeoutContext(context.Background(), time.Second) + defer cancel() + _, err = conn.readWireMessage(ctx) + regex := regexp.MustCompile( + `^connection\(.*\[-\d+\]\) incomplete read of message header: context deadline exceeded: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`, + ) + assert.True(t, regex.MatchString(err.Error()), "mismatched err: %v", err) + err = p.checkIn(conn) + noerr(t, err) + wg.Wait() + p.close(context.Background()) + errs := []*regexp.Regexp{ + regexp.MustCompile( + `^error discarding 6 byte message: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`, + ), + } + for i := 0; true; i++ { + err, ok := <-errCh + if !ok { + if i != len(errs) { + assert.Fail(t, "expected more errors") } + break + } else if i < len(errs) { + assert.True(t, errs[i].MatchString(err.Error()), "mismatched err: %v", err) + } else { + assert.Fail(t, "unexpected error", "got unexpected error: %v", err) } + } + }) + t.Run("timeout on reading the full message", func(t *testing.T) { + errCh := make(chan error) + var originalCallback func(string, time.Time, time.Time, []error, bool) + originalCallback, BGReadCallback = BGReadCallback, newBGReadCallback(errCh) + t.Cleanup(func() { + BGReadCallback = originalCallback }) - t.Run("timeout on reading the full message", func(t *testing.T) { - setup(t) - defer teardown(t) - - wg := &sync.WaitGroup{} - wg.Add(1) - go func(t *testing.T) { - t.Helper() - - defer wg.Done() - conn, err := socket.Accept() - noerr(t, err) - defer conn.Close() - - _, err = conn.Write([]byte{12, 0, 0, 0, 0, 0, 0, 0, 1}) - noerr(t, err) - time.Sleep(1500 * time.Millisecond) - _, err = conn.Write([]byte{2, 3, 4}) - noerr(t, err) - time.Sleep(1500 * time.Millisecond) - }(t) + wg := &sync.WaitGroup{} + wg.Add(1) + addr := bootstrapConnections(t, 1, func(nc net.Conn) { + t.Helper() - p := newPool( - poolConfig{}, - WithDialer(func(Dialer) Dialer { - return DialerFunc(func(context.Context, string, string) (net.Conn, error) { - conn, err := net.Dial("unix", sockPath) - noerr(t, err) - return newLimitConn(conn, 10), nil - }) - }), - ) - err := p.ready() - noerr(t, err) + defer func() { + _ = nc.Close() + wg.Done() + }() - conn, err := p.checkOut(context.Background()) + var err error + _, err = nc.Write([]byte{12, 0, 0, 0, 0, 0, 0, 0, 1}) noerr(t, err) - ctx, cancel := csot.MakeTimeoutContext(context.Background(), 1*time.Second) - defer cancel() - _, err = conn.readWireMessage(ctx) - regex := regexp.MustCompile( - `^connection\(.*\[-\d+\]\) incomplete read of full message: context deadline exceeded: read unix .*->\.\/test.sock: i\/o timeout$`, - ) - assert.True(t, regex.MatchString(err.Error()), "mismatched err: %v", err) - err = p.checkIn(conn) + time.Sleep(1500 * time.Millisecond) + _, err = nc.Write([]byte{2, 3, 4}) noerr(t, err) - wg.Wait() - p.close(context.Background()) - errs := []string{ - "error reading message of 3: EOF", - } - for i := 0; true; i++ { - err, ok := <-errCh - if !ok { - if i != len(errs) { - assert.Fail(t, "expected more errors") - } - break - } else if i < len(errs) { - assert.EqualError(t, err, errs[i], "mismatched err: %v", err) - } else { - assert.Fail(t, "unexpected error", "got unexpected error: %v", err) + time.Sleep(1500 * time.Millisecond) + }) + + p := newPool( + poolConfig{}, + WithDialer(func(Dialer) Dialer { + return DialerFunc(func(context.Context, string, string) (net.Conn, error) { + conn, err := net.Dial("tcp", addr.String()) + noerr(t, err) + return newLimitConn(conn, 10), nil + }) + }), + ) + err := p.ready() + noerr(t, err) + + conn, err := p.checkOut(context.Background()) + noerr(t, err) + ctx, cancel := csot.MakeTimeoutContext(context.Background(), 1*time.Second) + defer cancel() + _, err = conn.readWireMessage(ctx) + regex := regexp.MustCompile( + `^connection\(.*\[-\d+\]\) incomplete read of full message: context deadline exceeded: read tcp 127.0.0.1:.*->127.0.0.1:.*: i\/o timeout$`, + ) + assert.True(t, regex.MatchString(err.Error()), "mismatched err: %v", err) + err = p.checkIn(conn) + noerr(t, err) + wg.Wait() + p.close(context.Background()) + errs := []string{ + "error discarding 3 byte message: EOF", + } + for i := 0; true; i++ { + err, ok := <-errCh + if !ok { + if i != len(errs) { + assert.Fail(t, "expected more errors") } + break + } else if i < len(errs) { + assert.EqualError(t, err, errs[i], "mismatched err: %v", err) + } else { + assert.Fail(t, "unexpected error", "got unexpected error: %v", err) } - }) + } }) }