Skip to content

Commit

Permalink
Simplify logic to check in all non-idle connections for interuption.
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyang-hu committed Dec 14, 2023
1 parent d938fed commit 9dfddd9
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 81 deletions.
8 changes: 8 additions & 0 deletions internal/eventtest/eventtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,11 @@ func (tpm *TestPoolMonitor) IsPoolCleared() bool {
})
return len(poolClearedEvents) > 0
}

// Interruptions returns the number of interruptions in the events recorded by the testPoolMonitor.
func (tpm *TestPoolMonitor) Interruptions() int {
interruptions := tpm.Events(func(evt *event.PoolEvent) bool {
return evt.Interruption
})
return len(interruptions)
}
2 changes: 0 additions & 2 deletions x/mongo/driver/topology/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ type connection struct {
// - atomic bug: https://pkg.go.dev/sync/atomic#pkg-note-BUG
// - suggested layout: https://go101.org/article/memory-layout.html
state int64
inUse bool
err error

id string
Expand Down Expand Up @@ -87,7 +86,6 @@ func newConnection(addr address.Address, opts ...ConnectionOption) *connection {
id := fmt.Sprintf("%s[-%d]", addr, nextConnectionID())

c := &connection{
inUse: cfg.inUse,
id: id,
addr: addr,
idleTimeout: cfg.idleTimeout,
Expand Down
1 change: 0 additions & 1 deletion x/mongo/driver/topology/connection_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ type Handshaker = driver.Handshaker
type generationNumberFn func(serviceID *primitive.ObjectID) uint64

type connectionConfig struct {
inUse bool
connectTimeout time.Duration
dialer Dialer
handshaker Handshaker
Expand Down
132 changes: 69 additions & 63 deletions x/mongo/driver/topology/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -535,9 +535,6 @@ func (p *pool) checkOut(ctx context.Context) (conn *connection, err error) {
// timed out).
w := newWantConn()
defer func() {
if conn != nil {
conn.inUse = true
}
if err != nil {
w.cancel(p, err)
}
Expand Down Expand Up @@ -592,11 +589,6 @@ func (p *pool) checkOut(ctx context.Context) (conn *connection, err error) {

// If we didn't get an immediately available idle connection, also get in the queue for a new
// connection while we're waiting for an idle connection.
w.mu.Lock()
w.connOpts = append(w.connOpts, func(cfg *connectionConfig) {
cfg.inUse = true
})
w.mu.Unlock()
p.queueForNewConn(w)
p.stateMu.RUnlock()

Expand Down Expand Up @@ -759,8 +751,7 @@ func (p *pool) removeConnection(conn *connection, reason reason, err error) erro
return nil
}

// checkIn returns an idle connection to the pool. If the connection is perished or the pool is
// closed, it is removed from the connection pool and closed.
// checkIn returns an idle connection to the pool. It calls checkInWithCallback internally.
func (p *pool) checkIn(conn *connection) error {
return p.checkInWithCallback(conn, func() (reason, bool) {
if mustLogPoolMessage(p) {
Expand Down Expand Up @@ -807,16 +798,19 @@ func (p *pool) checkInNoEvent(conn *connection) error {
})
}

func (p *pool) checkInWithCallback(conn *connection, cb func() (reason, bool)) error {
// checkInWithCallback returns a connection to the pool. If the connection is perished or the pool is
// closed, it is removed from the connection pool and closed.
// The callback parameter is expected to returns a reason of the check-in and a boolean value to
// indicate whether the connection is perished.
// Events and logs can also be added in the callback function.
func (p *pool) checkInWithCallback(conn *connection, callback func() (reason, bool)) error {
if conn == nil {
return nil
}
if conn.pool != p {
return ErrWrongPool
}

conn.inUse = false

// Bump the connection idle deadline here because we're about to make the connection "available".
// The idle deadline is used to determine when a connection has reached its max idle time and
// should be closed. A connection reaches its max idle time when it has been "available" in the
Expand All @@ -827,8 +821,8 @@ func (p *pool) checkInWithCallback(conn *connection, cb func() (reason, bool)) e

var r reason
var perished bool
if cb != nil {
r, perished = cb()
if callback != nil {
r, perished = callback()
}
if perished {
_ = p.removeConnection(conn, r, nil)
Expand Down Expand Up @@ -861,45 +855,49 @@ func (p *pool) checkInWithCallback(conn *connection, cb func() (reason, bool)) e
return nil
}

// clearAll does same as the "clear" method and interrupts all in-use connections as well.
// clear calls clearImpl internally with a false interruptAllConnections value.
func (p *pool) clear(err error, serviceID *primitive.ObjectID) {
p.clearImpl(err, serviceID, false)
}

// clearAll does same as the "clear" method but interrupts all connections.
func (p *pool) clearAll(err error, serviceID *primitive.ObjectID) {
p.clearImpl(err, serviceID, true)
}

func (p *pool) interruptInUseConnections() {
for _, conn := range p.conns {
if conn.inUse && p.stale(conn) {
_ = conn.closeWithErr(poolClearedError{
err: fmt.Errorf("interrupted"),
address: p.address,
})
_ = p.checkInWithCallback(conn, func() (reason, bool) {
if mustLogPoolMessage(p) {
keysAndValues := logger.KeyValues{
logger.KeyDriverConnectionID, conn.driverConnectionID,
}

logPoolMessage(p, logger.ConnectionCheckedIn, keysAndValues...)
// interruptConnections interrupts the input connections.
func (p *pool) interruptConnections(conns []*connection) {
for _, conn := range conns {
_ = conn.closeWithErr(poolClearedError{
err: fmt.Errorf("interrupted"),
address: p.address,
})
_ = p.checkInWithCallback(conn, func() (reason, bool) {
if mustLogPoolMessage(p) {
keysAndValues := logger.KeyValues{
logger.KeyDriverConnectionID, conn.driverConnectionID,
}

if p.monitor != nil {
p.monitor.Event(&event.PoolEvent{
Type: event.ConnectionCheckedIn,
ConnectionID: conn.driverConnectionID,
Address: conn.addr.String(),
})
}
logPoolMessage(p, logger.ConnectionCheckedIn, keysAndValues...)
}

r, ok := connectionPerished(conn)
if ok {
r = reason{
loggerConn: logger.ReasonConnClosedStale,
event: event.ReasonStale,
}
if p.monitor != nil {
p.monitor.Event(&event.PoolEvent{
Type: event.ConnectionCheckedIn,
ConnectionID: conn.driverConnectionID,
Address: conn.addr.String(),
})
}

r, ok := connectionPerished(conn)
if ok {
r = reason{
loggerConn: logger.ReasonConnClosedStale,
event: event.ReasonStale,
}
return r, ok
})
}
}
return r, ok
})
}
}

Expand All @@ -908,11 +906,9 @@ func (p *pool) interruptInUseConnections() {
// "paused". If serviceID is nil, clear marks all connections as stale. If serviceID is not nil,
// clear marks only connections associated with the given serviceID stale (for use in load balancer
// mode).
func (p *pool) clear(err error, serviceID *primitive.ObjectID) {
p.clearImpl(err, serviceID, false)
}

func (p *pool) clearImpl(err error, serviceID *primitive.ObjectID, interruptInUseConnections bool) {
// If interruptAllConnections is true, this function calls interruptConnections to interrupt all
// non-idle connections.
func (p *pool) clearImpl(err error, serviceID *primitive.ObjectID, interruptAllConnections bool) {
if p.getState() == poolClosed {
return
}
Expand Down Expand Up @@ -953,15 +949,33 @@ func (p *pool) clearImpl(err error, serviceID *primitive.ObjectID, interruptInUs
ServiceID: serviceID,
Error: err,
}
if interruptInUseConnections {
if interruptAllConnections {
event.Interruption = true
}
p.monitor.Event(event)
}

p.removePerishedConns()
if interruptInUseConnections {
p.interruptInUseConnections()
if interruptAllConnections {
p.createConnectionsCond.L.Lock()
p.idleMu.Lock()

idleConns := make(map[*connection]bool, len(p.idleConns))
for _, idle := range p.idleConns {
idleConns[idle] = true
}

conns := make([]*connection, 0, len(p.conns))
for _, conn := range p.conns {
if _, ok := idleConns[conn]; !ok && p.stale(conn) {
conns = append(conns, conn)
}
}

p.idleMu.Unlock()
p.createConnectionsCond.L.Unlock()

p.interruptConnections(conns)
}

if serviceID == nil {
Expand Down Expand Up @@ -1093,11 +1107,7 @@ func (p *pool) createConnections(ctx context.Context, wg *sync.WaitGroup) {
return nil, nil, false
}

w.mu.Lock()
connOpts := w.connOpts
w.mu.Unlock()
connOpts = append(connOpts, p.connOpts...)
conn := newConnection(p.address, connOpts...)
conn := newConnection(p.address, p.connOpts...)
conn.pool = p
conn.driverConnectionID = atomic.AddInt64(&p.nextID, 1)
p.conns[conn.driverConnectionID] = conn
Expand Down Expand Up @@ -1311,8 +1321,6 @@ func compact(arr []*connection) []*connection {
type wantConn struct {
ready chan struct{}

connOpts []ConnectionOption

mu sync.Mutex // Guards conn, err
conn *connection
err error
Expand Down Expand Up @@ -1349,7 +1357,6 @@ func (w *wantConn) tryDeliver(conn *connection, err error) bool {
panic("x/mongo/driver/topology: internal error: misuse of tryDeliver")
}

w.connOpts = w.connOpts[:0]
close(w.ready)

return true
Expand All @@ -1365,7 +1372,6 @@ func (w *wantConn) cancel(p *pool, err error) {

w.mu.Lock()
if w.conn == nil && w.err == nil {
w.connOpts = w.connOpts[:0]
close(w.ready) // catch misbehavior in future delivery
}
conn := w.conn
Expand Down
27 changes: 12 additions & 15 deletions x/mongo/driver/topology/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,30 +126,26 @@ func (d *timeoutDialer) DialContext(ctx context.Context, network, address string
return &timeoutConn{c, d.errors}, e
}

// TestServerHeartbeatTimeout tests timeout retry for GODRIVER-2577.
// TestServerHeartbeatTimeout tests timeout retry and preemptive canceling.
func TestServerHeartbeatTimeout(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}

networkTimeoutError := &net.DNSError{
IsTimeout: true,
}

testCases := []struct {
desc string
ioErrors []error
expectPoolCleared bool
desc string
ioErrors []error
expectInterruptions int
}{
{
desc: "one single timeout should not clear the pool",
ioErrors: []error{nil, networkTimeoutError, nil, networkTimeoutError, nil},
expectPoolCleared: false,
desc: "one single timeout should not clear the pool",
ioErrors: []error{nil, networkTimeoutError, nil, networkTimeoutError, nil},
expectInterruptions: 0,
},
{
desc: "continuous timeouts should clear the pool",
ioErrors: []error{nil, networkTimeoutError, networkTimeoutError, nil},
expectPoolCleared: true,
desc: "continuous timeouts should clear the pool with interruption",
ioErrors: []error{nil, networkTimeoutError, networkTimeoutError, nil},
expectInterruptions: 1,
},
}
for _, tc := range testCases {
Expand Down Expand Up @@ -195,7 +191,8 @@ func TestServerHeartbeatTimeout(t *testing.T) {
)
require.NoError(t, server.Connect(nil))
wg.Wait()
assert.Equal(t, tc.expectPoolCleared, tpm.IsPoolCleared(), "expected pool cleared to be %v but was %v", tc.expectPoolCleared, tpm.IsPoolCleared())
interruptions := tpm.Interruptions()
assert.Equal(t, tc.expectInterruptions, interruptions, "expected %d interruption but got %d", tc.expectInterruptions, interruptions)
})
}
}
Expand Down

0 comments on commit 9dfddd9

Please sign in to comment.