Skip to content

Commit

Permalink
GODRIVER-2935 Use OP_QUERY in connection handshakes (#1377)
Browse files Browse the repository at this point in the history
  • Loading branch information
prestonvasquez authored Sep 7, 2023
1 parent 68bf155 commit 3bef0e4
Show file tree
Hide file tree
Showing 9 changed files with 272 additions and 16 deletions.
4 changes: 2 additions & 2 deletions mongo/integration/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -731,8 +731,8 @@ func TestClient(t *testing.T) {
pair := msgPairs[0]
assert.Equal(mt, handshake.LegacyHello, pair.CommandName, "expected command name %s at index 0, got %s",
handshake.LegacyHello, pair.CommandName)
assert.Equal(mt, wiremessage.OpMsg, pair.Sent.OpCode,
"expected 'OP_MSG' OpCode in wire message, got %q", pair.Sent.OpCode.String())
assert.Equal(mt, wiremessage.OpQuery, pair.Sent.OpCode,
"expected 'OP_QUERY' OpCode in wire message, got %q", pair.Sent.OpCode.String())

// Look for a saslContinue in the remaining proxied messages and assert that it uses the OP_MSG OpCode, as wire
// version is now known to be >= 6.
Expand Down
65 changes: 65 additions & 0 deletions mongo/integration/mtest/sent_message.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ type sentMsgParseFn func([]byte) (*SentMessage, error)

func getSentMessageParser(opcode wiremessage.OpCode) (sentMsgParseFn, bool) {
switch opcode {
case wiremessage.OpQuery:
return parseOpQuery, true
case wiremessage.OpMsg:
return parseSentOpMsg, true
case wiremessage.OpCompressed:
Expand All @@ -46,6 +48,69 @@ func getSentMessageParser(opcode wiremessage.OpCode) (sentMsgParseFn, bool) {
}
}

func parseOpQuery(wm []byte) (*SentMessage, error) {
var ok bool

if _, wm, ok = wiremessage.ReadQueryFlags(wm); !ok {
return nil, errors.New("failed to read query flags")
}
if _, wm, ok = wiremessage.ReadQueryFullCollectionName(wm); !ok {
return nil, errors.New("failed to read full collection name")
}
if _, wm, ok = wiremessage.ReadQueryNumberToSkip(wm); !ok {
return nil, errors.New("failed to read number to skip")
}
if _, wm, ok = wiremessage.ReadQueryNumberToReturn(wm); !ok {
return nil, errors.New("failed to read number to return")
}

query, wm, ok := wiremessage.ReadQueryQuery(wm)
if !ok {
return nil, errors.New("failed to read query")
}

// If there is no read preference document, the command document is query.
// Otherwise, query is in the format {$query: <command document>, $readPreference: <read preference document>}.
commandDoc := query
var rpDoc bsoncore.Document

dollarQueryVal, err := query.LookupErr("$query")
if err == nil {
commandDoc = dollarQueryVal.Document()

rpVal, err := query.LookupErr("$readPreference")
if err != nil {
return nil, fmt.Errorf("query %s contains $query but not $readPreference fields", query)
}
rpDoc = rpVal.Document()
}

// For OP_QUERY, inserts, updates, and deletes are sent as a BSON array of documents inside the main command
// document. Pull these sequences out into an ArrayStyle DocumentSequence.
var docSequence *bsoncore.DocumentSequence
cmdElems, _ := commandDoc.Elements()
for _, elem := range cmdElems {
switch elem.Key() {
case "documents", "updates", "deletes":
docSequence = &bsoncore.DocumentSequence{
Style: bsoncore.ArrayStyle,
Data: elem.Value().Array(),
}
}
if docSequence != nil {
// There can only be one of these arrays in a well-formed command, so we exit the loop once one is found.
break
}
}

sm := &SentMessage{
Command: commandDoc,
ReadPreference: rpDoc,
DocumentSequence: docSequence,
}
return sm, nil
}

func parseSentMessage(wm []byte) (*SentMessage, error) {
// Re-assign the wire message to "remaining" so "wm" continues to point to the entire message after parsing.
_, requestID, _, opcode, remaining, ok := wiremessage.ReadHeader(wm)
Expand Down
4 changes: 2 additions & 2 deletions x/mongo/driver/auth/speculative_scram_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func TestSpeculativeSCRAM(t *testing.T) {
// Assert that the driver sent hello with the speculative authentication message.
assert.Equal(t, len(tc.payloads), len(conn.Written), "expected %d wire messages to be sent, got %d",
len(tc.payloads), (conn.Written))
helloCmd, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written)
helloCmd, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written)
assert.Nil(t, err, "error parsing hello command: %v", err)
assertCommandName(t, helloCmd, handshake.LegacyHello)

Expand Down Expand Up @@ -177,7 +177,7 @@ func TestSpeculativeSCRAM(t *testing.T) {

assert.Equal(t, numResponses, len(conn.Written), "expected %d wire messages to be sent, got %d",
numResponses, len(conn.Written))
hello, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written)
hello, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written)
assert.Nil(t, err, "error parsing hello command: %v", err)
assertCommandName(t, hello, handshake.LegacyHello)
_, err = hello.LookupErr("speculativeAuthenticate")
Expand Down
4 changes: 2 additions & 2 deletions x/mongo/driver/auth/speculative_x509_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func TestSpeculativeX509(t *testing.T) {

assert.Equal(t, numResponses, len(conn.Written), "expected %d wire messages to be sent, got %d",
numResponses, len(conn.Written))
hello, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written)
hello, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written)
assert.Nil(t, err, "error parsing hello command: %v", err)
assertCommandName(t, hello, handshake.LegacyHello)

Expand Down Expand Up @@ -103,7 +103,7 @@ func TestSpeculativeX509(t *testing.T) {

assert.Equal(t, numResponses, len(conn.Written), "expected %d wire messages to be sent, got %d",
numResponses, len(conn.Written))
hello, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written)
hello, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written)
assert.Nil(t, err, "error parsing hello command: %v", err)
assertCommandName(t, hello, handshake.LegacyHello)
_, err = hello.LookupErr("speculativeAuthenticate")
Expand Down
32 changes: 32 additions & 0 deletions x/mongo/driver/drivertest/channel_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,38 @@ func MakeReply(doc bsoncore.Document) []byte {
return bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:])))
}

