From 5202542272753da6904925998a33b83dca56ed2e Mon Sep 17 00:00:00 2001 From: Qingyang Hu Date: Fri, 22 Sep 2023 17:26:40 -0400 Subject: [PATCH] Revert mongo.Connect() changes --- benchmark/operation_test.go | 21 ++++++++++++++--- benchmark/single.go | 5 +++- mongo/client.go | 29 ++++++++++++++---------- mongo/client_test.go | 18 +++++++-------- mongo/database_test.go | 2 +- mongo/integration/client_options_test.go | 4 +++- mongo/integration/client_test.go | 2 +- mongo/integration/mtest/mongotest.go | 9 +++++--- mongo/mongocryptd.go | 4 ++-- x/mongo/driver/topology/topology_test.go | 2 +- 10 files changed, 62 insertions(+), 34 deletions(-) diff --git a/benchmark/operation_test.go b/benchmark/operation_test.go index 04e618e7bc..80f20ddc75 100644 --- a/benchmark/operation_test.go +++ b/benchmark/operation_test.go @@ -32,10 +32,15 @@ func BenchmarkClientWrite(b *testing.B) { } for _, bm := range benchmarks { b.Run(bm.name, func(b *testing.B) { - client, err := mongo.Connect(context.Background(), bm.opt) + client, err := mongo.NewClient(bm.opt) if err != nil { b.Fatalf("error creating client: %v", err) } + ctx := context.Background() + err = client.Connect(ctx) + if err != nil { + b.Fatalf("error connecting: %v", err) + } defer client.Disconnect(context.Background()) coll := client.Database("test").Collection("test") _, err = coll.DeleteMany(context.Background(), bson.D{}) @@ -71,10 +76,15 @@ func BenchmarkClientBulkWrite(b *testing.B) { } for _, bm := range benchmarks { b.Run(bm.name, func(b *testing.B) { - client, err := mongo.Connect(context.Background(), bm.opt) + client, err := mongo.NewClient(bm.opt) if err != nil { b.Fatalf("error creating client: %v", err) } + ctx := context.Background() + err = client.Connect(ctx) + if err != nil { + b.Fatalf("error connecting: %v", err) + } defer client.Disconnect(context.Background()) coll := client.Database("test").Collection("test") _, err = coll.DeleteMany(context.Background(), bson.D{}) @@ -115,10 +125,15 @@ func BenchmarkClientRead(b *testing.B) { } for _, bm := range benchmarks { b.Run(bm.name, func(b *testing.B) { - client, err := mongo.Connect(context.Background(), bm.opt) + client, err := mongo.NewClient(bm.opt) if err != nil { b.Fatalf("error creating client: %v", err) } + ctx := context.Background() + err = client.Connect(ctx) + if err != nil { + b.Fatalf("error connecting: %v", err) + } defer client.Disconnect(context.Background()) coll := client.Database("test").Collection("test") _, err = coll.DeleteMany(context.Background(), bson.D{}) diff --git a/benchmark/single.go b/benchmark/single.go index 333a8f66be..b85b46f34f 100644 --- a/benchmark/single.go +++ b/benchmark/single.go @@ -29,10 +29,13 @@ func getClientDB(ctx context.Context) (*mongo.Database, error) { if err != nil { return nil, err } - client, err := mongo.Connect(ctx, options.Client().ApplyURI(cs.String())) + client, err := mongo.NewClient(options.Client().ApplyURI(cs.String())) if err != nil { return nil, err } + if err = client.Connect(ctx); err != nil { + return nil, err + } db := client.Database(integtest.GetDBName(cs)) return db, nil diff --git a/mongo/client.go b/mongo/client.go index 520250e43e..5929274831 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -81,7 +81,8 @@ type Client struct { encryptedFieldsMap map[string]interface{} } -// Connect creates a new Client and then initializes it using the Connect method. +// Connect creates a new Client and then initializes it using the Connect method. This is equivalent to calling +// NewClient followed by Client.Connect. // // When creating an options.ClientOptions, the order the methods are called matters. Later Set* // methods will overwrite the values from previous Set* method invocations. This includes the @@ -103,18 +104,18 @@ type Client struct { // The Client.Ping method can be used to verify that the deployment is successfully connected and the // Client was correctly configured. func Connect(ctx context.Context, opts ...*options.ClientOptions) (*Client, error) { - c, err := newClient(opts...) + c, err := NewClient(opts...) if err != nil { return nil, err } - err = c.connect(ctx) + err = c.Connect(ctx) if err != nil { return nil, err } return c, nil } -// newClient creates a new client to connect to a deployment specified by the uri. +// NewClient creates a new client to connect to a deployment specified by the uri. // // When creating an options.ClientOptions, the order the methods are called matters. Later Set* // methods will overwrite the values from previous Set* method invocations. This includes the @@ -127,7 +128,9 @@ func Connect(ctx context.Context, opts ...*options.ClientOptions) (*Client, erro // option fields of previous options, there is no partial overwriting. For example, if Username is // set in the Auth field for the first option, and Password is set for the second but with no // Username, after the merge the Username field will be empty. -func newClient(opts ...*options.ClientOptions) (*Client, error) { +// +// Deprecated: Use [Connect] instead. +func NewClient(opts ...*options.ClientOptions) (*Client, error) { clientOpt := options.MergeClientOptions(opts...) id, err := uuid.New() @@ -232,12 +235,14 @@ func newClient(opts ...*options.ClientOptions) (*Client, error) { return client, nil } -// connect initializes the Client by starting background monitoring goroutines. +// Connect initializes the Client by starting background monitoring goroutines. // If the Client was created using the NewClient function, this method must be called before a Client can be used. // // Connect starts background goroutines to monitor the state of the deployment and does not do any I/O in the main // goroutine. The Client.Ping method can be used to verify that the connection was created successfully. -func (c *Client) connect(ctx context.Context) error { +// +// Deprecated: Use [mongo.Connect] instead. +func (c *Client) Connect(ctx context.Context) error { if connector, ok := c.deployment.(driver.Connector); ok { err := connector.Connect() if err != nil { @@ -252,19 +257,19 @@ func (c *Client) connect(ctx context.Context) error { } if c.internalClientFLE != nil { - if err := c.internalClientFLE.connect(ctx); err != nil { + if err := c.internalClientFLE.Connect(ctx); err != nil { return err } } if c.keyVaultClientFLE != nil && c.keyVaultClientFLE != c.internalClientFLE && c.keyVaultClientFLE != c { - if err := c.keyVaultClientFLE.connect(ctx); err != nil { + if err := c.keyVaultClientFLE.Connect(ctx); err != nil { return err } } if c.metadataClientFLE != nil && c.metadataClientFLE != c.internalClientFLE && c.metadataClientFLE != c { - if err := c.metadataClientFLE.connect(ctx); err != nil { + if err := c.metadataClientFLE.Connect(ctx); err != nil { return err } } @@ -484,7 +489,7 @@ func (c *Client) getOrCreateInternalClient(clientOpts *options.ClientOptions) (* internalClientOpts.AutoEncryptionOptions = nil internalClientOpts.SetMinPoolSize(0) var err error - c.internalClientFLE, err = newClient(internalClientOpts) + c.internalClientFLE, err = NewClient(internalClientOpts) return c.internalClientFLE, err } @@ -494,7 +499,7 @@ func (c *Client) configureKeyVaultClientFLE(clientOpts *options.ClientOptions) e aeOpts := clientOpts.AutoEncryptionOptions switch { case aeOpts.KeyVaultClientOptions != nil: - c.keyVaultClientFLE, err = newClient(aeOpts.KeyVaultClientOptions) + c.keyVaultClientFLE, err = NewClient(aeOpts.KeyVaultClientOptions) case clientOpts.MaxPoolSize != nil && *clientOpts.MaxPoolSize == 0: c.keyVaultClientFLE = c default: diff --git a/mongo/client_test.go b/mongo/client_test.go index abda6002a5..e1e8f322a8 100644 --- a/mongo/client_test.go +++ b/mongo/client_test.go @@ -35,7 +35,7 @@ func setupClient(opts ...*options.ClientOptions) *Client { integtest.AddTestServerAPIVersion(clientOpts) opts = append(opts, clientOpts) } - client, _ := newClient(opts...) + client, _ := NewClient(opts...) return client } @@ -183,7 +183,7 @@ func TestClient(t *testing.T) { } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - _, err := newClient(tc.opts) + _, err := NewClient(tc.opts) assert.Equal(t, tc.err, err, "expected error %v, got %v", tc.err, err) }) } @@ -227,7 +227,7 @@ func TestClient(t *testing.T) { } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - _, err := newClient(tc.opts) + _, err := NewClient(tc.opts) assert.Equal(t, tc.err, err, "expected error %v, got %v", tc.err, err) }) } @@ -249,7 +249,7 @@ func TestClient(t *testing.T) { } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - client, err := newClient(tc.opts) + client, err := NewClient(tc.opts) if tc.expectErr { assert.NotNil(t, err, "expected error, got nil") return @@ -277,7 +277,7 @@ func TestClient(t *testing.T) { } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - client, err := newClient(tc.opts) + client, err := NewClient(tc.opts) if tc.expectErr { assert.NotNil(t, err, "expected error, got nil") return @@ -412,7 +412,7 @@ func TestClient(t *testing.T) { t.Run("success with all options", func(t *testing.T) { serverAPIOptions := getServerAPIOptions() - client, err := newClient(options.Client().SetServerAPIOptions(serverAPIOptions)) + client, err := NewClient(options.Client().SetServerAPIOptions(serverAPIOptions)) assert.Nil(t, err, "unexpected error from NewClient: %v", err) convertedAPIOptions := topology.ConvertToDriverAPIOptions(serverAPIOptions) assert.Equal(t, convertedAPIOptions, client.serverAPI, @@ -420,14 +420,14 @@ func TestClient(t *testing.T) { }) t.Run("failure with unsupported version", func(t *testing.T) { serverAPIOptions := options.ServerAPI("badVersion") - _, err := newClient(options.Client().SetServerAPIOptions(serverAPIOptions)) + _, err := NewClient(options.Client().SetServerAPIOptions(serverAPIOptions)) assert.NotNil(t, err, "expected error from NewClient, got nil") errmsg := `api version "badVersion" not supported; this driver version only supports API version "1"` assert.Equal(t, errmsg, err.Error(), "expected error %v, got %v", errmsg, err.Error()) }) t.Run("cannot modify options after client creation", func(t *testing.T) { serverAPIOptions := getServerAPIOptions() - client, err := newClient(options.Client().SetServerAPIOptions(serverAPIOptions)) + client, err := NewClient(options.Client().SetServerAPIOptions(serverAPIOptions)) assert.Nil(t, err, "unexpected error from NewClient: %v", err) expectedServerAPIOptions := getServerAPIOptions() @@ -476,7 +476,7 @@ func TestClient(t *testing.T) { extraOptions["__cryptSharedLibDisabledForTestOnly"] = true } - _, err := newClient(options.Client(). + _, err := NewClient(options.Client(). SetAutoEncryptionOptions(options.AutoEncryption(). SetKmsProviders(map[string]map[string]interface{}{ "local": {"key": make([]byte, 96)}, diff --git a/mongo/database_test.go b/mongo/database_test.go index 276080be79..46e1fc3f19 100644 --- a/mongo/database_test.go +++ b/mongo/database_test.go @@ -97,7 +97,7 @@ func TestDatabase(t *testing.T) { }) t.Run("TransientTransactionError label", func(t *testing.T) { client := setupClient(options.Client().ApplyURI("mongodb://nonexistent").SetServerSelectionTimeout(3 * time.Second)) - err := client.connect(bgCtx) + err := client.Connect(bgCtx) defer client.Disconnect(bgCtx) assert.Nil(t, err, "expected nil, got %v", err) diff --git a/mongo/integration/client_options_test.go b/mongo/integration/client_options_test.go index 43703d5e33..0fb068bc5e 100644 --- a/mongo/integration/client_options_test.go +++ b/mongo/integration/client_options_test.go @@ -24,7 +24,9 @@ func TestClientOptions_CustomDialer(t *testing.T) { cs := integtest.ConnString(t) opts := options.Client().ApplyURI(cs.String()).SetDialer(td) integtest.AddTestServerAPIVersion(opts) - client, err := mongo.Connect(context.Background(), opts) + client, err := mongo.NewClient(opts) + require.NoError(t, err) + err = client.Connect(context.Background()) require.NoError(t, err) _, err = client.ListDatabases(context.Background(), bson.D{}) require.NoError(t, err) diff --git a/mongo/integration/client_test.go b/mongo/integration/client_test.go index d76213e908..038ed25d72 100644 --- a/mongo/integration/client_test.go +++ b/mongo/integration/client_test.go @@ -341,7 +341,7 @@ func TestClient(t *testing.T) { }) mt.RunOpts("watch", noClientOpts, func(mt *mtest.T) { mt.Run("disconnected", func(mt *mtest.T) { - c, err := mongo.Connect(context.Background(), options.Client().ApplyURI(mtest.ClusterURI())) + c, err := mongo.NewClient(options.Client().ApplyURI(mtest.ClusterURI())) assert.Nil(mt, err, "NewClient error: %v", err) _, err = c.Watch(context.Background(), mongo.Pipeline{}) assert.Equal(mt, mongo.ErrClientDisconnected, err, "expected error %v, got %v", mongo.ErrClientDisconnected, err) diff --git a/mongo/integration/mtest/mongotest.go b/mongo/integration/mtest/mongotest.go index 78f21b0c12..d5235d228e 100644 --- a/mongo/integration/mtest/mongotest.go +++ b/mongo/integration/mtest/mongotest.go @@ -689,13 +689,13 @@ func (t *T) createTestClient() { // pin to first mongos pinnedHostList := []string{testContext.connString.Hosts[0]} uriOpts := options.Client().ApplyURI(testContext.connString.Original).SetHosts(pinnedHostList) - t.Client, err = mongo.Connect(context.Background(), uriOpts, clientOpts) + t.Client, err = mongo.NewClient(uriOpts, clientOpts) case Mock: // clear pool monitor to avoid configuration error clientOpts.PoolMonitor = nil t.mockDeployment = newMockDeployment() clientOpts.Deployment = t.mockDeployment - t.Client, err = mongo.Connect(context.Background(), clientOpts) + t.Client, err = mongo.NewClient(clientOpts) case Proxy: t.proxyDialer = newProxyDialer() clientOpts.SetDialer(t.proxyDialer) @@ -713,11 +713,14 @@ func (t *T) createTestClient() { } // Pass in uriOpts first so clientOpts wins if there are any conflicting settings. - t.Client, err = mongo.Connect(context.Background(), uriOpts, clientOpts) + t.Client, err = mongo.NewClient(uriOpts, clientOpts) } if err != nil { t.Fatalf("error creating client: %v", err) } + if err := t.Client.Connect(context.Background()); err != nil { + t.Fatalf("error connecting client: %v", err) + } } func (t *T) createTestCollection() { diff --git a/mongo/mongocryptd.go b/mongo/mongocryptd.go index bd86c3ced5..41aebc76c1 100644 --- a/mongo/mongocryptd.go +++ b/mongo/mongocryptd.go @@ -74,7 +74,7 @@ func newMongocryptdClient(opts *options.AutoEncryptionOptions) (*mongocryptdClie } // create client - client, err := Connect(context.Background(), options.Client().ApplyURI(uri).SetServerSelectionTimeout(defaultServerSelectionTimeout)) + client, err := NewClient(options.Client().ApplyURI(uri).SetServerSelectionTimeout(defaultServerSelectionTimeout)) if err != nil { return nil, err } @@ -114,7 +114,7 @@ func (mc *mongocryptdClient) markCommand(ctx context.Context, dbName string, cmd // connect connects the underlying Client instance. This must be called before performing any mark operations. func (mc *mongocryptdClient) connect(ctx context.Context) error { - return mc.client.connect(ctx) + return mc.client.Connect(ctx) } // disconnect disconnects the underlying Client instance. This should be called after all operations have completed. diff --git a/x/mongo/driver/topology/topology_test.go b/x/mongo/driver/topology/topology_test.go index a99728f9c8..909e2debf1 100644 --- a/x/mongo/driver/topology/topology_test.go +++ b/x/mongo/driver/topology/topology_test.go @@ -607,7 +607,7 @@ func TestSessionTimeout(t *testing.T) { currDesc := topo.desc.Load().(description.Topology) require.Nil(t, currDesc.SessionTimeoutMinutesPtr, - "session timeout minutes mismatch. got: %d. expected: nil", *currDesc.SessionTimeoutMinutesPtr) + "session timeout minutes mismatch. expected: nil") }) }