Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GODRIVER-3302 Handle malformatted message length properly. (#1758) [master] #1817

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 102 additions & 84 deletions x/mongo/driver/topology/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package topology
import (
"context"
"crypto/tls"
"encoding/binary"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -80,9 +81,9 @@ type connection struct {
// accessTokens in the OIDC authenticator cache.
oidcTokenGenID uint64

// awaitingResponse indicates that the server response was not completely
// awaitRemainingBytes indicates the size of server response that was not completely
// read before returning the connection to the pool.
awaitingResponse bool
awaitRemainingBytes *int32
}

// newConnection handles the creation of a connection. It does not connect the connection.
Expand Down Expand Up @@ -111,11 +112,6 @@ func newConnection(addr address.Address, opts ...ConnectionOption) *connection {
return c
}

// DriverConnectionID returns the driver connection ID.
func (c *connection) DriverConnectionID() int64 {
return c.driverConnectionID
}

// setGenerationNumber sets the connection's generation number if a callback has been provided to do so in connection
// configuration.
func (c *connection) setGenerationNumber() {
Expand All @@ -137,6 +133,39 @@ func (c *connection) hasGenerationNumber() bool {
return driverutil.IsServerLoadBalanced(c.desc)
}

func configureTLS(ctx context.Context,
tlsConnSource tlsConnectionSource,
nc net.Conn,
addr address.Address,
config *tls.Config,
ocspOpts *ocsp.VerifyOptions,
) (net.Conn, error) {
// Ensure config.ServerName is always set for SNI.
if config.ServerName == "" {
hostname := addr.String()
colonPos := strings.LastIndex(hostname, ":")
if colonPos == -1 {
colonPos = len(hostname)
}

hostname = hostname[:colonPos]
config.ServerName = hostname
}

client := tlsConnSource.Client(nc, config)
if err := clientHandshake(ctx, client); err != nil {
return nil, err
}

// Only do OCSP verification if TLS verification is requested.
if !config.InsecureSkipVerify {
if ocspErr := ocsp.Verify(ctx, client.ConnectionState(), ocspOpts); ocspErr != nil {
return nil, ocspErr
}
}
return client, nil
}

// connect handles the I/O for a connection. It will dial, configure TLS, and perform initialization
// handshakes. All errors returned by connect are considered "before the handshake completes" and
// must be handled by calling the appropriate SDAM handshake error handler.
Expand Down Expand Up @@ -291,6 +320,10 @@ func (c *connection) closeConnectContext() {
}
}

func (c *connection) cancellationListenerCallback() {
_ = c.close()
}

func transformNetworkError(ctx context.Context, originalError error, contextDeadlineUsed bool) error {
if originalError == nil {
return nil
Expand All @@ -313,10 +346,6 @@ func transformNetworkError(ctx context.Context, originalError error, contextDead
return originalError
}

func (c *connection) cancellationListenerCallback() {
_ = c.close()
}

func (c *connection) writeWireMessage(ctx context.Context, wm []byte) error {
var err error
if atomic.LoadInt64(&c.state) != connConnected {
Expand Down Expand Up @@ -377,14 +406,9 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {

dst, errMsg, err := c.read(ctx)
if err != nil {
if nerr := net.Error(nil); errors.As(err, &nerr) && nerr.Timeout() {
// If the error was a timeout error, instead of closing the
// connection mark it as awaiting response so the pool can read the
// response before making it available to other operations.
c.awaitingResponse = true
} else {
// Otherwise, and close the connection because we don't know what
// the connection state is.
if c.awaitRemainingBytes == nil {
// If the connection was not marked as awaiting response, close the
// connection because we don't know what the connection state is.
c.close()
}
message := errMsg
Expand All @@ -401,6 +425,26 @@ func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {
return dst, nil
}

func (c *connection) parseWmSizeBytes(wmSizeBytes [4]byte) (int32, error) {
// read the length as an int32
size := int32(binary.LittleEndian.Uint32(wmSizeBytes[:]))

if size < 4 {
return 0, fmt.Errorf("malformed message length: %d", size)
}
// In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded
// defaultMaxMessageSize instead.
maxMessageSize := c.desc.MaxMessageSize
if maxMessageSize == 0 {
maxMessageSize = defaultMaxMessageSize
}
if uint32(size) > maxMessageSize {
return 0, errResponseTooLarge
}

return size, nil
}

func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, err error) {
go c.cancellationListener.Listen(ctx, c.cancellationListenerCallback)
defer func() {
Expand All @@ -414,36 +458,42 @@ func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string,
}
}()

isCSOTTimeout := func(err error) bool {
// If the error was a timeout error, instead of closing the
// connection mark it as awaiting response so the pool can read the
// response before making it available to other operations.
nerr := net.Error(nil)
return errors.As(err, &nerr) && nerr.Timeout()
}

// We use an array here because it only costs 4 bytes on the stack and means we'll only need to
// reslice dst once instead of twice.
var sizeBuf [4]byte

// We do a ReadFull into an array here instead of doing an opportunistic ReadAtLeast into dst
// because there might be more than one wire message waiting to be read, for example when
// reading messages from an exhaust cursor.
_, err = io.ReadFull(c.nc, sizeBuf[:])
n, err := io.ReadFull(c.nc, sizeBuf[:])
if err != nil {
if l := int32(n); l == 0 && isCSOTTimeout(err) {
c.awaitRemainingBytes = &l
}
return nil, "incomplete read of message header", err
}

// read the length as an int32
size := (int32(sizeBuf[0])) | (int32(sizeBuf[1]) << 8) | (int32(sizeBuf[2]) << 16) | (int32(sizeBuf[3]) << 24)

// In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded
// defaultMaxMessageSize instead.
maxMessageSize := c.desc.MaxMessageSize
if maxMessageSize == 0 {
maxMessageSize = defaultMaxMessageSize
}
if uint32(size) > maxMessageSize {
return nil, errResponseTooLarge.Error(), errResponseTooLarge
size, err := c.parseWmSizeBytes(sizeBuf)
if err != nil {
return nil, err.Error(), err
}

dst := make([]byte, size)
copy(dst, sizeBuf[:])

_, err = io.ReadFull(c.nc, dst[4:])
n, err = io.ReadFull(c.nc, dst[4:])
if err != nil {
remainingBytes := size - 4 - int32(n)
if remainingBytes > 0 && isCSOTTimeout(err) {
c.awaitRemainingBytes = &remainingBytes
}
return dst, "incomplete read of full message", err
}

Expand Down Expand Up @@ -496,10 +546,6 @@ func (c *connection) setCanStream(canStream bool) {
c.canStream = canStream
}

func (c initConnection) supportsStreaming() bool {
return c.canStream
}

func (c *connection) setStreaming(streaming bool) {
c.currentlyStreaming = streaming
}
Expand All @@ -508,6 +554,14 @@ func (c *connection) getCurrentlyStreaming() bool {
return c.currentlyStreaming
}

func (c *connection) previousCanceled() bool {
if val := c.prevCanceled.Load(); val != nil {
return val.(bool)
}

return false
}

func (c *connection) ID() string {
return c.id
}
Expand All @@ -516,12 +570,17 @@ func (c *connection) ServerConnectionID() *int64 {
return c.serverConnectionID
}

func (c *connection) previousCanceled() bool {
if val := c.prevCanceled.Load(); val != nil {
return val.(bool)
}
// DriverConnectionID returns the driver connection ID.
func (c *connection) DriverConnectionID() int64 {
return c.driverConnectionID
}

return false
func (c *connection) OIDCTokenGenID() uint64 {
return c.oidcTokenGenID
}

func (c *connection) SetOIDCTokenGenID(genID uint64) {
c.oidcTokenGenID = genID
}

// initConnection is an adapter used during connection initialization. It has the minimum
Expand Down Expand Up @@ -562,7 +621,7 @@ func (c initConnection) CurrentlyStreaming() bool {
return c.getCurrentlyStreaming()
}
func (c initConnection) SupportsStreaming() bool {
return c.supportsStreaming()
return c.canStream
}

// Connection implements the driver.Connection interface to allow reading and writing wire
Expand Down Expand Up @@ -797,39 +856,6 @@ func (c *Connection) DriverConnectionID() int64 {
return c.connection.DriverConnectionID()
}

func configureTLS(ctx context.Context,
tlsConnSource tlsConnectionSource,
nc net.Conn,
addr address.Address,
config *tls.Config,
ocspOpts *ocsp.VerifyOptions,
) (net.Conn, error) {
// Ensure config.ServerName is always set for SNI.
if config.ServerName == "" {
hostname := addr.String()
colonPos := strings.LastIndex(hostname, ":")
if colonPos == -1 {
colonPos = len(hostname)
}

hostname = hostname[:colonPos]
config.ServerName = hostname
}

client := tlsConnSource.Client(nc, config)
if err := clientHandshake(ctx, client); err != nil {
return nil, err
}

// Only do OCSP verification if TLS verification is requested.
if !config.InsecureSkipVerify {
if ocspErr := ocsp.Verify(ctx, client.ConnectionState(), ocspOpts); ocspErr != nil {
return nil, ocspErr
}
}
return client, nil
}

// OIDCTokenGenID returns the OIDC token generation ID.
func (c *Connection) OIDCTokenGenID() uint64 {
return c.oidcTokenGenID
Expand All @@ -839,11 +865,3 @@ func (c *Connection) OIDCTokenGenID() uint64 {
func (c *Connection) SetOIDCTokenGenID(genID uint64) {
c.oidcTokenGenID = genID
}

func (c *connection) OIDCTokenGenID() uint64 {
return c.oidcTokenGenID
}

func (c *connection) SetOIDCTokenGenID(genID uint64) {
c.oidcTokenGenID = genID
}
17 changes: 17 additions & 0 deletions x/mongo/driver/topology/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,23 @@ func TestConnection(t *testing.T) {
}
listener.assertCalledOnce(t)
})
t.Run("size too small errors", func(t *testing.T) {
err := errors.New("malformed message length: 3")
tnc := &testNetConn{readerr: err, buf: []byte{0x03, 0x00, 0x00, 0x00}}
conn := &connection{id: "foobar", nc: tnc, state: connConnected}
listener := newTestCancellationListener(false)
conn.cancellationListener = listener

want := ConnectionError{ConnectionID: "foobar", Wrapped: err, message: err.Error()}
_, got := conn.readWireMessage(context.Background())
if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
t.Errorf("errors do not match. got %v; want %v", got, want)
}
if !tnc.closed {
t.Errorf("failed to closeConnection net.Conn after error writing bytes.")
}
listener.assertCalledOnce(t)
})
t.Run("full message read errors", func(t *testing.T) {
err := errors.New("Read error")
tnc := &testNetConn{readerr: err, buf: []byte{0x11, 0x00, 0x00, 0x00}}
Expand Down
Loading
Loading