Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GODRIVER-2935 Use OP_QUERY in connection handshakes #1377

Merged
merged 7 commits into from
Sep 7, 2023
65 changes: 65 additions & 0 deletions mongo/integration/mtest/sent_message.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ type sentMsgParseFn func([]byte) (*SentMessage, error)

func getSentMessageParser(opcode wiremessage.OpCode) (sentMsgParseFn, bool) {
switch opcode {
case wiremessage.OpQuery:
return parseOpQuery, true
case wiremessage.OpMsg:
return parseSentOpMsg, true
case wiremessage.OpCompressed:
Expand All @@ -46,6 +48,69 @@ func getSentMessageParser(opcode wiremessage.OpCode) (sentMsgParseFn, bool) {
}
}

func parseOpQuery(wm []byte) (*SentMessage, error) {
Copy link
Collaborator Author

@prestonvasquez prestonvasquez Sep 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a direct copy and paste of the code removed in this commit: 4449617

var ok bool

if _, wm, ok = wiremessage.ReadQueryFlags(wm); !ok {
return nil, errors.New("failed to read query flags")
}
if _, wm, ok = wiremessage.ReadQueryFullCollectionName(wm); !ok {
return nil, errors.New("failed to read full collection name")
}
if _, wm, ok = wiremessage.ReadQueryNumberToSkip(wm); !ok {
return nil, errors.New("failed to read number to skip")
}
if _, wm, ok = wiremessage.ReadQueryNumberToReturn(wm); !ok {
return nil, errors.New("failed to read number to return")
}

query, wm, ok := wiremessage.ReadQueryQuery(wm)
if !ok {
return nil, errors.New("failed to read query")
}

// If there is no read preference document, the command document is query.
// Otherwise, query is in the format {$query: <command document>, $readPreference: <read preference document>}.
commandDoc := query
var rpDoc bsoncore.Document

dollarQueryVal, err := query.LookupErr("$query")
if err == nil {
commandDoc = dollarQueryVal.Document()

rpVal, err := query.LookupErr("$readPreference")
if err != nil {
return nil, fmt.Errorf("query %s contains $query but not $readPreference fields", query)
}
rpDoc = rpVal.Document()
}

// For OP_QUERY, inserts, updates, and deletes are sent as a BSON array of documents inside the main command
// document. Pull these sequences out into an ArrayStyle DocumentSequence.
var docSequence *bsoncore.DocumentSequence
cmdElems, _ := commandDoc.Elements()
for _, elem := range cmdElems {
switch elem.Key() {
case "documents", "updates", "deletes":
docSequence = &bsoncore.DocumentSequence{
Style: bsoncore.ArrayStyle,
Data: elem.Value().Array(),
}
}
if docSequence != nil {
// There can only be one of these arrays in a well-formed command, so we exit the loop once one is found.
break
}
}

sm := &SentMessage{
Command: commandDoc,
ReadPreference: rpDoc,
DocumentSequence: docSequence,
}
return sm, nil
}

func parseSentMessage(wm []byte) (*SentMessage, error) {
// Re-assign the wire message to "remaining" so "wm" continues to point to the entire message after parsing.
_, requestID, _, opcode, remaining, ok := wiremessage.ReadHeader(wm)
Expand Down
111 changes: 110 additions & 1 deletion x/mongo/driver/operation.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"time"

"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/bsontype"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/event"
"go.mongodb.org/mongo-driver/internal/csot"
Expand Down Expand Up @@ -629,7 +630,7 @@ func (op Operation) Execute(ctx context.Context) error {
}

var startedInfo startedInformation
*wm, startedInfo, err = op.createMsgWireMessage(ctx, maxTimeMS, (*wm)[:0], desc, conn)
*wm, startedInfo, err = op.createWireMessage(ctx, (*wm)[:0], desc, maxTimeMS, conn)

if err != nil {
return err
Expand Down Expand Up @@ -1103,6 +1104,85 @@ func (op Operation) addBatchArray(dst []byte) []byte {
return dst
}

func (op Operation) createLegacyHandshakeWireMessage(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a direct copy and paste of the code removed in this commit: 4449617

maxTimeMS uint64,
dst []byte,
desc description.SelectedServer,
) ([]byte, startedInformation, error) {
var info startedInformation
flags := op.secondaryOK(desc)
var wmindex int32
info.requestID = wiremessage.NextRequestID()
wmindex, dst = wiremessage.AppendHeaderStart(dst, info.requestID, 0, wiremessage.OpQuery)
dst = wiremessage.AppendQueryFlags(dst, flags)

dollarCmd := [...]byte{'.', '$', 'c', 'm', 'd'}

// FullCollectionName
dst = append(dst, op.Database...)
dst = append(dst, dollarCmd[:]...)
dst = append(dst, 0x00)
dst = wiremessage.AppendQueryNumberToSkip(dst, 0)
dst = wiremessage.AppendQueryNumberToReturn(dst, -1)

wrapper := int32(-1)
rp, err := op.createReadPref(desc, true)
if err != nil {
return dst, info, err
}
if len(rp) > 0 {
wrapper, dst = bsoncore.AppendDocumentStart(dst)
dst = bsoncore.AppendHeader(dst, bsontype.EmbeddedDocument, "$query")
}
idx, dst := bsoncore.AppendDocumentStart(dst)
dst, err = op.CommandFn(dst, desc)
if err != nil {
return dst, info, err
}

if op.Batches != nil && len(op.Batches.Current) > 0 {
dst = op.addBatchArray(dst)
}

dst, err = op.addReadConcern(dst, desc)
if err != nil {
return dst, info, err
}

dst, err = op.addWriteConcern(dst, desc)
if err != nil {
return dst, info, err
}

dst, err = op.addSession(dst, desc)
if err != nil {
return dst, info, err
}

dst = op.addClusterTime(dst, desc)
dst = op.addServerAPI(dst)
// If maxTimeMS is greater than 0 append it to wire message. A maxTimeMS value of 0 only explicitly
// specifies the default behavior of no timeout server-side.
if maxTimeMS > 0 {
dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", int64(maxTimeMS))
}

dst, _ = bsoncore.AppendDocumentEnd(dst, idx)
// Command monitoring only reports the document inside $query
info.cmd = dst[idx:]

if len(rp) > 0 {
var err error
dst = bsoncore.AppendDocumentElement(dst, "$readPreference", rp)
dst, err = bsoncore.AppendDocumentEnd(dst, wrapper)
if err != nil {
return dst, info, err
}
}

return bsoncore.UpdateLength(dst, wmindex, int32(len(dst[wmindex:]))), info, nil
}

func (op Operation) createMsgWireMessage(ctx context.Context, maxTimeMS uint64, dst []byte, desc description.SelectedServer,
conn Connection,
) ([]byte, startedInformation, error) {
Expand Down Expand Up @@ -1186,6 +1266,33 @@ func (op Operation) createMsgWireMessage(ctx context.Context, maxTimeMS uint64,
return bsoncore.UpdateLength(dst, wmindex, int32(len(dst[wmindex:]))), info, nil
}

// isLegacyHandshake returns "true" if the operation is the first message of
// the initial handshake and should use a legacy hello. The requirement for
// using a legacy hello as defined by the specifications is as follows:
//
// > If server API version is not requested and loadBalanced: False, drivers
// > MUST use legacy hello for the first message of the initial handshake with
// > the OP_QUERY protocol
func isLegacyHandshake(op Operation, desc description.SelectedServer) bool {
isInitialHandshake := desc.WireVersion == nil || desc.WireVersion.Max == 0

return desc.Kind != description.LoadBalanced && op.ServerAPI == nil && isInitialHandshake
}

func (op Operation) createWireMessage(
ctx context.Context,
dst []byte,
desc description.SelectedServer,
maxTimeMS uint64,
conn Connection,
) ([]byte, startedInformation, error) {
if isLegacyHandshake(op, desc) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See here for the exact logic of this:

If server API version is not requested and loadBalanced: False, drivers MUST use legacy hello for the first message of the initial handshake with the OP_QUERY protocol (before switching to OP_MSG if the maxWireVersion indicates compatibility), and include helloOk:true in the handshake request.

return op.createLegacyHandshakeWireMessage(maxTimeMS, dst, desc)
}

return op.createMsgWireMessage(ctx, maxTimeMS, dst, desc, conn)
}

// addCommandFields adds the fields for a command to the wire message in dst. This assumes that the start of the document
// has already been added and does not add the final 0 byte.
func (op Operation) addCommandFields(ctx context.Context, dst []byte, desc description.SelectedServer) ([]byte, error) {
Expand Down Expand Up @@ -1830,6 +1937,8 @@ func (op Operation) publishFinishedEvent(ctx context.Context, info finishedInfor
logger.KeyFailure, formattedReply)...)
}

//fmt.Println("->", redactFinishedInformationResponse(info))

// If the finished event cannot be published, return early.
if !op.canPublishFinishedEvent(info) {
return
Expand Down
39 changes: 35 additions & 4 deletions x/mongo/driver/topology/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@ import (
"go.mongodb.org/mongo-driver/internal/require"
"go.mongodb.org/mongo-driver/mongo/address"
"go.mongodb.org/mongo-driver/mongo/description"
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/auth"
"go.mongodb.org/mongo-driver/x/mongo/driver/drivertest"
"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
)

type channelNetConnDialer struct{}
Expand Down Expand Up @@ -1207,12 +1209,41 @@ func TestServer_ProcessError(t *testing.T) {
func includesClientMetadata(t *testing.T, wm []byte) bool {
t.Helper()

doc, err := drivertest.GetCommandFromMsgWireMessage(wm)
assert.NoError(t, err)
var ok bool
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a direct copy and paste of the code removed in this commit: 4449617

_, _, _, _, wm, ok = wiremessage.ReadHeader(wm)
if !ok {
t.Fatal("could not read header")
}
_, wm, ok = wiremessage.ReadQueryFlags(wm)
if !ok {
t.Fatal("could not read flags")
}
_, wm, ok = wiremessage.ReadQueryFullCollectionName(wm)
if !ok {
t.Fatal("could not read fullCollectionName")
}
_, wm, ok = wiremessage.ReadQueryNumberToSkip(wm)
if !ok {
t.Fatal("could not read numberToSkip")
}
_, wm, ok = wiremessage.ReadQueryNumberToReturn(wm)
if !ok {
t.Fatal("could not read numberToReturn")
}
var query bsoncore.Document
query, wm, ok = wiremessage.ReadQueryQuery(wm)
if !ok {
t.Fatal("could not read query")
}

_, err = doc.LookupErr("client")
if _, err := query.LookupErr("client"); err == nil {
return true
}
if _, err := query.LookupErr("$query", "client"); err == nil {
return true
}

return err == nil
return false
}

// processErrorTestConn is a driver.Connection implementation used by tests
Expand Down
Loading