Skip to content

Commit

Permalink
Revert mongo.Connect() changes
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyang-hu committed Sep 22, 2023
1 parent be47df2 commit 5202542
Show file tree
Hide file tree
Showing 10 changed files with 62 additions and 34 deletions.
21 changes: 18 additions & 3 deletions benchmark/operation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,15 @@ func BenchmarkClientWrite(b *testing.B) {
}
for _, bm := range benchmarks {
b.Run(bm.name, func(b *testing.B) {
client, err := mongo.Connect(context.Background(), bm.opt)
client, err := mongo.NewClient(bm.opt)
if err != nil {
b.Fatalf("error creating client: %v", err)
}
ctx := context.Background()
err = client.Connect(ctx)
if err != nil {
b.Fatalf("error connecting: %v", err)
}
defer client.Disconnect(context.Background())
coll := client.Database("test").Collection("test")
_, err = coll.DeleteMany(context.Background(), bson.D{})
Expand Down Expand Up @@ -71,10 +76,15 @@ func BenchmarkClientBulkWrite(b *testing.B) {
}
for _, bm := range benchmarks {
b.Run(bm.name, func(b *testing.B) {
client, err := mongo.Connect(context.Background(), bm.opt)
client, err := mongo.NewClient(bm.opt)
if err != nil {
b.Fatalf("error creating client: %v", err)
}
ctx := context.Background()
err = client.Connect(ctx)
if err != nil {
b.Fatalf("error connecting: %v", err)
}
defer client.Disconnect(context.Background())
coll := client.Database("test").Collection("test")
_, err = coll.DeleteMany(context.Background(), bson.D{})
Expand Down Expand Up @@ -115,10 +125,15 @@ func BenchmarkClientRead(b *testing.B) {
}
for _, bm := range benchmarks {
b.Run(bm.name, func(b *testing.B) {
client, err := mongo.Connect(context.Background(), bm.opt)
client, err := mongo.NewClient(bm.opt)
if err != nil {
b.Fatalf("error creating client: %v", err)
}
ctx := context.Background()
err = client.Connect(ctx)
if err != nil {
b.Fatalf("error connecting: %v", err)
}
defer client.Disconnect(context.Background())
coll := client.Database("test").Collection("test")
_, err = coll.DeleteMany(context.Background(), bson.D{})
Expand Down
5 changes: 4 additions & 1 deletion benchmark/single.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,13 @@ func getClientDB(ctx context.Context) (*mongo.Database, error) {
if err != nil {
return nil, err
}
client, err := mongo.Connect(ctx, options.Client().ApplyURI(cs.String()))
client, err := mongo.NewClient(options.Client().ApplyURI(cs.String()))
if err != nil {
return nil, err
}
if err = client.Connect(ctx); err != nil {
return nil, err
}

db := client.Database(integtest.GetDBName(cs))
return db, nil
Expand Down
29 changes: 17 additions & 12 deletions mongo/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ type Client struct {
encryptedFieldsMap map[string]interface{}
}

// Connect creates a new Client and then initializes it using the Connect method.
// Connect creates a new Client and then initializes it using the Connect method. This is equivalent to calling
// NewClient followed by Client.Connect.
//
// When creating an options.ClientOptions, the order the methods are called matters. Later Set*
// methods will overwrite the values from previous Set* method invocations. This includes the
Expand All @@ -103,18 +104,18 @@ type Client struct {
// The Client.Ping method can be used to verify that the deployment is successfully connected and the
// Client was correctly configured.
func Connect(ctx context.Context, opts ...*options.ClientOptions) (*Client, error) {
c, err := newClient(opts...)
c, err := NewClient(opts...)
if err != nil {
return nil, err
}
err = c.connect(ctx)
err = c.Connect(ctx)
if err != nil {
return nil, err
}
return c, nil
}

// newClient creates a new client to connect to a deployment specified by the uri.
// NewClient creates a new client to connect to a deployment specified by the uri.
//
// When creating an options.ClientOptions, the order the methods are called matters. Later Set*
// methods will overwrite the values from previous Set* method invocations. This includes the
Expand All @@ -127,7 +128,9 @@ func Connect(ctx context.Context, opts ...*options.ClientOptions) (*Client, erro
// option fields of previous options, there is no partial overwriting. For example, if Username is
// set in the Auth field for the first option, and Password is set for the second but with no
// Username, after the merge the Username field will be empty.
func newClient(opts ...*options.ClientOptions) (*Client, error) {
//
// Deprecated: Use [Connect] instead.
func NewClient(opts ...*options.ClientOptions) (*Client, error) {
clientOpt := options.MergeClientOptions(opts...)

id, err := uuid.New()
Expand Down Expand Up @@ -232,12 +235,14 @@ func newClient(opts ...*options.ClientOptions) (*Client, error) {
return client, nil
}

// connect initializes the Client by starting background monitoring goroutines.
// Connect initializes the Client by starting background monitoring goroutines.
// If the Client was created using the NewClient function, this method must be called before a Client can be used.
//
// Connect starts background goroutines to monitor the state of the deployment and does not do any I/O in the main
// goroutine. The Client.Ping method can be used to verify that the connection was created successfully.
func (c *Client) connect(ctx context.Context) error {
//
// Deprecated: Use [mongo.Connect] instead.
func (c *Client) Connect(ctx context.Context) error {
if connector, ok := c.deployment.(driver.Connector); ok {
err := connector.Connect()
if err != nil {
Expand All @@ -252,19 +257,19 @@ func (c *Client) connect(ctx context.Context) error {
}

if c.internalClientFLE != nil {
if err := c.internalClientFLE.connect(ctx); err != nil {
if err := c.internalClientFLE.Connect(ctx); err != nil {
return err
}
}

if c.keyVaultClientFLE != nil && c.keyVaultClientFLE != c.internalClientFLE && c.keyVaultClientFLE != c {
if err := c.keyVaultClientFLE.connect(ctx); err != nil {
if err := c.keyVaultClientFLE.Connect(ctx); err != nil {
return err
}
}

if c.metadataClientFLE != nil && c.metadataClientFLE != c.internalClientFLE && c.metadataClientFLE != c {
if err := c.metadataClientFLE.connect(ctx); err != nil {
if err := c.metadataClientFLE.Connect(ctx); err != nil {
return err
}
}
Expand Down Expand Up @@ -484,7 +489,7 @@ func (c *Client) getOrCreateInternalClient(clientOpts *options.ClientOptions) (*
internalClientOpts.AutoEncryptionOptions = nil
internalClientOpts.SetMinPoolSize(0)
var err error
c.internalClientFLE, err = newClient(internalClientOpts)
c.internalClientFLE, err = NewClient(internalClientOpts)
return c.internalClientFLE, err
}

Expand All @@ -494,7 +499,7 @@ func (c *Client) configureKeyVaultClientFLE(clientOpts *options.ClientOptions) e
aeOpts := clientOpts.AutoEncryptionOptions
switch {
case aeOpts.KeyVaultClientOptions != nil:
c.keyVaultClientFLE, err = newClient(aeOpts.KeyVaultClientOptions)
c.keyVaultClientFLE, err = NewClient(aeOpts.KeyVaultClientOptions)
case clientOpts.MaxPoolSize != nil && *clientOpts.MaxPoolSize == 0:
c.keyVaultClientFLE = c
default:
Expand Down
18 changes: 9 additions & 9 deletions mongo/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func setupClient(opts ...*options.ClientOptions) *Client {
integtest.AddTestServerAPIVersion(clientOpts)
opts = append(opts, clientOpts)
}
client, _ := newClient(opts...)
client, _ := NewClient(opts...)
return client
}

Expand Down Expand Up @@ -183,7 +183,7 @@ func TestClient(t *testing.T) {
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := newClient(tc.opts)
_, err := NewClient(tc.opts)
assert.Equal(t, tc.err, err, "expected error %v, got %v", tc.err, err)
})
}
Expand Down Expand Up @@ -227,7 +227,7 @@ func TestClient(t *testing.T) {
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := newClient(tc.opts)
_, err := NewClient(tc.opts)
assert.Equal(t, tc.err, err, "expected error %v, got %v", tc.err, err)
})
}
Expand All @@ -249,7 +249,7 @@ func TestClient(t *testing.T) {
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
client, err := newClient(tc.opts)
client, err := NewClient(tc.opts)
if tc.expectErr {
assert.NotNil(t, err, "expected error, got nil")
return
Expand Down Expand Up @@ -277,7 +277,7 @@ func TestClient(t *testing.T) {
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
client, err := newClient(tc.opts)
client, err := NewClient(tc.opts)
if tc.expectErr {
assert.NotNil(t, err, "expected error, got nil")
return
Expand Down Expand Up @@ -412,22 +412,22 @@ func TestClient(t *testing.T) {

t.Run("success with all options", func(t *testing.T) {
serverAPIOptions := getServerAPIOptions()
client, err := newClient(options.Client().SetServerAPIOptions(serverAPIOptions))
client, err := NewClient(options.Client().SetServerAPIOptions(serverAPIOptions))
assert.Nil(t, err, "unexpected error from NewClient: %v", err)
convertedAPIOptions := topology.ConvertToDriverAPIOptions(serverAPIOptions)
assert.Equal(t, convertedAPIOptions, client.serverAPI,
"mismatch in serverAPI; expected %v, got %v", convertedAPIOptions, client.serverAPI)
})
t.Run("failure with unsupported version", func(t *testing.T) {
serverAPIOptions := options.ServerAPI("badVersion")
_, err := newClient(options.Client().SetServerAPIOptions(serverAPIOptions))
_, err := NewClient(options.Client().SetServerAPIOptions(serverAPIOptions))
assert.NotNil(t, err, "expected error from NewClient, got nil")
errmsg := `api version "badVersion" not supported; this driver version only supports API version "1"`
assert.Equal(t, errmsg, err.Error(), "expected error %v, got %v", errmsg, err.Error())
})
t.Run("cannot modify options after client creation", func(t *testing.T) {
serverAPIOptions := getServerAPIOptions()
client, err := newClient(options.Client().SetServerAPIOptions(serverAPIOptions))
client, err := NewClient(options.Client().SetServerAPIOptions(serverAPIOptions))
assert.Nil(t, err, "unexpected error from NewClient: %v", err)

expectedServerAPIOptions := getServerAPIOptions()
Expand Down Expand Up @@ -476,7 +476,7 @@ func TestClient(t *testing.T) {
extraOptions["__cryptSharedLibDisabledForTestOnly"] = true
}

_, err := newClient(options.Client().
_, err := NewClient(options.Client().
SetAutoEncryptionOptions(options.AutoEncryption().
SetKmsProviders(map[string]map[string]interface{}{
"local": {"key": make([]byte, 96)},
Expand Down
2 changes: 1 addition & 1 deletion mongo/database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func TestDatabase(t *testing.T) {
})
t.Run("TransientTransactionError label", func(t *testing.T) {
client := setupClient(options.Client().ApplyURI("mongodb://nonexistent").SetServerSelectionTimeout(3 * time.Second))
err := client.connect(bgCtx)
err := client.Connect(bgCtx)
defer client.Disconnect(bgCtx)
assert.Nil(t, err, "expected nil, got %v", err)

Expand Down
4 changes: 3 additions & 1 deletion mongo/integration/client_options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ func TestClientOptions_CustomDialer(t *testing.T) {
cs := integtest.ConnString(t)
opts := options.Client().ApplyURI(cs.String()).SetDialer(td)
integtest.AddTestServerAPIVersion(opts)
client, err := mongo.Connect(context.Background(), opts)
client, err := mongo.NewClient(opts)
require.NoError(t, err)
err = client.Connect(context.Background())
require.NoError(t, err)
_, err = client.ListDatabases(context.Background(), bson.D{})
require.NoError(t, err)
Expand Down
2 changes: 1 addition & 1 deletion mongo/integration/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ func TestClient(t *testing.T) {
})
mt.RunOpts("watch", noClientOpts, func(mt *mtest.T) {
mt.Run("disconnected", func(mt *mtest.T) {
c, err := mongo.Connect(context.Background(), options.Client().ApplyURI(mtest.ClusterURI()))
c, err := mongo.NewClient(options.Client().ApplyURI(mtest.ClusterURI()))
assert.Nil(mt, err, "NewClient error: %v", err)
_, err = c.Watch(context.Background(), mongo.Pipeline{})
assert.Equal(mt, mongo.ErrClientDisconnected, err, "expected error %v, got %v", mongo.ErrClientDisconnected, err)
Expand Down
9 changes: 6 additions & 3 deletions mongo/integration/mtest/mongotest.go
Original file line number Diff line number Diff line change
Expand Up @@ -689,13 +689,13 @@ func (t *T) createTestClient() {
// pin to first mongos
pinnedHostList := []string{testContext.connString.Hosts[0]}
uriOpts := options.Client().ApplyURI(testContext.connString.Original).SetHosts(pinnedHostList)
t.Client, err = mongo.Connect(context.Background(), uriOpts, clientOpts)
t.Client, err = mongo.NewClient(uriOpts, clientOpts)
case Mock:
// clear pool monitor to avoid configuration error
clientOpts.PoolMonitor = nil
t.mockDeployment = newMockDeployment()
clientOpts.Deployment = t.mockDeployment
t.Client, err = mongo.Connect(context.Background(), clientOpts)
t.Client, err = mongo.NewClient(clientOpts)
case Proxy:
t.proxyDialer = newProxyDialer()
clientOpts.SetDialer(t.proxyDialer)
Expand All @@ -713,11 +713,14 @@ func (t *T) createTestClient() {
}

// Pass in uriOpts first so clientOpts wins if there are any conflicting settings.
t.Client, err = mongo.Connect(context.Background(), uriOpts, clientOpts)
t.Client, err = mongo.NewClient(uriOpts, clientOpts)
}
if err != nil {
t.Fatalf("error creating client: %v", err)
}
if err := t.Client.Connect(context.Background()); err != nil {
t.Fatalf("error connecting client: %v", err)
}
}

func (t *T) createTestCollection() {
Expand Down
4 changes: 2 additions & 2 deletions mongo/mongocryptd.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func newMongocryptdClient(opts *options.AutoEncryptionOptions) (*mongocryptdClie
}

// create client
client, err := Connect(context.Background(), options.Client().ApplyURI(uri).SetServerSelectionTimeout(defaultServerSelectionTimeout))
client, err := NewClient(options.Client().ApplyURI(uri).SetServerSelectionTimeout(defaultServerSelectionTimeout))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -114,7 +114,7 @@ func (mc *mongocryptdClient) markCommand(ctx context.Context, dbName string, cmd

// connect connects the underlying Client instance. This must be called before performing any mark operations.
func (mc *mongocryptdClient) connect(ctx context.Context) error {
return mc.client.connect(ctx)
return mc.client.Connect(ctx)
}

// disconnect disconnects the underlying Client instance. This should be called after all operations have completed.
Expand Down
2 changes: 1 addition & 1 deletion x/mongo/driver/topology/topology_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ func TestSessionTimeout(t *testing.T) {

currDesc := topo.desc.Load().(description.Topology)
require.Nil(t, currDesc.SessionTimeoutMinutesPtr,
"session timeout minutes mismatch. got: %d. expected: nil", *currDesc.SessionTimeoutMinutesPtr)
"session timeout minutes mismatch. expected: nil")
})
}

Expand Down

0 comments on commit 5202542

Please sign in to comment.