Skip to content

Commit

Permalink
improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyang-hu committed Sep 13, 2024
1 parent b565a23 commit 2d988ed
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 183 deletions.
12 changes: 5 additions & 7 deletions x/mongo/driver/topology/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package topology
import (
"context"
"crypto/tls"
"encoding/binary"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -472,10 +473,7 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {

func (c *connection) parseWmSizeBytes(wmSizeBytes [4]byte) (int32, error) {
// read the length as an int32
size := (int32(wmSizeBytes[0])) |
(int32(wmSizeBytes[1]) << 8) |
(int32(wmSizeBytes[2]) << 16) |
(int32(wmSizeBytes[3]) << 24)
size := int32(binary.LittleEndian.Uint32(wmSizeBytes[:]))

if size < 4 {
return 0, fmt.Errorf("malformed message length: %d", size)
Expand Down Expand Up @@ -506,7 +504,7 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string,
}
}()

needToWait := func(err error) bool {
isCSOTTimeout := func(err error) bool {
// If the error was a timeout error and CSOT is enabled, instead of
// closing the connection mark it as awaiting response so the pool
// can read the response before making it available to other
Expand All @@ -524,7 +522,7 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string,
// reading messages from an exhaust cursor.
n, err := io.ReadFull(c.nc, sizeBuf[:])
if err != nil {
if l := int32(n); l == 0 && needToWait(err) {
if l := int32(n); l == 0 && isCSOTTimeout(err) {
c.awaitRemainingBytes = &l
}
return nil, "incomplete read of message header", err
Expand All @@ -540,7 +538,7 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string,
n, err = io.ReadFull(c.nc, dst[4:])
if err != nil {
remainingBytes := size - 4 - int32(n)
if remainingBytes > 0 && needToWait(err) {
if remainingBytes > 0 && isCSOTTimeout(err) {
c.awaitRemainingBytes = &remainingBytes
}
return dst, "incomplete read of full message", err
Expand Down
2 changes: 1 addition & 1 deletion x/mongo/driver/topology/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,7 @@ func bgRead(pool *pool, conn *connection, size int32) {
}
_, err = io.CopyN(io.Discard, conn.nc, int64(size))
if err != nil {
err = fmt.Errorf("error reading message of %d: %w", size, err)
err = fmt.Errorf("error discarding %d byte message: %w", size, err)
}
}

Expand Down
Loading

0 comments on commit 2d988ed

Please sign in to comment.