Skip to content

Commit

Permalink
removing atomic value and altering tests for replaceErrors for discon…
Browse files Browse the repository at this point in the history
…nected topology
  • Loading branch information
joyjwang committed Sep 16, 2024
1 parent 8210c5b commit 3a7ec76
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 25 deletions.
18 changes: 0 additions & 18 deletions mongo/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"errors"
"fmt"
"net/http"
"sync/atomic"
"time"

"go.mongodb.org/mongo-driver/v2/bson"
Expand Down Expand Up @@ -74,7 +73,6 @@ type Client struct {
timeout *time.Duration
httpClient *http.Client
logger *logger.Logger
closed atomic.Value

// client-side encryption fields
keyVaultClientFLE *Client
Expand Down Expand Up @@ -252,8 +250,6 @@ func newClient(opts ...options.Lister[options.ClientOptions]) (*Client, error) {
return nil, fmt.Errorf("invalid logger options: %w", err)
}

client.closed.Store(false)

return client, nil
}

Expand Down Expand Up @@ -315,10 +311,6 @@ func (c *Client) connect() error {
// or write operations. If this method returns with no errors, all connections
// associated with this Client have been closed.
func (c *Client) Disconnect(ctx context.Context) error {
if c.closed.Load().(bool) {
return ErrClientDisconnected
}

if c.logger != nil {
defer c.logger.Close()
}
Expand Down Expand Up @@ -358,8 +350,6 @@ func (c *Client) Disconnect(ctx context.Context) error {
c.cryptFLE.Close()
}

c.closed.Store(true)

if disconnector, ok := c.deployment.(driver.Disconnector); ok {
return replaceErrors(disconnector.Disconnect(ctx))
}
Expand All @@ -379,10 +369,6 @@ func (c *Client) Disconnect(ctx context.Context) error {
// Using Ping reduces application resilience because applications starting up will error if the server is temporarily
// unavailable or is failing over (e.g. during autoscaling due to a load spike).
func (c *Client) Ping(ctx context.Context, rp *readpref.ReadPref) error {
if c.closed.Load().(bool) {
return ErrClientDisconnected
}

if ctx == nil {
ctx = context.Background()
}
Expand Down Expand Up @@ -878,10 +864,6 @@ func (c *Client) UseSessionWithOptions(
// documentation).
func (c *Client) Watch(ctx context.Context, pipeline interface{},
opts ...options.Lister[options.ChangeStreamOptions]) (*ChangeStream, error) {
if c.closed.Load().(bool) {
return nil, ErrClientDisconnected
}

csConfig := changeStreamConfig{
readConcern: c.readConcern,
readPreference: c.readPreference,
Expand Down
20 changes: 14 additions & 6 deletions mongo/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"go.mongodb.org/mongo-driver/v2/internal/assert"
"go.mongodb.org/mongo-driver/v2/internal/integtest"
"go.mongodb.org/mongo-driver/v2/internal/mongoutil"
"go.mongodb.org/mongo-driver/v2/internal/require"
"go.mongodb.org/mongo-driver/v2/mongo/options"
"go.mongodb.org/mongo-driver/v2/mongo/readconcern"
"go.mongodb.org/mongo-driver/v2/mongo/readpref"
Expand Down Expand Up @@ -52,18 +53,25 @@ func TestClient(t *testing.T) {
assert.Equal(t, dbName, db.Name(), "expected db name %v, got %v", dbName, db.Name())
assert.Equal(t, client, db.Client(), "expected client %v, got %v", client, db.Client())
})
t.Run("client disconnect error", func(t *testing.T) {
t.Run("replaceErrors for disconnected topology", func(t *testing.T) {
client := setupClient()
assert.Equal(t, false, client.closed.Load().(bool), "expected value %v, got %v", false, client.closed.Load().(bool))

err := client.Disconnect(bgCtx)
assert.Equal(t, nil, err, "expected nil, got %v", err)
assert.Equal(t, true, client.closed.Load().(bool), "expected error %v, got %v", true, client.closed.Load().(bool))
topo, ok := client.deployment.(*topology.Topology)
require.True(t, ok, "client deployment is not a topology")

err := topo.Disconnect(context.Background())
require.NoError(t, err)

_, err = client.ListDatabases(bgCtx, bson.D{})
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

err = client.Ping(bgCtx, nil)
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

_, err = client.Watch(bgCtx, nil, nil)
err = client.Disconnect(bgCtx)
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

_, err = client.Watch(bgCtx, []bson.D{})
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
})
t.Run("nil document error", func(t *testing.T) {
Expand Down
62 changes: 62 additions & 0 deletions mongo/collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,20 @@ package mongo

import (
"bytes"
"context"
"errors"
"testing"

"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/internal/assert"
"go.mongodb.org/mongo-driver/v2/internal/ptrutil"
"go.mongodb.org/mongo-driver/v2/internal/require"
"go.mongodb.org/mongo-driver/v2/mongo/options"
"go.mongodb.org/mongo-driver/v2/mongo/readconcern"
"go.mongodb.org/mongo-driver/v2/mongo/readpref"
"go.mongodb.org/mongo-driver/v2/mongo/writeconcern"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/topology"
)

const (
Expand Down Expand Up @@ -78,6 +81,65 @@ func TestCollection(t *testing.T) {
}
compareColls(t, expected, coll)
})
t.Run("replaceErrors for disconnected topology", func(t *testing.T) {
coll := setupColl("foo")
doc := bson.D{}
update := bson.D{{"$update", bson.D{{"x", 1}}}}

topo, ok := coll.client.deployment.(*topology.Topology)
require.True(t, ok, "client deployment is not a topology")

err := topo.Disconnect(context.Background())
require.NoError(t, err)

_, err = coll.InsertOne(bgCtx, doc)
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

_, err = coll.InsertMany(bgCtx, []interface{}{doc})
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

_, err = coll.DeleteOne(bgCtx, doc)
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

_, err = coll.DeleteMany(bgCtx, doc)
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

_, err = coll.UpdateOne(bgCtx, doc, update)
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

_, err = coll.UpdateMany(bgCtx, doc, update)
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

_, err = coll.ReplaceOne(bgCtx, doc, doc)
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

_, err = coll.Aggregate(bgCtx, Pipeline{})
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

_, err = coll.EstimatedDocumentCount(bgCtx)
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

_, err = coll.CountDocuments(bgCtx, doc)
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

err = coll.Distinct(bgCtx, "x", doc).Err()
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

_, err = coll.Find(bgCtx, doc)
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

err = coll.FindOne(bgCtx, doc).Err()
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

err = coll.FindOneAndDelete(bgCtx, doc).Err()
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

err = coll.FindOneAndReplace(bgCtx, doc, doc).Err()
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

err = coll.FindOneAndUpdate(bgCtx, doc, update).Err()
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
})
t.Run("database accessor", func(t *testing.T) {
coll := setupColl("bar")
dbName := coll.Database().Name()
Expand Down
20 changes: 19 additions & 1 deletion mongo/database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (

"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/internal/assert"
"go.mongodb.org/mongo-driver/v2/internal/require"
"go.mongodb.org/mongo-driver/v2/mongo/options"
"go.mongodb.org/mongo-driver/v2/mongo/readconcern"
"go.mongodb.org/mongo-driver/v2/mongo/readpref"
Expand Down Expand Up @@ -83,9 +84,26 @@ func TestDatabase(t *testing.T) {
compareDbs(t, expected, got)
})
})
t.Run("replaceErrors for disconnected topology", func(t *testing.T) {
db := setupDb("foo")

topo, ok := db.client.deployment.(*topology.Topology)
require.True(t, ok, "client deployment is not a topology")

err := topo.Disconnect(context.Background())
require.NoError(t, err)

err = db.RunCommand(bgCtx, bson.D{{"x", 1}}).Err()
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

err = db.Drop(bgCtx)
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)

_, err = db.ListCollections(bgCtx, bson.D{})
assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
})
t.Run("TransientTransactionError label", func(t *testing.T) {
client := setupClient(options.Client().ApplyURI("mongodb://nonexistent").SetServerSelectionTimeout(3 * time.Second))
assert.Equal(t, false, client.closed.Load().(bool), "expected value %v, got %v", false, client.closed.Load().(bool))
defer func() { _ = client.Disconnect(bgCtx) }()

t.Run("negative case of non-transaction", func(t *testing.T) {
Expand Down

0 comments on commit 3a7ec76

Please sign in to comment.