Skip to content

Commit

Permalink
minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyang-hu committed Jan 4, 2024
1 parent 61e0bc0 commit 31b50f1
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 12 deletions.
17 changes: 5 additions & 12 deletions x/mongo/driver/topology/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ func newPool(config poolConfig, connOpts ...ConnectionOption) *pool {
}
pool.connOpts = append(pool.connOpts, withGenerationNumberFn(func(_ generationNumberFn) generationNumberFn { return pool.getGenerationForNewConnection }))

pool.generation.connect()

// Create a Context with cancellation that's used to signal the createConnections() and
// maintain() background goroutines to stop. Also create a "backgroundDone" WaitGroup that is
// used to wait for the background goroutines to return.
Expand Down Expand Up @@ -273,18 +275,7 @@ func newPool(config poolConfig, connOpts ...ConnectionOption) *pool {

// stale checks if a given connection's generation is below the generation of the pool
func (p *pool) stale(conn *connection) bool {
if conn == nil {
return true
}
p.stateMu.RLock()
defer p.stateMu.RUnlock()
if p.state == poolClosed {
return true
}
if generation, ok := p.generation.getGeneration(conn.desc.ServiceID); ok {
return conn.generation < generation
}
return false
return conn == nil || p.generation.stale(conn.desc.ServiceID, conn.generation)
}

// ready puts the pool into the "ready" state and starts the background connection creation and
Expand Down Expand Up @@ -353,6 +344,8 @@ func (p *pool) close(ctx context.Context) {
// Wait for all background goroutines to exit.
p.backgroundDone.Wait()

p.generation.disconnect()

if ctx == nil {
ctx = context.Background()
}
Expand Down
25 changes: 25 additions & 0 deletions x/mongo/driver/topology/pool_generation_counter.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package topology

import (
"sync"
"sync/atomic"

"go.mongodb.org/mongo-driver/bson/primitive"
)
Expand All @@ -29,6 +30,10 @@ type generationStats struct {
// load balancer, there is only one service ID: primitive.NilObjectID. For load-balanced deployments, each server behind
// the load balancer will have a unique service ID.
type poolGenerationMap struct {
// state must be accessed using the atomic package and should be at the beginning of the struct.
// - atomic bug: https://pkg.go.dev/sync/atomic#pkg-note-BUG
// - suggested layout: https://go101.org/article/memory-layout.html
state int64
generationMap map[primitive.ObjectID]*generationStats

sync.Mutex
Expand All @@ -42,6 +47,14 @@ func newPoolGenerationMap() *poolGenerationMap {
return pgm
}

func (p *poolGenerationMap) connect() {
atomic.StoreInt64(&p.state, generationConnected)
}

func (p *poolGenerationMap) disconnect() {
atomic.StoreInt64(&p.state, generationDisconnected)
}

// addConnection increments the connection count for the generation associated with the given service ID and returns the
// generation number for the connection.
func (p *poolGenerationMap) addConnection(serviceIDPtr *primitive.ObjectID) uint64 {
Expand Down Expand Up @@ -93,6 +106,18 @@ func (p *poolGenerationMap) clear(serviceIDPtr *primitive.ObjectID) {
}
}

func (p *poolGenerationMap) stale(serviceIDPtr *primitive.ObjectID, knownGeneration uint64) bool {
// If the map has been disconnected, all connections should be considered stale to ensure that they're closed.
if atomic.LoadInt64(&p.state) == generationDisconnected {
return true
}

if generation, ok := p.getGeneration(serviceIDPtr); ok {
return knownGeneration < generation
}
return false
}

func (p *poolGenerationMap) getGeneration(serviceIDPtr *primitive.ObjectID) (uint64, bool) {
serviceID := getServiceID(serviceIDPtr)
p.Lock()
Expand Down

0 comments on commit 31b50f1

Please sign in to comment.