Skip to content

Commit

Permalink
removed sessionPool nil checks, add closed atomic value in client str…
Browse files Browse the repository at this point in the history
…uct, alter tests
  • Loading branch information
joyjwang committed Sep 13, 2024
1 parent 9e7ccb0 commit 2fa662e
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 23 deletions.
24 changes: 15 additions & 9 deletions mongo/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"errors"
"fmt"
"net/http"
"sync/atomic"
"time"

"go.mongodb.org/mongo-driver/v2/bson"
Expand Down Expand Up @@ -73,6 +74,7 @@ type Client struct {
timeout *time.Duration
httpClient *http.Client
logger *logger.Logger
closed atomic.Value

// client-side encryption fields
keyVaultClientFLE *Client
Expand Down Expand Up @@ -250,6 +252,8 @@ func newClient(opts ...options.Lister[options.ClientOptions]) (*Client, error) {
return nil, fmt.Errorf("invalid logger options: %w", err)
}

client.closed.Store(false)

return client, nil
}

Expand Down Expand Up @@ -311,6 +315,10 @@ func (c *Client) connect() error {
// or write operations. If this method returns with no errors, all connections
// associated with this Client have been closed.
func (c *Client) Disconnect(ctx context.Context) error {
if c.closed.Load().(bool) {
return ErrClientDisconnected
}

if c.logger != nil {
defer c.logger.Close()
}
Expand Down Expand Up @@ -350,6 +358,8 @@ func (c *Client) Disconnect(ctx context.Context) error {
c.cryptFLE.Close()
}

c.closed.Store(true)

if disconnector, ok := c.deployment.(driver.Disconnector); ok {
return replaceErrors(disconnector.Disconnect(ctx))
}
Expand All @@ -369,6 +379,10 @@ func (c *Client) Disconnect(ctx context.Context) error {
// Using Ping reduces application resilience because applications starting up will error if the server is temporarily
// unavailable or is failing over (e.g. during autoscaling due to a load spike).
func (c *Client) Ping(ctx context.Context, rp *readpref.ReadPref) error {
if c.closed.Load().(bool) {
return ErrClientDisconnected
}

if ctx == nil {
ctx = context.Background()
}
Expand Down Expand Up @@ -396,10 +410,6 @@ func (c *Client) Ping(ctx context.Context, rp *readpref.ReadPref) error {
// If the DefaultReadConcern, DefaultWriteConcern, or DefaultReadPreference options are not set, the client's read
// concern, write concern, or read preference will be used, respectively.
func (c *Client) StartSession(opts ...options.Lister[options.SessionOptions]) (*Session, error) {
if c.sessionPool == nil {
return nil, ErrClientDisconnected
}

sessArgs, err := mongoutil.NewOptions(opts...)
if err != nil {
return nil, err
Expand Down Expand Up @@ -454,10 +464,6 @@ func (c *Client) StartSession(opts ...options.Lister[options.SessionOptions]) (*
}

func (c *Client) endSessions(ctx context.Context) {
if c.sessionPool == nil {
return
}

sessionIDs := c.sessionPool.IDSlice()
op := operation.NewEndSessions(nil).ClusterClock(c.clock).Deployment(c.deployment).
ServerSelector(&serverselector.ReadPref{ReadPref: readpref.PrimaryPreferred()}).
Expand Down Expand Up @@ -872,7 +878,7 @@ func (c *Client) UseSessionWithOptions(
// documentation).
func (c *Client) Watch(ctx context.Context, pipeline interface{},
opts ...options.Lister[options.ChangeStreamOptions]) (*ChangeStream, error) {
if c.sessionPool == nil {
if c.closed.Load().(bool) {
return nil, ErrClientDisconnected
}

Expand Down
21 changes: 7 additions & 14 deletions mongo/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (
"go.mongodb.org/mongo-driver/v2/mongo/writeconcern"
"go.mongodb.org/mongo-driver/v2/tag"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/mongocrypt"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/session"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/topology"
)

Expand All @@ -37,7 +36,7 @@ func setupClient(opts ...options.Lister[options.ClientOptions]) *Client {
integtest.AddTestServerAPIVersion(clientOpts)
opts = append(opts, clientOpts)
}
client, _ := newClient(opts...)
client, _ := Connect(opts...)
return client
}

Expand All @@ -53,28 +52,22 @@ func TestClient(t *testing.T) {
assert.Equal(t, dbName, db.Name(), "expected db name %v, got %v", dbName, db.Name())
assert.Equal(t, client, db.Client(), "expected client %v, got %v", client, db.Client())
})
t.Run("replace topology error", func(t *testing.T) {
t.Run("client disconnect error", func(t *testing.T) {
client := setupClient()
assert.Equal(t, false, client.closed.Load().(bool), "expected value %v, got %v", false, client.closed.Load().(bool))

_, err := client.StartSession()
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

_, err = client.ListDatabases(bgCtx, bson.D{})
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
err := client.Disconnect(bgCtx)
assert.Equal(t, nil, err, "expected nil, got %v", err)
assert.Equal(t, true, client.closed.Load().(bool), "expected error %v, got %v", true, client.closed.Load().(bool))

err = client.Ping(bgCtx, nil)
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

err = client.Disconnect(bgCtx)
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

_, err = client.Watch(bgCtx, []bson.D{})
_, err = client.Watch(bgCtx, nil, nil)
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
})
t.Run("nil document error", func(t *testing.T) {
// manually set session pool to non-nil because Watch will return ErrClientDisconnected
client := setupClient()
client.sessionPool = &session.Pool{}

_, err := client.Watch(bgCtx, nil)
watchErr := errors.New("can only marshal slices and arrays into aggregation pipelines, but got invalid")
Expand Down

0 comments on commit 2fa662e

Please sign in to comment.