From f4609f8c1cf3c98e2e821e5f8e3a922acc2106fa Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Wed, 17 Apr 2024 10:48:14 +0200 Subject: [PATCH] Fix the race condition during vttablet startup This avoids the problem where the connection pool is poisoned when we check for the MySQL port, by avoiding to use the pool in the first place. We only ever run this once at startup, so we can create a new connection here and then dispose of it once we've retrieved the port. That way we know the connection pool is still clean and doesn't have any problems. Fixes #15730 Signed-off-by: Dirkjan Bussink --- go/test/endtoend/utils/mysql_test.go | 29 +++++++++++++++++++------ go/vt/mysqlctl/fakemysqldaemon.go | 2 +- go/vt/mysqlctl/mysql_daemon.go | 2 +- go/vt/mysqlctl/mysqld.go | 6 ++++- go/vt/mysqlctl/replication.go | 18 +++++++++++++-- go/vt/mysqlctl/replication_test.go | 7 ++++-- go/vt/vttablet/tabletmanager/tm_init.go | 16 ++++++++++---- 7 files changed, 62 insertions(+), 18 deletions(-) diff --git a/go/test/endtoend/utils/mysql_test.go b/go/test/endtoend/utils/mysql_test.go index ae550e34864..c8c09a3f979 100644 --- a/go/test/endtoend/utils/mysql_test.go +++ b/go/test/endtoend/utils/mysql_test.go @@ -21,6 +21,7 @@ import ( "fmt" "os" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -129,7 +130,9 @@ func TestSetSuperReadOnlyMySQL(t *testing.T) { func TestGetMysqlPort(t *testing.T) { require.NotNil(t, mysqld) - port, err := mysqld.GetMysqlPort() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + port, err := mysqld.GetMysqlPort(ctx) // Expected port should be one less than the port returned by GetAndReservePort // As we are calling this second time to get port @@ -161,7 +164,9 @@ func TestReplicationStatus(t *testing.T) { conn, err := mysql.Connect(ctx, &mysqlParams) require.NoError(t, err) - port, err := mysqld.GetMysqlPort() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + port, err := mysqld.GetMysqlPort(ctx) require.NoError(t, err) host := "localhost" @@ -234,7 +239,9 @@ func TestSetReplicationPosition(t *testing.T) { func TestSetAndResetReplication(t *testing.T) { require.NotNil(t, mysqld) - port, err := mysqld.GetMysqlPort() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + port, err := mysqld.GetMysqlPort(ctx) require.NoError(t, err) host := "localhost" @@ -387,7 +394,9 @@ func TestWaitForReplicationStart(t *testing.T) { err := mysqlctl.WaitForReplicationStart(mysqld, 1) assert.ErrorContains(t, err, "no replication status") - port, err := mysqld.GetMysqlPort() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + port, err := mysqld.GetMysqlPort(ctx) require.NoError(t, err) host := "localhost" @@ -407,7 +416,9 @@ func TestStartReplication(t *testing.T) { err := mysqld.StartReplication(map[string]string{}) assert.ErrorContains(t, err, "The server is not configured as replica") - port, err := mysqld.GetMysqlPort() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + port, err := mysqld.GetMysqlPort(ctx) require.NoError(t, err) host := "localhost" @@ -425,7 +436,9 @@ func TestStartReplication(t *testing.T) { func TestStopReplication(t *testing.T) { require.NotNil(t, mysqld) - port, err := mysqld.GetMysqlPort() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + port, err := mysqld.GetMysqlPort(ctx) require.NoError(t, err) host := "localhost" @@ -449,7 +462,9 @@ func TestStopReplication(t *testing.T) { func TestStopSQLThread(t *testing.T) { require.NotNil(t, mysqld) - port, err := mysqld.GetMysqlPort() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + port, err := mysqld.GetMysqlPort(ctx) require.NoError(t, err) host := "localhost" diff --git a/go/vt/mysqlctl/fakemysqldaemon.go b/go/vt/mysqlctl/fakemysqldaemon.go index 94c1f7f52c1..e0d9a6c1252 100644 --- a/go/vt/mysqlctl/fakemysqldaemon.go +++ b/go/vt/mysqlctl/fakemysqldaemon.go @@ -273,7 +273,7 @@ func (fmd *FakeMysqlDaemon) WaitForDBAGrants(ctx context.Context, waitTime time. } // GetMysqlPort is part of the MysqlDaemon interface. -func (fmd *FakeMysqlDaemon) GetMysqlPort() (int32, error) { +func (fmd *FakeMysqlDaemon) GetMysqlPort(ctx context.Context) (int32, error) { if fmd.MysqlPort.Load() == -1 { return 0, fmt.Errorf("FakeMysqlDaemon.GetMysqlPort returns an error") } diff --git a/go/vt/mysqlctl/mysql_daemon.go b/go/vt/mysqlctl/mysql_daemon.go index 4829af3d4f7..cb9882e7052 100644 --- a/go/vt/mysqlctl/mysql_daemon.go +++ b/go/vt/mysqlctl/mysql_daemon.go @@ -43,7 +43,7 @@ type MysqlDaemon interface { WaitForDBAGrants(ctx context.Context, waitTime time.Duration) (err error) // GetMysqlPort returns the current port mysql is listening on. - GetMysqlPort() (int32, error) + GetMysqlPort(ctx context.Context) (int32, error) // GetServerID returns the servers ID. GetServerID(ctx context.Context) (uint32, error) diff --git a/go/vt/mysqlctl/mysqld.go b/go/vt/mysqlctl/mysqld.go index a1f6e257887..433ccc7de64 100644 --- a/go/vt/mysqlctl/mysqld.go +++ b/go/vt/mysqlctl/mysqld.go @@ -521,11 +521,15 @@ func (mysqld *Mysqld) WaitForDBAGrants(ctx context.Context, waitTime time.Durati if waitTime == 0 { return nil } + params, err := mysqld.dbcfgs.DbaConnector().MysqlParams() + if err != nil { + return err + } timer := time.NewTimer(waitTime) ctx, cancel := context.WithTimeout(ctx, waitTime) defer cancel() for { - conn, connErr := dbconnpool.NewDBConnection(ctx, mysqld.dbcfgs.DbaConnector()) + conn, connErr := mysql.Connect(ctx, params) if connErr == nil { res, fetchErr := conn.ExecuteFetch("SHOW GRANTS", 1000, false) conn.Close() diff --git a/go/vt/mysqlctl/replication.go b/go/vt/mysqlctl/replication.go index e53caac593d..8603b172606 100644 --- a/go/vt/mysqlctl/replication.go +++ b/go/vt/mysqlctl/replication.go @@ -29,6 +29,7 @@ import ( "strings" "time" + "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/mysql/replication" "vitess.io/vitess/go/netutil" "vitess.io/vitess/go/vt/hook" @@ -174,8 +175,21 @@ func (mysqld *Mysqld) RestartReplication(hookExtraEnv map[string]string) error { } // GetMysqlPort returns mysql port -func (mysqld *Mysqld) GetMysqlPort() (int32, error) { - qr, err := mysqld.FetchSuperQuery(context.TODO(), "SHOW VARIABLES LIKE 'port'") +func (mysqld *Mysqld) GetMysqlPort(ctx context.Context) (int32, error) { + // We can not use the connection pool here. This check runs very early + // during MySQL startup when we still might be loading things like grants. + // This means we need to use an isolated connection to avoid poisoning the + // DBA connection pool for further queries. + params, err := mysqld.dbcfgs.DbaConnector().MysqlParams() + if err != nil { + return 0, err + } + conn, err := mysql.Connect(ctx, params) + if err != nil { + return 0, err + } + defer conn.Close() + qr, err := conn.ExecuteFetch("SHOW VARIABLES LIKE 'port'", 1, false) if err != nil { return 0, err } diff --git a/go/vt/mysqlctl/replication_test.go b/go/vt/mysqlctl/replication_test.go index e171379f668..d117a96ab89 100644 --- a/go/vt/mysqlctl/replication_test.go +++ b/go/vt/mysqlctl/replication_test.go @@ -21,6 +21,7 @@ import ( "fmt" "net" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -133,12 +134,14 @@ func TestGetMysqlPort(t *testing.T) { testMysqld := NewMysqld(dbc) defer testMysqld.Close() - res, err := testMysqld.GetMysqlPort() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + res, err := testMysqld.GetMysqlPort(ctx) assert.Equal(t, int32(12), res) assert.NoError(t, err) db.AddQuery("SHOW VARIABLES LIKE 'port'", &sqltypes.Result{}) - res, err = testMysqld.GetMysqlPort() + res, err = testMysqld.GetMysqlPort(ctx) assert.ErrorContains(t, err, "no port variable in mysql") assert.Equal(t, int32(0), res) } diff --git a/go/vt/vttablet/tabletmanager/tm_init.go b/go/vt/vttablet/tabletmanager/tm_init.go index efb6c5e878f..5a4d7bcdb16 100644 --- a/go/vt/vttablet/tabletmanager/tm_init.go +++ b/go/vt/vttablet/tabletmanager/tm_init.go @@ -366,7 +366,7 @@ func (tm *TabletManager) Start(tablet *topodatapb.Tablet, config *tabletenv.Tabl if err := tm.checkPrimaryShip(ctx, si); err != nil { return err } - if err := tm.checkMysql(); err != nil { + if err := tm.checkMysql(ctx); err != nil { return err } if err := tm.initTablet(ctx); err != nil { @@ -702,7 +702,7 @@ func (tm *TabletManager) checkPrimaryShip(ctx context.Context, si *topo.ShardInf return nil } -func (tm *TabletManager) checkMysql() error { +func (tm *TabletManager) checkMysql(ctx context.Context) error { appConfig, err := tm.DBConfigs.AppWithDB().MysqlParams() if err != nil { return err @@ -717,7 +717,7 @@ func (tm *TabletManager) checkMysql() error { tm.tmState.UpdateTablet(func(tablet *topodatapb.Tablet) { tablet.MysqlHostname = tablet.Hostname }) - mysqlPort, err := tm.MysqlDaemon.GetMysqlPort() + mysqlPort, err := tm.MysqlDaemon.GetMysqlPort(ctx) if err != nil { log.Warningf("Cannot get current mysql port, will keep retrying every %v: %v", mysqlPortRetryInterval, err) go tm.findMysqlPort(mysqlPortRetryInterval) @@ -730,10 +730,18 @@ func (tm *TabletManager) checkMysql() error { return nil } +const portCheckTimeout = 5 * time.Second + +func (tm *TabletManager) getMysqlPort() (int32, error) { + ctx, cancel := context.WithTimeout(context.Background(), portCheckTimeout) + defer cancel() + return tm.MysqlDaemon.GetMysqlPort(ctx) +} + func (tm *TabletManager) findMysqlPort(retryInterval time.Duration) { for { time.Sleep(retryInterval) - mport, err := tm.MysqlDaemon.GetMysqlPort() + mport, err := tm.getMysqlPort() if err != nil || mport == 0 { continue }