diff --git a/mongo/integration/unified_spec_test.go b/mongo/integration/unified_spec_test.go index c38548ec70..4da42e6a68 100644 --- a/mongo/integration/unified_spec_test.go +++ b/mongo/integration/unified_spec_test.go @@ -458,46 +458,64 @@ func executeTestRunnerOperation(mt *mtest.T, testCase *testCase, op *operation, var fp mtest.FailPoint if err := bson.Unmarshal(fpDoc.Document(), &fp); err != nil { - return fmt.Errorf("Unmarshal error: %v", err) + return fmt.Errorf("Unmarshal error: %w", err) } + if clientSession == nil { + return errors.New("expected valid session, got nil") + } targetHost := clientSession.PinnedServer.Addr.String() opts := options.Client().ApplyURI(mtest.ClusterURI()).SetHosts([]string{targetHost}) integtest.AddTestServerAPIVersion(opts) client, err := mongo.Connect(context.Background(), opts) if err != nil { - return fmt.Errorf("Connect error for targeted client: %v", err) + return fmt.Errorf("Connect error for targeted client: %w", err) } defer func() { _ = client.Disconnect(context.Background()) }() if err = client.Database("admin").RunCommand(context.Background(), fp).Err(); err != nil { - return fmt.Errorf("error setting targeted fail point: %v", err) + return fmt.Errorf("error setting targeted fail point: %w", err) } mt.TrackFailPoint(fp.ConfigureFailPoint) case "configureFailPoint": fp, err := op.Arguments.LookupErr("failPoint") - assert.Nil(mt, err, "failPoint not found in arguments") + if err != nil { + return fmt.Errorf("unable to find 'failPoint' in arguments: %w", err) + } mt.SetFailPointFromDocument(fp.Document()) case "assertSessionTransactionState": stateVal, err := op.Arguments.LookupErr("state") - assert.Nil(mt, err, "state not found in arguments") + if err != nil { + return fmt.Errorf("unable to find 'state' in arguments: %w", err) + } expectedState, ok := stateVal.StringValueOK() - assert.True(mt, ok, "state argument is not a string") + if !ok { + return errors.New("expected 'state' argument to be string") + } - assert.NotNil(mt, clientSession, "expected valid session, got nil") + if clientSession == nil { + return errors.New("expected valid session, got nil") + } actualState := clientSession.TransactionState.String() // actualState should match expectedState, but "in progress" is the same as // "in_progress". stateMatch := actualState == expectedState || actualState == "in progress" && expectedState == "in_progress" - assert.True(mt, stateMatch, "expected transaction state %v, got %v", - expectedState, actualState) + if !stateMatch { + return fmt.Errorf("expected transaction state %v, got %v", expectedState, actualState) + } case "assertSessionPinned": + if clientSession == nil { + return errors.New("expected valid session, got nil") + } if clientSession.PinnedServer == nil { return errors.New("expected pinned server, got nil") } case "assertSessionUnpinned": + if clientSession == nil { + return errors.New("expected valid session, got nil") + } // We don't use a combined helper for assertSessionPinned and assertSessionUnpinned because the unpinned // case provides the pinned server address in the error msg for debugging. if clientSession.PinnedServer != nil { @@ -540,7 +558,7 @@ func executeTestRunnerOperation(mt *mtest.T, testCase *testCase, op *operation, case "waitForThread": waitForThread(mt, testCase, op) default: - mt.Fatalf("unrecognized testRunner operation %v", op.Name) + return fmt.Errorf("unrecognized testRunner operation %v", op.Name) } return nil @@ -567,7 +585,7 @@ func indexExists(dbName, collName, indexName string) (bool, error) { iv := mtest.GlobalClient().Database(dbName).Collection(collName).Indexes() cursor, err := iv.List(context.Background()) if err != nil { - return false, fmt.Errorf("IndexView.List error: %v", err) + return false, fmt.Errorf("IndexView.List error: %w", err) } defer cursor.Close(context.Background()) @@ -602,7 +620,7 @@ func collectionExists(dbName, collName string) (bool, error) { // Use global client because listCollections cannot be executed inside a transaction. collections, err := mtest.GlobalClient().Database(dbName).ListCollectionNames(context.Background(), filter) if err != nil { - return false, fmt.Errorf("ListCollectionNames error: %v", err) + return false, fmt.Errorf("ListCollectionNames error: %w", err) } return len(collections) > 0, nil @@ -632,9 +650,8 @@ func executeSessionOperation(mt *mtest.T, op *operation, sess mongo.Session) err case "withTransaction": return executeWithTransaction(mt, sess, op.Arguments) default: - mt.Fatalf("unrecognized session operation: %v", op.Name) + return fmt.Errorf("unrecognized session operation: %v", op.Name) } - return nil } func executeCollectionOperation(mt *mtest.T, op *operation, sess mongo.Session) error {