diff --git a/internal/eventtest/eventtest.go b/internal/eventtest/eventtest.go index c06037e850..1158d69933 100644 --- a/internal/eventtest/eventtest.go +++ b/internal/eventtest/eventtest.go @@ -76,3 +76,11 @@ func (tpm *TestPoolMonitor) IsPoolCleared() bool { }) return len(poolClearedEvents) > 0 } + +// Interruptions returns the number of interruptions in the events recorded by the testPoolMonitor. +func (tpm *TestPoolMonitor) Interruptions() int { + interruptions := tpm.Events(func(evt *event.PoolEvent) bool { + return evt.Interruption + }) + return len(interruptions) +} diff --git a/x/mongo/driver/topology/connection.go b/x/mongo/driver/topology/connection.go index 66162723b0..b12f0109a1 100644 --- a/x/mongo/driver/topology/connection.go +++ b/x/mongo/driver/topology/connection.go @@ -48,7 +48,6 @@ type connection struct { // - atomic bug: https://pkg.go.dev/sync/atomic#pkg-note-BUG // - suggested layout: https://go101.org/article/memory-layout.html state int64 - inUse bool err error id string @@ -87,7 +86,6 @@ func newConnection(addr address.Address, opts ...ConnectionOption) *connection { id := fmt.Sprintf("%s[-%d]", addr, nextConnectionID()) c := &connection{ - inUse: cfg.inUse, id: id, addr: addr, idleTimeout: cfg.idleTimeout, diff --git a/x/mongo/driver/topology/connection_options.go b/x/mongo/driver/topology/connection_options.go index 7993268bd7..43e6f3f507 100644 --- a/x/mongo/driver/topology/connection_options.go +++ b/x/mongo/driver/topology/connection_options.go @@ -48,7 +48,6 @@ type Handshaker = driver.Handshaker type generationNumberFn func(serviceID *primitive.ObjectID) uint64 type connectionConfig struct { - inUse bool connectTimeout time.Duration dialer Dialer handshaker Handshaker diff --git a/x/mongo/driver/topology/pool.go b/x/mongo/driver/topology/pool.go index 756941ea69..968748f935 100644 --- a/x/mongo/driver/topology/pool.go +++ b/x/mongo/driver/topology/pool.go @@ -535,9 +535,6 @@ func (p *pool) checkOut(ctx context.Context) (conn *connection, err error) { // timed out). w := newWantConn() defer func() { - if conn != nil { - conn.inUse = true - } if err != nil { w.cancel(p, err) } @@ -592,11 +589,6 @@ func (p *pool) checkOut(ctx context.Context) (conn *connection, err error) { // If we didn't get an immediately available idle connection, also get in the queue for a new // connection while we're waiting for an idle connection. - w.mu.Lock() - w.connOpts = append(w.connOpts, func(cfg *connectionConfig) { - cfg.inUse = true - }) - w.mu.Unlock() p.queueForNewConn(w) p.stateMu.RUnlock() @@ -759,8 +751,7 @@ func (p *pool) removeConnection(conn *connection, reason reason, err error) erro return nil } -// checkIn returns an idle connection to the pool. If the connection is perished or the pool is -// closed, it is removed from the connection pool and closed. +// checkIn returns an idle connection to the pool. It calls checkInWithCallback internally. func (p *pool) checkIn(conn *connection) error { return p.checkInWithCallback(conn, func() (reason, bool) { if mustLogPoolMessage(p) { @@ -807,7 +798,12 @@ func (p *pool) checkInNoEvent(conn *connection) error { }) } -func (p *pool) checkInWithCallback(conn *connection, cb func() (reason, bool)) error { +// checkInWithCallback returns a connection to the pool. If the connection is perished or the pool is +// closed, it is removed from the connection pool and closed. +// The callback parameter is expected to returns a reason of the check-in and a boolean value to +// indicate whether the connection is perished. +// Events and logs can also be added in the callback function. +func (p *pool) checkInWithCallback(conn *connection, callback func() (reason, bool)) error { if conn == nil { return nil } @@ -815,8 +811,6 @@ func (p *pool) checkInWithCallback(conn *connection, cb func() (reason, bool)) e return ErrWrongPool } - conn.inUse = false - // Bump the connection idle deadline here because we're about to make the connection "available". // The idle deadline is used to determine when a connection has reached its max idle time and // should be closed. A connection reaches its max idle time when it has been "available" in the @@ -827,8 +821,8 @@ func (p *pool) checkInWithCallback(conn *connection, cb func() (reason, bool)) e var r reason var perished bool - if cb != nil { - r, perished = cb() + if callback != nil { + r, perished = callback() } if perished { _ = p.removeConnection(conn, r, nil) @@ -861,45 +855,49 @@ func (p *pool) checkInWithCallback(conn *connection, cb func() (reason, bool)) e return nil } -// clearAll does same as the "clear" method and interrupts all in-use connections as well. +// clear calls clearImpl internally with a false interruptAllConnections value. +func (p *pool) clear(err error, serviceID *primitive.ObjectID) { + p.clearImpl(err, serviceID, false) +} + +// clearAll does same as the "clear" method but interrupts all connections. func (p *pool) clearAll(err error, serviceID *primitive.ObjectID) { p.clearImpl(err, serviceID, true) } -func (p *pool) interruptInUseConnections() { - for _, conn := range p.conns { - if conn.inUse && p.stale(conn) { - _ = conn.closeWithErr(poolClearedError{ - err: fmt.Errorf("interrupted"), - address: p.address, - }) - _ = p.checkInWithCallback(conn, func() (reason, bool) { - if mustLogPoolMessage(p) { - keysAndValues := logger.KeyValues{ - logger.KeyDriverConnectionID, conn.driverConnectionID, - } - - logPoolMessage(p, logger.ConnectionCheckedIn, keysAndValues...) +// interruptConnections interrupts the input connections. +func (p *pool) interruptConnections(conns []*connection) { + for _, conn := range conns { + _ = conn.closeWithErr(poolClearedError{ + err: fmt.Errorf("interrupted"), + address: p.address, + }) + _ = p.checkInWithCallback(conn, func() (reason, bool) { + if mustLogPoolMessage(p) { + keysAndValues := logger.KeyValues{ + logger.KeyDriverConnectionID, conn.driverConnectionID, } - if p.monitor != nil { - p.monitor.Event(&event.PoolEvent{ - Type: event.ConnectionCheckedIn, - ConnectionID: conn.driverConnectionID, - Address: conn.addr.String(), - }) - } + logPoolMessage(p, logger.ConnectionCheckedIn, keysAndValues...) + } - r, ok := connectionPerished(conn) - if ok { - r = reason{ - loggerConn: logger.ReasonConnClosedStale, - event: event.ReasonStale, - } + if p.monitor != nil { + p.monitor.Event(&event.PoolEvent{ + Type: event.ConnectionCheckedIn, + ConnectionID: conn.driverConnectionID, + Address: conn.addr.String(), + }) + } + + r, ok := connectionPerished(conn) + if ok { + r = reason{ + loggerConn: logger.ReasonConnClosedStale, + event: event.ReasonStale, } - return r, ok - }) - } + } + return r, ok + }) } } @@ -908,11 +906,9 @@ func (p *pool) interruptInUseConnections() { // "paused". If serviceID is nil, clear marks all connections as stale. If serviceID is not nil, // clear marks only connections associated with the given serviceID stale (for use in load balancer // mode). -func (p *pool) clear(err error, serviceID *primitive.ObjectID) { - p.clearImpl(err, serviceID, false) -} - -func (p *pool) clearImpl(err error, serviceID *primitive.ObjectID, interruptInUseConnections bool) { +// If interruptAllConnections is true, this function calls interruptConnections to interrupt all +// non-idle connections. +func (p *pool) clearImpl(err error, serviceID *primitive.ObjectID, interruptAllConnections bool) { if p.getState() == poolClosed { return } @@ -953,15 +949,33 @@ func (p *pool) clearImpl(err error, serviceID *primitive.ObjectID, interruptInUs ServiceID: serviceID, Error: err, } - if interruptInUseConnections { + if interruptAllConnections { event.Interruption = true } p.monitor.Event(event) } p.removePerishedConns() - if interruptInUseConnections { - p.interruptInUseConnections() + if interruptAllConnections { + p.createConnectionsCond.L.Lock() + p.idleMu.Lock() + + idleConns := make(map[*connection]bool, len(p.idleConns)) + for _, idle := range p.idleConns { + idleConns[idle] = true + } + + conns := make([]*connection, 0, len(p.conns)) + for _, conn := range p.conns { + if _, ok := idleConns[conn]; !ok && p.stale(conn) { + conns = append(conns, conn) + } + } + + p.idleMu.Unlock() + p.createConnectionsCond.L.Unlock() + + p.interruptConnections(conns) } if serviceID == nil { @@ -1093,11 +1107,7 @@ func (p *pool) createConnections(ctx context.Context, wg *sync.WaitGroup) { return nil, nil, false } - w.mu.Lock() - connOpts := w.connOpts - w.mu.Unlock() - connOpts = append(connOpts, p.connOpts...) - conn := newConnection(p.address, connOpts...) + conn := newConnection(p.address, p.connOpts...) conn.pool = p conn.driverConnectionID = atomic.AddInt64(&p.nextID, 1) p.conns[conn.driverConnectionID] = conn @@ -1311,8 +1321,6 @@ func compact(arr []*connection) []*connection { type wantConn struct { ready chan struct{} - connOpts []ConnectionOption - mu sync.Mutex // Guards conn, err conn *connection err error @@ -1349,7 +1357,6 @@ func (w *wantConn) tryDeliver(conn *connection, err error) bool { panic("x/mongo/driver/topology: internal error: misuse of tryDeliver") } - w.connOpts = w.connOpts[:0] close(w.ready) return true @@ -1365,7 +1372,6 @@ func (w *wantConn) cancel(p *pool, err error) { w.mu.Lock() if w.conn == nil && w.err == nil { - w.connOpts = w.connOpts[:0] close(w.ready) // catch misbehavior in future delivery } conn := w.conn diff --git a/x/mongo/driver/topology/server_test.go b/x/mongo/driver/topology/server_test.go index 79c4c67137..84c6481610 100644 --- a/x/mongo/driver/topology/server_test.go +++ b/x/mongo/driver/topology/server_test.go @@ -126,30 +126,26 @@ func (d *timeoutDialer) DialContext(ctx context.Context, network, address string return &timeoutConn{c, d.errors}, e } -// TestServerHeartbeatTimeout tests timeout retry for GODRIVER-2577. +// TestServerHeartbeatTimeout tests timeout retry and preemptive canceling. func TestServerHeartbeatTimeout(t *testing.T) { - if testing.Short() { - t.Skip("skipping integration test in short mode") - } - networkTimeoutError := &net.DNSError{ IsTimeout: true, } testCases := []struct { - desc string - ioErrors []error - expectPoolCleared bool + desc string + ioErrors []error + expectInterruptions int }{ { - desc: "one single timeout should not clear the pool", - ioErrors: []error{nil, networkTimeoutError, nil, networkTimeoutError, nil}, - expectPoolCleared: false, + desc: "one single timeout should not clear the pool", + ioErrors: []error{nil, networkTimeoutError, nil, networkTimeoutError, nil}, + expectInterruptions: 0, }, { - desc: "continuous timeouts should clear the pool", - ioErrors: []error{nil, networkTimeoutError, networkTimeoutError, nil}, - expectPoolCleared: true, + desc: "continuous timeouts should clear the pool with interruption", + ioErrors: []error{nil, networkTimeoutError, networkTimeoutError, nil}, + expectInterruptions: 1, }, } for _, tc := range testCases { @@ -195,7 +191,8 @@ func TestServerHeartbeatTimeout(t *testing.T) { ) require.NoError(t, server.Connect(nil)) wg.Wait() - assert.Equal(t, tc.expectPoolCleared, tpm.IsPoolCleared(), "expected pool cleared to be %v but was %v", tc.expectPoolCleared, tpm.IsPoolCleared()) + interruptions := tpm.Interruptions() + assert.Equal(t, tc.expectInterruptions, interruptions, "expected %d interruption but got %d", tc.expectInterruptions, interruptions) }) } }