diff --git a/x/mongo/driver/topology/pool.go b/x/mongo/driver/topology/pool.go index c30aaa5c79..0099babebf 100644 --- a/x/mongo/driver/topology/pool.go +++ b/x/mongo/driver/topology/pool.go @@ -228,6 +228,8 @@ func newPool(config poolConfig, connOpts ...ConnectionOption) *pool { } pool.connOpts = append(pool.connOpts, withGenerationNumberFn(func(_ generationNumberFn) generationNumberFn { return pool.getGenerationForNewConnection })) + pool.generation.connect() + // Create a Context with cancellation that's used to signal the createConnections() and // maintain() background goroutines to stop. Also create a "backgroundDone" WaitGroup that is // used to wait for the background goroutines to return. @@ -273,18 +275,7 @@ func newPool(config poolConfig, connOpts ...ConnectionOption) *pool { // stale checks if a given connection's generation is below the generation of the pool func (p *pool) stale(conn *connection) bool { - if conn == nil { - return true - } - p.stateMu.RLock() - defer p.stateMu.RUnlock() - if p.state == poolClosed { - return true - } - if generation, ok := p.generation.getGeneration(conn.desc.ServiceID); ok { - return conn.generation < generation - } - return false + return conn == nil || p.generation.stale(conn.desc.ServiceID, conn.generation) } // ready puts the pool into the "ready" state and starts the background connection creation and @@ -353,6 +344,8 @@ func (p *pool) close(ctx context.Context) { // Wait for all background goroutines to exit. p.backgroundDone.Wait() + p.generation.disconnect() + if ctx == nil { ctx = context.Background() } diff --git a/x/mongo/driver/topology/pool_generation_counter.go b/x/mongo/driver/topology/pool_generation_counter.go index 4d9070e8f0..dd10c0ce7a 100644 --- a/x/mongo/driver/topology/pool_generation_counter.go +++ b/x/mongo/driver/topology/pool_generation_counter.go @@ -8,6 +8,7 @@ package topology import ( "sync" + "sync/atomic" "go.mongodb.org/mongo-driver/bson/primitive" ) @@ -29,6 +30,10 @@ type generationStats struct { // load balancer, there is only one service ID: primitive.NilObjectID. For load-balanced deployments, each server behind // the load balancer will have a unique service ID. type poolGenerationMap struct { + // state must be accessed using the atomic package and should be at the beginning of the struct. + // - atomic bug: https://pkg.go.dev/sync/atomic#pkg-note-BUG + // - suggested layout: https://go101.org/article/memory-layout.html + state int64 generationMap map[primitive.ObjectID]*generationStats sync.Mutex @@ -42,6 +47,14 @@ func newPoolGenerationMap() *poolGenerationMap { return pgm } +func (p *poolGenerationMap) connect() { + atomic.StoreInt64(&p.state, generationConnected) +} + +func (p *poolGenerationMap) disconnect() { + atomic.StoreInt64(&p.state, generationDisconnected) +} + // addConnection increments the connection count for the generation associated with the given service ID and returns the // generation number for the connection. func (p *poolGenerationMap) addConnection(serviceIDPtr *primitive.ObjectID) uint64 { @@ -93,6 +106,18 @@ func (p *poolGenerationMap) clear(serviceIDPtr *primitive.ObjectID) { } } +func (p *poolGenerationMap) stale(serviceIDPtr *primitive.ObjectID, knownGeneration uint64) bool { + // If the map has been disconnected, all connections should be considered stale to ensure that they're closed. + if atomic.LoadInt64(&p.state) == generationDisconnected { + return true + } + + if generation, ok := p.getGeneration(serviceIDPtr); ok { + return knownGeneration < generation + } + return false +} + func (p *poolGenerationMap) getGeneration(serviceIDPtr *primitive.ObjectID) (uint64, bool) { serviceID := getServiceID(serviceIDPtr) p.Lock()