diff --git a/mongo/integration/client_test.go b/mongo/integration/client_test.go index d564fab2f8..0dfe740358 100644 --- a/mongo/integration/client_test.go +++ b/mongo/integration/client_test.go @@ -731,8 +731,8 @@ func TestClient(t *testing.T) { pair := msgPairs[0] assert.Equal(mt, handshake.LegacyHello, pair.CommandName, "expected command name %s at index 0, got %s", handshake.LegacyHello, pair.CommandName) - assert.Equal(mt, wiremessage.OpMsg, pair.Sent.OpCode, - "expected 'OP_MSG' OpCode in wire message, got %q", pair.Sent.OpCode.String()) + assert.Equal(mt, wiremessage.OpQuery, pair.Sent.OpCode, + "expected 'OP_QUERY' OpCode in wire message, got %q", pair.Sent.OpCode.String()) // Look for a saslContinue in the remaining proxied messages and assert that it uses the OP_MSG OpCode, as wire // version is now known to be >= 6. diff --git a/mongo/integration/mtest/sent_message.go b/mongo/integration/mtest/sent_message.go index d36075bf81..6b96e061bc 100644 --- a/mongo/integration/mtest/sent_message.go +++ b/mongo/integration/mtest/sent_message.go @@ -37,6 +37,8 @@ type sentMsgParseFn func([]byte) (*SentMessage, error) func getSentMessageParser(opcode wiremessage.OpCode) (sentMsgParseFn, bool) { switch opcode { + case wiremessage.OpQuery: + return parseOpQuery, true case wiremessage.OpMsg: return parseSentOpMsg, true case wiremessage.OpCompressed: @@ -46,6 +48,69 @@ func getSentMessageParser(opcode wiremessage.OpCode) (sentMsgParseFn, bool) { } } +func parseOpQuery(wm []byte) (*SentMessage, error) { + var ok bool + + if _, wm, ok = wiremessage.ReadQueryFlags(wm); !ok { + return nil, errors.New("failed to read query flags") + } + if _, wm, ok = wiremessage.ReadQueryFullCollectionName(wm); !ok { + return nil, errors.New("failed to read full collection name") + } + if _, wm, ok = wiremessage.ReadQueryNumberToSkip(wm); !ok { + return nil, errors.New("failed to read number to skip") + } + if _, wm, ok = wiremessage.ReadQueryNumberToReturn(wm); !ok { + return nil, errors.New("failed to read number to return") + } + + query, wm, ok := wiremessage.ReadQueryQuery(wm) + if !ok { + return nil, errors.New("failed to read query") + } + + // If there is no read preference document, the command document is query. + // Otherwise, query is in the format {$query: , $readPreference: }. + commandDoc := query + var rpDoc bsoncore.Document + + dollarQueryVal, err := query.LookupErr("$query") + if err == nil { + commandDoc = dollarQueryVal.Document() + + rpVal, err := query.LookupErr("$readPreference") + if err != nil { + return nil, fmt.Errorf("query %s contains $query but not $readPreference fields", query) + } + rpDoc = rpVal.Document() + } + + // For OP_QUERY, inserts, updates, and deletes are sent as a BSON array of documents inside the main command + // document. Pull these sequences out into an ArrayStyle DocumentSequence. + var docSequence *bsoncore.DocumentSequence + cmdElems, _ := commandDoc.Elements() + for _, elem := range cmdElems { + switch elem.Key() { + case "documents", "updates", "deletes": + docSequence = &bsoncore.DocumentSequence{ + Style: bsoncore.ArrayStyle, + Data: elem.Value().Array(), + } + } + if docSequence != nil { + // There can only be one of these arrays in a well-formed command, so we exit the loop once one is found. + break + } + } + + sm := &SentMessage{ + Command: commandDoc, + ReadPreference: rpDoc, + DocumentSequence: docSequence, + } + return sm, nil +} + func parseSentMessage(wm []byte) (*SentMessage, error) { // Re-assign the wire message to "remaining" so "wm" continues to point to the entire message after parsing. _, requestID, _, opcode, remaining, ok := wiremessage.ReadHeader(wm) diff --git a/x/mongo/driver/auth/speculative_scram_test.go b/x/mongo/driver/auth/speculative_scram_test.go index f2234e227c..a159891adc 100644 --- a/x/mongo/driver/auth/speculative_scram_test.go +++ b/x/mongo/driver/auth/speculative_scram_test.go @@ -93,7 +93,7 @@ func TestSpeculativeSCRAM(t *testing.T) { // Assert that the driver sent hello with the speculative authentication message. assert.Equal(t, len(tc.payloads), len(conn.Written), "expected %d wire messages to be sent, got %d", len(tc.payloads), (conn.Written)) - helloCmd, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written) + helloCmd, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written) assert.Nil(t, err, "error parsing hello command: %v", err) assertCommandName(t, helloCmd, handshake.LegacyHello) @@ -177,7 +177,7 @@ func TestSpeculativeSCRAM(t *testing.T) { assert.Equal(t, numResponses, len(conn.Written), "expected %d wire messages to be sent, got %d", numResponses, len(conn.Written)) - hello, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written) + hello, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written) assert.Nil(t, err, "error parsing hello command: %v", err) assertCommandName(t, hello, handshake.LegacyHello) _, err = hello.LookupErr("speculativeAuthenticate") diff --git a/x/mongo/driver/auth/speculative_x509_test.go b/x/mongo/driver/auth/speculative_x509_test.go index 13fdf2b185..cf46de6ffd 100644 --- a/x/mongo/driver/auth/speculative_x509_test.go +++ b/x/mongo/driver/auth/speculative_x509_test.go @@ -58,7 +58,7 @@ func TestSpeculativeX509(t *testing.T) { assert.Equal(t, numResponses, len(conn.Written), "expected %d wire messages to be sent, got %d", numResponses, len(conn.Written)) - hello, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written) + hello, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written) assert.Nil(t, err, "error parsing hello command: %v", err) assertCommandName(t, hello, handshake.LegacyHello) @@ -103,7 +103,7 @@ func TestSpeculativeX509(t *testing.T) { assert.Equal(t, numResponses, len(conn.Written), "expected %d wire messages to be sent, got %d", numResponses, len(conn.Written)) - hello, err := drivertest.GetCommandFromMsgWireMessage(<-conn.Written) + hello, err := drivertest.GetCommandFromQueryWireMessage(<-conn.Written) assert.Nil(t, err, "error parsing hello command: %v", err) assertCommandName(t, hello, handshake.LegacyHello) _, err = hello.LookupErr("speculativeAuthenticate") diff --git a/x/mongo/driver/drivertest/channel_conn.go b/x/mongo/driver/drivertest/channel_conn.go index d2ae8df248..27be4c264d 100644 --- a/x/mongo/driver/drivertest/channel_conn.go +++ b/x/mongo/driver/drivertest/channel_conn.go @@ -99,6 +99,38 @@ func MakeReply(doc bsoncore.Document) []byte { return bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:]))) } +// GetCommandFromQueryWireMessage returns the command sent in an OP_QUERY wire message. +func GetCommandFromQueryWireMessage(wm []byte) (bsoncore.Document, error) { + var ok bool + _, _, _, _, wm, ok = wiremessage.ReadHeader(wm) + if !ok { + return nil, errors.New("could not read header") + } + _, wm, ok = wiremessage.ReadQueryFlags(wm) + if !ok { + return nil, errors.New("could not read flags") + } + _, wm, ok = wiremessage.ReadQueryFullCollectionName(wm) + if !ok { + return nil, errors.New("could not read fullCollectionName") + } + _, wm, ok = wiremessage.ReadQueryNumberToSkip(wm) + if !ok { + return nil, errors.New("could not read numberToSkip") + } + _, wm, ok = wiremessage.ReadQueryNumberToReturn(wm) + if !ok { + return nil, errors.New("could not read numberToReturn") + } + + var query bsoncore.Document + query, wm, ok = wiremessage.ReadQueryQuery(wm) + if !ok { + return nil, errors.New("could not read query") + } + return query, nil +} + // GetCommandFromMsgWireMessage returns the command document sent in an OP_MSG wire message. func GetCommandFromMsgWireMessage(wm []byte) (bsoncore.Document, error) { var ok bool diff --git a/x/mongo/driver/legacy.go b/x/mongo/driver/legacy.go index 9f3b8a39ac..c40f1f8091 100644 --- a/x/mongo/driver/legacy.go +++ b/x/mongo/driver/legacy.go @@ -19,4 +19,5 @@ const ( LegacyKillCursors LegacyListCollections LegacyListIndexes + LegacyHandshake ) diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 8e52773503..229988e133 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -19,6 +19,7 @@ import ( "time" "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/bsontype" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/event" "go.mongodb.org/mongo-driver/internal/csot" @@ -629,7 +630,7 @@ func (op Operation) Execute(ctx context.Context) error { } var startedInfo startedInformation - *wm, startedInfo, err = op.createMsgWireMessage(ctx, maxTimeMS, (*wm)[:0], desc, conn, requestID) + *wm, startedInfo, err = op.createWireMessage(ctx, maxTimeMS, (*wm)[:0], desc, conn, requestID) if err != nil { return err @@ -1103,6 +1104,85 @@ func (op Operation) addBatchArray(dst []byte) []byte { return dst } +func (op Operation) createLegacyHandshakeWireMessage( + maxTimeMS uint64, + dst []byte, + desc description.SelectedServer, +) ([]byte, startedInformation, error) { + var info startedInformation + flags := op.secondaryOK(desc) + var wmindex int32 + info.requestID = wiremessage.NextRequestID() + wmindex, dst = wiremessage.AppendHeaderStart(dst, info.requestID, 0, wiremessage.OpQuery) + dst = wiremessage.AppendQueryFlags(dst, flags) + + dollarCmd := [...]byte{'.', '$', 'c', 'm', 'd'} + + // FullCollectionName + dst = append(dst, op.Database...) + dst = append(dst, dollarCmd[:]...) + dst = append(dst, 0x00) + dst = wiremessage.AppendQueryNumberToSkip(dst, 0) + dst = wiremessage.AppendQueryNumberToReturn(dst, -1) + + wrapper := int32(-1) + rp, err := op.createReadPref(desc, true) + if err != nil { + return dst, info, err + } + if len(rp) > 0 { + wrapper, dst = bsoncore.AppendDocumentStart(dst) + dst = bsoncore.AppendHeader(dst, bsontype.EmbeddedDocument, "$query") + } + idx, dst := bsoncore.AppendDocumentStart(dst) + dst, err = op.CommandFn(dst, desc) + if err != nil { + return dst, info, err + } + + if op.Batches != nil && len(op.Batches.Current) > 0 { + dst = op.addBatchArray(dst) + } + + dst, err = op.addReadConcern(dst, desc) + if err != nil { + return dst, info, err + } + + dst, err = op.addWriteConcern(dst, desc) + if err != nil { + return dst, info, err + } + + dst, err = op.addSession(dst, desc) + if err != nil { + return dst, info, err + } + + dst = op.addClusterTime(dst, desc) + dst = op.addServerAPI(dst) + // If maxTimeMS is greater than 0 append it to wire message. A maxTimeMS value of 0 only explicitly + // specifies the default behavior of no timeout server-side. + if maxTimeMS > 0 { + dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", int64(maxTimeMS)) + } + + dst, _ = bsoncore.AppendDocumentEnd(dst, idx) + // Command monitoring only reports the document inside $query + info.cmd = dst[idx:] + + if len(rp) > 0 { + var err error + dst = bsoncore.AppendDocumentElement(dst, "$readPreference", rp) + dst, err = bsoncore.AppendDocumentEnd(dst, wrapper) + if err != nil { + return dst, info, err + } + } + + return bsoncore.UpdateLength(dst, wmindex, int32(len(dst[wmindex:]))), info, nil +} + func (op Operation) createMsgWireMessage( ctx context.Context, maxTimeMS uint64, @@ -1191,6 +1271,29 @@ func (op Operation) createMsgWireMessage( return bsoncore.UpdateLength(dst, wmindex, int32(len(dst[wmindex:]))), info, nil } +// isLegacyHandshake returns True if the operation is the first message of +// the initial handshake and should use a legacy hello. +func isLegacyHandshake(op Operation, desc description.SelectedServer) bool { + isInitialHandshake := desc.WireVersion == nil || desc.WireVersion.Max == 0 + + return op.Legacy == LegacyHandshake && isInitialHandshake +} + +func (op Operation) createWireMessage( + ctx context.Context, + maxTimeMS uint64, + dst []byte, + desc description.SelectedServer, + conn Connection, + requestID int32, +) ([]byte, startedInformation, error) { + if isLegacyHandshake(op, desc) { + return op.createLegacyHandshakeWireMessage(maxTimeMS, dst, desc) + } + + return op.createMsgWireMessage(ctx, maxTimeMS, dst, desc, conn, requestID) +} + // addCommandFields adds the fields for a command to the wire message in dst. This assumes that the start of the document // has already been added and does not add the final 0 byte. func (op Operation) addCommandFields(ctx context.Context, dst []byte, desc description.SelectedServer) ([]byte, error) { diff --git a/x/mongo/driver/operation/hello.go b/x/mongo/driver/operation/hello.go index 3cfa2d450a..16d5809130 100644 --- a/x/mongo/driver/operation/hello.go +++ b/x/mongo/driver/operation/hello.go @@ -537,8 +537,16 @@ func (h *Hello) StreamResponse(ctx context.Context, conn driver.StreamerConnecti return h.createOperation().ExecuteExhaust(ctx, conn) } +// isLegacyHandshake returns True if server API version is not requested and +// loadBalanced is False. If this is the case, then the drivers MUST use legacy +// hello for the first message of the initial handshake with the OP_QUERY +// protocol +func isLegacyHandshake(srvAPI *driver.ServerAPIOptions, deployment driver.Deployment) bool { + return srvAPI == nil && deployment.Kind() != description.LoadBalanced +} + func (h *Hello) createOperation() driver.Operation { - return driver.Operation{ + op := driver.Operation{ Clock: h.clock, CommandFn: h.command, Database: "admin", @@ -549,23 +557,36 @@ func (h *Hello) createOperation() driver.Operation { }, ServerAPI: h.serverAPI, } + + if isLegacyHandshake(h.serverAPI, h.d) { + op.Legacy = driver.LegacyHandshake + } + + return op } // GetHandshakeInformation performs the MongoDB handshake for the provided connection and returns the relevant // information about the server. This function implements the driver.Handshaker interface. func (h *Hello) GetHandshakeInformation(ctx context.Context, _ address.Address, c driver.Connection) (driver.HandshakeInformation, error) { - err := driver.Operation{ + deployment := driver.SingleConnectionDeployment{C: c} + + op := driver.Operation{ Clock: h.clock, CommandFn: h.handshakeCommand, - Deployment: driver.SingleConnectionDeployment{C: c}, + Deployment: deployment, Database: "admin", ProcessResponseFn: func(info driver.ResponseInfo) error { h.res = info.ServerResponse return nil }, ServerAPI: h.serverAPI, - }.Execute(ctx) - if err != nil { + } + + if isLegacyHandshake(h.serverAPI, deployment) { + op.Legacy = driver.LegacyHandshake + } + + if err := op.Execute(ctx); err != nil { return driver.HandshakeInformation{}, err } @@ -578,6 +599,9 @@ func (h *Hello) GetHandshakeInformation(ctx context.Context, _ address.Address, if serverConnectionID, ok := h.res.Lookup("connectionId").AsInt64OK(); ok { info.ServerConnectionID = &serverConnectionID } + + var err error + // Cast to bson.Raw to lookup saslSupportedMechs to avoid converting from bsoncore.Value to bson.RawValue for the // StringSliceFromRawValue call. if saslSupportedMechs, lookupErr := bson.Raw(h.res).LookupErr("saslSupportedMechs"); lookupErr == nil { diff --git a/x/mongo/driver/topology/server_test.go b/x/mongo/driver/topology/server_test.go index a2abd1fb1f..ba92b6dd94 100644 --- a/x/mongo/driver/topology/server_test.go +++ b/x/mongo/driver/topology/server_test.go @@ -31,9 +31,11 @@ import ( "go.mongodb.org/mongo-driver/internal/require" "go.mongodb.org/mongo-driver/mongo/address" "go.mongodb.org/mongo-driver/mongo/description" + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/x/mongo/driver" "go.mongodb.org/mongo-driver/x/mongo/driver/auth" "go.mongodb.org/mongo-driver/x/mongo/driver/drivertest" + "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage" ) type channelNetConnDialer struct{} @@ -1207,12 +1209,41 @@ func TestServer_ProcessError(t *testing.T) { func includesClientMetadata(t *testing.T, wm []byte) bool { t.Helper() - doc, err := drivertest.GetCommandFromMsgWireMessage(wm) - assert.NoError(t, err) + var ok bool + _, _, _, _, wm, ok = wiremessage.ReadHeader(wm) + if !ok { + t.Fatal("could not read header") + } + _, wm, ok = wiremessage.ReadQueryFlags(wm) + if !ok { + t.Fatal("could not read flags") + } + _, wm, ok = wiremessage.ReadQueryFullCollectionName(wm) + if !ok { + t.Fatal("could not read fullCollectionName") + } + _, wm, ok = wiremessage.ReadQueryNumberToSkip(wm) + if !ok { + t.Fatal("could not read numberToSkip") + } + _, wm, ok = wiremessage.ReadQueryNumberToReturn(wm) + if !ok { + t.Fatal("could not read numberToReturn") + } + var query bsoncore.Document + query, wm, ok = wiremessage.ReadQueryQuery(wm) + if !ok { + t.Fatal("could not read query") + } - _, err = doc.LookupErr("client") + if _, err := query.LookupErr("client"); err == nil { + return true + } + if _, err := query.LookupErr("$query", "client"); err == nil { + return true + } - return err == nil + return false } // processErrorTestConn is a driver.Connection implementation used by tests