// GetCommandFromQueryWireMessage returns the command sent in an OP_QUERY wire message.
func GetCommandFromQueryWireMessage(wm []byte) (bsoncore.Document, error) {
var ok bool
_, _, _, _, wm, ok = wiremessage.ReadHeader(wm)
if !ok {
return nil, errors.New("could not read header")
}
_, wm, ok = wiremessage.ReadQueryFlags(wm)
if !ok {
return nil, errors.New("could not read flags")
}
_, wm, ok = wiremessage.ReadQueryFullCollectionName(wm)
if !ok {
return nil, errors.New("could not read fullCollectionName")
}
_, wm, ok = wiremessage.ReadQueryNumberToSkip(wm)
if !ok {
return nil, errors.New("could not read numberToSkip")
}
_, wm, ok = wiremessage.ReadQueryNumberToReturn(wm)
if !ok {
return nil, errors.New("could not read numberToReturn")
}

var query bsoncore.Document
query, wm, ok = wiremessage.ReadQueryQuery(wm)
if !ok {
return nil, errors.New("could not read query")
}
return query, nil
}

// GetCommandFromMsgWireMessage returns the command document sent in an OP_MSG wire message.
func GetCommandFromMsgWireMessage(wm []byte) (bsoncore.Document, error) {
var ok bool
Expand Down
1 change: 1 addition & 0 deletions x/mongo/driver/legacy.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ const (
LegacyKillCursors
LegacyListCollections
LegacyListIndexes
LegacyHandshake
)
105 changes: 104 additions & 1 deletion x/mongo/driver/operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"time"

"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/internal/csot"
Expand Down Expand Up @@ -629,7 +630,7 @@ func (op Operation) Execute(ctx context.Context) error {
}

var startedInfo startedInformation
*wm, startedInfo, err = op.createMsgWireMessage(ctx, maxTimeMS, (*wm)[:0], desc, conn, requestID)
*wm, startedInfo, err = op.createWireMessage(ctx, maxTimeMS, (*wm)[:0], desc, conn, requestID)

if err != nil {
return err
Expand Down Expand Up @@ -1103,6 +1104,85 @@ func (op Operation) addBatchArray(dst []byte) []byte {
return dst
}

func (op Operation) createLegacyHandshakeWireMessage(
maxTimeMS uint64,
dst []byte,
desc description.SelectedServer,
) ([]byte, startedInformation, error) {
var info startedInformation
flags := op.secondaryOK(desc)
var wmindex int32
info.requestID = wiremessage.NextRequestID()
wmindex, dst = wiremessage.AppendHeaderStart(dst, info.requestID, 0, wiremessage.OpQuery)
dst = wiremessage.AppendQueryFlags(dst, flags)

dollarCmd := [...]byte{'.', '$', 'c', 'm', 'd'}

// FullCollectionName
dst = append(dst, op.Database...)
dst = append(dst, dollarCmd[:]...)
dst = append(dst, 0x00)
dst = wiremessage.AppendQueryNumberToSkip(dst, 0)
dst = wiremessage.AppendQueryNumberToReturn(dst, -1)

wrapper := int32(-1)
rp, err := op.createReadPref(desc, true)
if err != nil {
return dst, info, err
}
if len(rp) > 0 {
wrapper, dst = bsoncore.AppendDocumentStart(dst)
dst = bsoncore.AppendHeader(dst, bsontype.EmbeddedDocument, "$query")
}
idx, dst := bsoncore.AppendDocumentStart(dst)
dst, err = op.CommandFn(dst, desc)
if err != nil {
return dst, info, err
}

if op.Batches != nil && len(op.Batches.Current) > 0 {
dst = op.addBatchArray(dst)
}

dst, err = op.addReadConcern(dst, desc)
if err != nil {
return dst, info, err
}

dst, err = op.addWriteConcern(dst, desc)
if err != nil {
return dst, info, err
}

dst, err = op.addSession(dst, desc)
if err != nil {
return dst, info, err
}

dst = op.addClusterTime(dst, desc)
dst = op.addServerAPI(dst)
// If maxTimeMS is greater than 0 append it to wire message. A maxTimeMS value of 0 only explicitly
// specifies the default behavior of no timeout server-side.
if maxTimeMS > 0 {
dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", int64(maxTimeMS))
}

dst, _ = bsoncore.AppendDocumentEnd(dst, idx)
// Command monitoring only reports the document inside $query
info.cmd = dst[idx:]

if len(rp) > 0 {
var err error
dst = bsoncore.AppendDocumentElement(dst, "$readPreference", rp)
dst, err = bsoncore.AppendDocumentEnd(dst, wrapper)
if err != nil {
return dst, info, err
}
}

return bsoncore.UpdateLength(dst, wmindex, int32(len(dst[wmindex:]))), info, nil
}

func (op Operation) createMsgWireMessage(
ctx context.Context,
maxTimeMS uint64,
Expand Down Expand Up @@ -1191,6 +1271,29 @@ func (op Operation) createMsgWireMessage(
return bsoncore.UpdateLength(dst, wmindex, int32(len(dst[wmindex:]))), info, nil
}

// isLegacyHandshake returns True if the operation is the first message of
// the initial handshake and should use a legacy hello.
func isLegacyHandshake(op Operation, desc description.SelectedServer) bool {
isInitialHandshake := desc.WireVersion == nil || desc.WireVersion.Max == 0

return op.Legacy == LegacyHandshake && isInitialHandshake
}

func (op Operation) createWireMessage(
ctx context.Context,
maxTimeMS uint64,
dst []byte,
desc description.SelectedServer,
conn Connection,
requestID int32,
) ([]byte, startedInformation, error) {
if isLegacyHandshake(op, desc) {
return op.createLegacyHandshakeWireMessage(maxTimeMS, dst, desc)
}

return op.createMsgWireMessage(ctx, maxTimeMS, dst, desc, conn, requestID)
}

// addCommandFields adds the fields for a command to the wire message in dst. This assumes that the start of the document
// has already been added and does not add the final 0 byte.
func (op Operation) addCommandFields(ctx context.Context, dst []byte, desc description.SelectedServer) ([]byte, error) {
Expand Down
34 changes: 29 additions & 5 deletions x/mongo/driver/operation/hello.go
Original file line number Diff line number Diff line change
Expand Up @@ -537,8 +537,16 @@ func (h *Hello) StreamResponse(ctx context.Context, conn driver.StreamerConnecti
return h.createOperation().ExecuteExhaust(ctx, conn)
}

// isLegacyHandshake returns True if server API version is not requested and
// loadBalanced is False. If this is the case, then the drivers MUST use legacy
// hello for the first message of the initial handshake with the OP_QUERY
// protocol
func isLegacyHandshake(srvAPI *driver.ServerAPIOptions, deployment driver.Deployment) bool {
return srvAPI == nil && deployment.Kind() != description.LoadBalanced
}

func (h *Hello) createOperation() driver.Operation {
return driver.Operation{
op := driver.Operation{
Clock: h.clock,
CommandFn: h.command,
Database: "admin",
Expand All @@ -549,23 +557,36 @@ func (h *Hello) createOperation() driver.Operation {
},
ServerAPI: h.serverAPI,
}

if isLegacyHandshake(h.serverAPI, h.d) {
op.Legacy = driver.LegacyHandshake
}

return op
}

// GetHandshakeInformation performs the MongoDB handshake for the provided connection and returns the relevant
// information about the server. This function implements the driver.Handshaker interface.
func (h *Hello) GetHandshakeInformation(ctx context.Context, _ address.Address, c driver.Connection) (driver.HandshakeInformation, error) {
err := driver.Operation{
deployment := driver.SingleConnectionDeployment{C: c}

op := driver.Operation{
Clock: h.clock,
CommandFn: h.handshakeCommand,
Deployment: driver.SingleConnectionDeployment{C: c},
Deployment: deployment,
Database: "admin",
ProcessResponseFn: func(info driver.ResponseInfo) error {
h.res = info.ServerResponse
return nil
},
ServerAPI: h.serverAPI,
}.Execute(ctx)
if err != nil {
}

if isLegacyHandshake(h.serverAPI, deployment) {
op.Legacy = driver.LegacyHandshake
}

if err := op.Execute(ctx); err != nil {
return driver.HandshakeInformation{}, err
}

Expand All @@ -578,6 +599,9 @@ func (h *Hello) GetHandshakeInformation(ctx context.Context, _ address.Address,
if serverConnectionID, ok := h.res.Lookup("connectionId").AsInt64OK(); ok {
info.ServerConnectionID = &serverConnectionID
}

var err error

// Cast to bson.Raw to lookup saslSupportedMechs to avoid converting from bsoncore.Value to bson.RawValue for the
// StringSliceFromRawValue call.
if saslSupportedMechs, lookupErr := bson.Raw(h.res).LookupErr("saslSupportedMechs"); lookupErr == nil {
Expand Down
Loading

0 comments on commit 3bef0e4

Please sign in to comment.