diff --git a/go/mysql/fakesqldb/server.go b/go/mysql/fakesqldb/server.go index 7d45755bd11..0144c44e533 100644 --- a/go/mysql/fakesqldb/server.go +++ b/go/mysql/fakesqldb/server.go @@ -375,11 +375,11 @@ func (db *DB) HandleQuery(c *mysql.Conn, query string, callback func(*sqltypes.R } key := strings.ToLower(query) db.mu.Lock() - defer db.mu.Unlock() db.queryCalled[key]++ db.querylog = append(db.querylog, key) // Check if we should close the connection and provoke errno 2013. if db.shouldClose.Load() { + defer db.mu.Unlock() c.Close() //log error @@ -393,7 +393,9 @@ func (db *DB) HandleQuery(c *mysql.Conn, query string, callback func(*sqltypes.R // The driver may send this at connection time, and we don't want it to // interfere. if key == "set names utf8" || strings.HasPrefix(key, "set collation_connection = ") { - //log error + defer db.mu.Unlock() + + // log error if err := callback(&sqltypes.Result{}); err != nil { log.Errorf("callback failed : %v", err) } @@ -402,12 +404,14 @@ func (db *DB) HandleQuery(c *mysql.Conn, query string, callback func(*sqltypes.R // check if we should reject it. if err, ok := db.rejectedData[key]; ok { + db.mu.Unlock() return err } // Check explicit queries from AddQuery(). result, ok := db.data[key] if ok { + db.mu.Unlock() if f := result.BeforeFunc; f != nil { f() } @@ -418,6 +422,7 @@ func (db *DB) HandleQuery(c *mysql.Conn, query string, callback func(*sqltypes.R for _, pat := range db.patternData { if pat.expr.MatchString(query) { userCallback, ok := db.queryPatternUserCallback[pat.expr] + db.mu.Unlock() if ok { userCallback(query) } @@ -428,6 +433,8 @@ func (db *DB) HandleQuery(c *mysql.Conn, query string, callback func(*sqltypes.R } } + defer db.mu.Unlock() + if db.neverFail.Load() { return callback(&sqltypes.Result{}) } diff --git a/go/vt/vttablet/tabletserver/connpool/dbconn.go b/go/vt/vttablet/tabletserver/connpool/dbconn.go index c2fdf991056..7affcb0a7c0 100644 --- a/go/vt/vttablet/tabletserver/connpool/dbconn.go +++ b/go/vt/vttablet/tabletserver/connpool/dbconn.go @@ -21,6 +21,7 @@ import ( "fmt" "strings" "sync" + "sync/atomic" "time" "vitess.io/vitess/go/pools" @@ -30,7 +31,6 @@ import ( "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/sqltypes" - "vitess.io/vitess/go/sync2" "vitess.io/vitess/go/trace" "vitess.io/vitess/go/vt/dbconnpool" "vitess.io/vitess/go/vt/log" @@ -41,6 +41,8 @@ import ( vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" ) +const defaultKillTimeout = 5 * time.Second + // DBConn is a db connection for tabletserver. // It performs automatic reconnects as needed. // Its Execute function has a timeout that can kill @@ -52,7 +54,7 @@ type DBConn struct { pool *Pool dbaPool *dbconnpool.ConnectionPool stats *tabletenv.Stats - current sync2.AtomicString + current atomic.Pointer[string] timeCreated time.Time setting string resetSetting string @@ -60,6 +62,8 @@ type DBConn struct { // err will be set if a query is killed through a Kill. errmu sync.Mutex err error + + killTimeout time.Duration } // NewDBConn creates a new DBConn. It triggers a CheckMySQL if creation fails. @@ -78,6 +82,7 @@ func NewDBConn(ctx context.Context, cp *Pool, appParams dbconfigs.Connector) (*D info: appParams, pool: cp, dbaPool: cp.dbaPool, + killTimeout: defaultKillTimeout, timeCreated: time.Now(), stats: cp.env.Stats(), }, nil @@ -95,6 +100,7 @@ func NewDBConnNoPool(ctx context.Context, params dbconfigs.Connector, dbaPool *d dbaPool: dbaPool, pool: nil, timeCreated: time.Now(), + killTimeout: defaultKillTimeout, stats: tabletenv.NewStats(servenv.NewExporter("Temp", "Tablet")), } if setting == nil { @@ -157,8 +163,8 @@ func (dbc *DBConn) Exec(ctx context.Context, query string, maxrows int, wantfiel } func (dbc *DBConn) execOnce(ctx context.Context, query string, maxrows int, wantfields bool) (*sqltypes.Result, error) { - dbc.current.Set(query) - defer dbc.current.Set("") + dbc.current.Store(&query) + defer dbc.current.Store(nil) // Check if the context is already past its deadline before // trying to execute the query. @@ -168,19 +174,33 @@ func (dbc *DBConn) execOnce(ctx context.Context, query string, maxrows int, want default: } - defer dbc.stats.MySQLTimings.Record("Exec", time.Now()) - - done, wg := dbc.setDeadline(ctx) - qr, err := dbc.conn.ExecuteFetch(query, maxrows, wantfields) + now := time.Now() + defer dbc.stats.MySQLTimings.Record("Exec", now) - if done != nil { - close(done) - wg.Wait() + type execResult struct { + result *sqltypes.Result + err error } - if dbcerr := dbc.Err(); dbcerr != nil { - return nil, dbcerr + + ch := make(chan execResult) + go func() { + result, err := dbc.conn.ExecuteFetch(query, maxrows, wantfields) + ch <- execResult{result, err} + }() + + select { + case <-ctx.Done(): + killCtx, cancel := context.WithTimeout(context.Background(), dbc.killTimeout) + defer cancel() + + _ = dbc.KillWithContext(killCtx, ctx.Err().Error(), time.Since(now)) + return nil, dbc.Err() + case r := <-ch: + if dbcErr := dbc.Err(); dbcErr != nil { + return nil, dbcErr + } + return r.result, r.err } - return qr, err } // ExecOnce executes the specified query, but does not retry on connection errors. @@ -260,22 +280,30 @@ func (dbc *DBConn) Stream(ctx context.Context, query string, callback func(*sqlt } func (dbc *DBConn) streamOnce(ctx context.Context, query string, callback func(*sqltypes.Result) error, alloc func() *sqltypes.Result, streamBufferSize int) error { - defer dbc.stats.MySQLTimings.Record("ExecStream", time.Now()) + dbc.current.Store(&query) + defer dbc.current.Store(nil) - dbc.current.Set(query) - defer dbc.current.Set("") + now := time.Now() + defer dbc.stats.MySQLTimings.Record("ExecStream", now) - done, wg := dbc.setDeadline(ctx) - err := dbc.conn.ExecuteStreamFetch(query, callback, alloc, streamBufferSize) + ch := make(chan error) + go func() { + ch <- dbc.conn.ExecuteStreamFetch(query, callback, alloc, streamBufferSize) + }() - if done != nil { - close(done) - wg.Wait() - } - if dbcerr := dbc.Err(); dbcerr != nil { - return dbcerr + select { + case <-ctx.Done(): + killCtx, cancel := context.WithTimeout(context.Background(), dbc.killTimeout) + defer cancel() + + _ = dbc.KillWithContext(killCtx, ctx.Err().Error(), time.Since(now)) + return dbc.Err() + case err := <-ch: + if dbcErr := dbc.Err(); dbcErr != nil { + return dbcErr + } + return err } - return err } // StreamOnce executes the query and streams the results. But, does not retry on connection errors. @@ -414,6 +442,16 @@ func (dbc *DBConn) Taint() { // and on the connection side. If no query is executing, it's a no-op. // Kill will also not kill a query more than once. func (dbc *DBConn) Kill(reason string, elapsed time.Duration) error { + return dbc.KillWithContext(context.Background(), reason, elapsed) +} + +// KillWithContext kills the currently executing query both on MySQL side +// and on the connection side. If no query is executing, it's a no-op. +// Kill will also not kill a query more than once. +func (dbc *DBConn) KillWithContext(ctx context.Context, reason string, elapsed time.Duration) error { + if cause := context.Cause(ctx); cause != nil { + return cause + } dbc.stats.KillCounters.Add("Queries", 1) log.Infof("Due to %s, elapsed time: %v, killing query ID %v %s", reason, elapsed, dbc.conn.ID(), dbc.CurrentForLogging()) @@ -424,25 +462,43 @@ func (dbc *DBConn) Kill(reason string, elapsed time.Duration) error { dbc.conn.Close() // Server side action. Kill the session. - killConn, err := dbc.dbaPool.Get(context.TODO()) + killConn, err := dbc.dbaPool.Get(ctx) if err != nil { log.Warningf("Failed to get conn from dba pool: %v", err) return err } defer killConn.Recycle() + + ch := make(chan error) sql := fmt.Sprintf("kill %d", dbc.conn.ID()) - _, err = killConn.ExecuteFetch(sql, 10000, false) - if err != nil { - log.Errorf("Could not kill query ID %v %s: %v", dbc.conn.ID(), - dbc.CurrentForLogging(), err) - return err + go func() { + _, err := killConn.ExecuteFetch(sql, -1, false) + ch <- err + }() + + select { + case <-ctx.Done(): + killConn.Close() + + dbc.stats.InternalErrors.Add("HungQuery", 1) + log.Warningf("Query may be hung: %s", dbc.CurrentForLogging()) + + return context.Cause(ctx) + case err := <-ch: + if err != nil { + log.Errorf("Could not kill query ID %v %s: %v", dbc.conn.ID(), dbc.CurrentForLogging(), err) + return err + } + return nil } - return nil } // Current returns the currently executing query. func (dbc *DBConn) Current() string { - return dbc.current.Get() + if q := dbc.current.Load(); q != nil { + return *q + } + return "" } // ID returns the connection id. @@ -480,45 +536,6 @@ func (dbc *DBConn) reconnect(ctx context.Context) error { return nil } -// setDeadline starts a goroutine that will kill the currently executing query -// if the deadline is exceeded. It returns a channel and a waitgroup. After the -// query is done executing, the caller is required to close the done channel -// and wait for the waitgroup to make sure that the necessary cleanup is done. -func (dbc *DBConn) setDeadline(ctx context.Context) (chan bool, *sync.WaitGroup) { - if ctx.Done() == nil { - return nil, nil - } - done := make(chan bool) - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - startTime := time.Now() - select { - case <-ctx.Done(): - dbc.Kill(ctx.Err().Error(), time.Since(startTime)) - case <-done: - return - } - elapsed := time.Since(startTime) - - // Give 2x the elapsed time and some buffer as grace period - // for the query to get killed. - tmr2 := time.NewTimer(2*elapsed + 5*time.Second) - defer tmr2.Stop() - select { - case <-tmr2.C: - dbc.stats.InternalErrors.Add("HungQuery", 1) - log.Warningf("Query may be hung: %s", dbc.CurrentForLogging()) - case <-done: - return - } - <-done - log.Warningf("Hung query returned") - }() - return done, &wg -} - // CurrentForLogging applies transformations to the query making it suitable to log. // It applies sanitization rules based on tablet settings and limits the max length of // queries. diff --git a/go/vt/vttablet/tabletserver/connpool/dbconn_test.go b/go/vt/vttablet/tabletserver/connpool/dbconn_test.go index 62ec0b6d12e..9e30d069fec 100644 --- a/go/vt/vttablet/tabletserver/connpool/dbconn_test.go +++ b/go/vt/vttablet/tabletserver/connpool/dbconn_test.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" "strings" + "sync/atomic" "testing" "time" @@ -32,6 +33,8 @@ import ( "vitess.io/vitess/go/pools" "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/vterrors" ) func compareTimingCounts(t *testing.T, op string, delta int64, before, after map[string]int64) { @@ -290,6 +293,59 @@ func TestDBConnKill(t *testing.T) { } } +func TestDBKillWithContext(t *testing.T) { + db := fakesqldb.New(t) + defer db.Close() + connPool := newPool() + params := db.ConnParams() + connPool.Open(params, params, params) + defer connPool.Close() + dbConn, err := NewDBConn(context.Background(), connPool, params) + if dbConn != nil { + defer dbConn.Close() + } + require.NoError(t, err) + + query := fmt.Sprintf("kill %d", dbConn.ID()) + db.AddQuery(query, &sqltypes.Result{}) + db.SetBeforeFunc(query, func() { + // should take longer than our context deadline below. + time.Sleep(200 * time.Millisecond) + }) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + // KillWithContext should return context.DeadlineExceeded + err = dbConn.KillWithContext(ctx, "test kill", 0) + require.ErrorIs(t, err, context.DeadlineExceeded) +} + +func TestDBKillWithContextDoneContext(t *testing.T) { + db := fakesqldb.New(t) + defer db.Close() + connPool := newPool() + params := db.ConnParams() + connPool.Open(params, params, params) + defer connPool.Close() + dbConn, err := NewDBConn(context.Background(), connPool, params) + if dbConn != nil { + defer dbConn.Close() + } + require.NoError(t, err) + + query := fmt.Sprintf("kill %d", dbConn.ID()) + db.AddRejectedQuery(query, errors.New("rejected")) + + contextErr := errors.New("context error") + ctx, cancel := context.WithCancelCause(context.Background()) + cancel(contextErr) // cancel the context immediately + + // KillWithContext should return the cancellation cause + err = dbConn.KillWithContext(ctx, "test kill", 0) + require.ErrorIs(t, err, contextErr) +} + // TestDBConnClose tests that an Exec returns immediately if a connection // is asynchronously killed (and closed) in the middle of an execution. func TestDBConnClose(t *testing.T) { @@ -518,3 +574,51 @@ func TestDBConnReApplySetting(t *testing.T) { db.VerifyAllExecutedOrFail() } + +func TestDBExecOnceKillTimeout(t *testing.T) { + db := fakesqldb.New(t) + defer db.Close() + connPool := newPool() + params := db.ConnParams() + connPool.Open(params, params, params) + defer connPool.Close() + dbConn, err := NewDBConn(context.Background(), connPool, params) + if dbConn != nil { + defer dbConn.Close() + } + require.NoError(t, err) + + // A very long running query that will be killed. + expectedQuery := "select 1" + var timestampQuery atomic.Int64 + db.AddQuery(expectedQuery, &sqltypes.Result{}) + db.SetBeforeFunc(expectedQuery, func() { + timestampQuery.Store(time.Now().UnixMicro()) + // should take longer than our context deadline below. + time.Sleep(1000 * time.Millisecond) + }) + + // We expect a kill-query to be fired, too. + // It should also run into a timeout. + var timestampKill atomic.Int64 + dbConn.killTimeout = 100 * time.Millisecond + db.AddQueryPatternWithCallback(`kill \d+`, &sqltypes.Result{}, func(string) { + timestampKill.Store(time.Now().UnixMicro()) + // should take longer than the configured kill timeout above. + time.Sleep(200 * time.Millisecond) + }) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + result, err := dbConn.ExecOnce(ctx, "select 1", 1, false) + timeDone := time.Now() + + require.Error(t, err) + require.Equal(t, vtrpcpb.Code_CANCELED, vterrors.Code(err)) + require.Nil(t, result) + timeQuery := time.UnixMicro(timestampQuery.Load()) + timeKill := time.UnixMicro(timestampKill.Load()) + require.WithinDuration(t, timeQuery, timeKill, 150*time.Millisecond) + require.WithinDuration(t, timeKill, timeDone, 150*time.Millisecond) +} diff --git a/go/vt/vttablet/tabletserver/tabletserver_test.go b/go/vt/vttablet/tabletserver/tabletserver_test.go index ac6b20cab8a..b524cc117e4 100644 --- a/go/vt/vttablet/tabletserver/tabletserver_test.go +++ b/go/vt/vttablet/tabletserver/tabletserver_test.go @@ -992,6 +992,20 @@ func TestSerializeTransactionsSameRow_ConcurrentTransactions(t *testing.T) { db.SetBeforeFunc("update test_table set name_string = 'tx1' where pk = 1 and `name` = 1 limit 10001", func() { close(tx1Started) + + // Wait for other queries to be pending. + <-allQueriesPending + }) + + db.SetBeforeFunc("update test_table set name_string = 'tx2' where pk = 1 and `name` = 1 limit 10001", + func() { + // Wait for other queries to be pending. + <-allQueriesPending + }) + + db.SetBeforeFunc("update test_table set name_string = 'tx3' where pk = 1 and `name` = 1 limit 10001", + func() { + // Wait for other queries to be pending. <-allQueriesPending }) @@ -1060,6 +1074,8 @@ func TestSerializeTransactionsSameRow_ConcurrentTransactions(t *testing.T) { // to allow more than connection attempt at a time. err := waitForTxSerializationPendingQueries(tsv, "test_table where pk = 1 and `name` = 1", 3) require.NoError(t, err) + + // Signal that all queries are pending now. close(allQueriesPending) wg.Wait()