diff --git a/go/test/endtoend/utils/mysql_test.go b/go/test/endtoend/utils/mysql_test.go index ae550e34864..c8c09a3f979 100644 --- a/go/test/endtoend/utils/mysql_test.go +++ b/go/test/endtoend/utils/mysql_test.go @@ -21,6 +21,7 @@ import ( "fmt" "os" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -129,7 +130,9 @@ func TestSetSuperReadOnlyMySQL(t *testing.T) { func TestGetMysqlPort(t *testing.T) { require.NotNil(t, mysqld) - port, err := mysqld.GetMysqlPort() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + port, err := mysqld.GetMysqlPort(ctx) // Expected port should be one less than the port returned by GetAndReservePort // As we are calling this second time to get port @@ -161,7 +164,9 @@ func TestReplicationStatus(t *testing.T) { conn, err := mysql.Connect(ctx, &mysqlParams) require.NoError(t, err) - port, err := mysqld.GetMysqlPort() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + port, err := mysqld.GetMysqlPort(ctx) require.NoError(t, err) host := "localhost" @@ -234,7 +239,9 @@ func TestSetReplicationPosition(t *testing.T) { func TestSetAndResetReplication(t *testing.T) { require.NotNil(t, mysqld) - port, err := mysqld.GetMysqlPort() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + port, err := mysqld.GetMysqlPort(ctx) require.NoError(t, err) host := "localhost" @@ -387,7 +394,9 @@ func TestWaitForReplicationStart(t *testing.T) { err := mysqlctl.WaitForReplicationStart(mysqld, 1) assert.ErrorContains(t, err, "no replication status") - port, err := mysqld.GetMysqlPort() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + port, err := mysqld.GetMysqlPort(ctx) require.NoError(t, err) host := "localhost" @@ -407,7 +416,9 @@ func TestStartReplication(t *testing.T) { err := mysqld.StartReplication(map[string]string{}) assert.ErrorContains(t, err, "The server is not configured as replica") - port, err := mysqld.GetMysqlPort() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + port, err := mysqld.GetMysqlPort(ctx) require.NoError(t, err) host := "localhost" @@ -425,7 +436,9 @@ func TestStartReplication(t *testing.T) { func TestStopReplication(t *testing.T) { require.NotNil(t, mysqld) - port, err := mysqld.GetMysqlPort() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + port, err := mysqld.GetMysqlPort(ctx) require.NoError(t, err) host := "localhost" @@ -449,7 +462,9 @@ func TestStopReplication(t *testing.T) { func TestStopSQLThread(t *testing.T) { require.NotNil(t, mysqld) - port, err := mysqld.GetMysqlPort() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + port, err := mysqld.GetMysqlPort(ctx) require.NoError(t, err) host := "localhost" diff --git a/go/vt/mysqlctl/fakemysqldaemon.go b/go/vt/mysqlctl/fakemysqldaemon.go index 94c1f7f52c1..e0d9a6c1252 100644 --- a/go/vt/mysqlctl/fakemysqldaemon.go +++ b/go/vt/mysqlctl/fakemysqldaemon.go @@ -273,7 +273,7 @@ func (fmd *FakeMysqlDaemon) WaitForDBAGrants(ctx context.Context, waitTime time. } // GetMysqlPort is part of the MysqlDaemon interface. -func (fmd *FakeMysqlDaemon) GetMysqlPort() (int32, error) { +func (fmd *FakeMysqlDaemon) GetMysqlPort(ctx context.Context) (int32, error) { if fmd.MysqlPort.Load() == -1 { return 0, fmt.Errorf("FakeMysqlDaemon.GetMysqlPort returns an error") } diff --git a/go/vt/mysqlctl/mysql_daemon.go b/go/vt/mysqlctl/mysql_daemon.go index 4829af3d4f7..cb9882e7052 100644 --- a/go/vt/mysqlctl/mysql_daemon.go +++ b/go/vt/mysqlctl/mysql_daemon.go @@ -43,7 +43,7 @@ type MysqlDaemon interface { WaitForDBAGrants(ctx context.Context, waitTime time.Duration) (err error) // GetMysqlPort returns the current port mysql is listening on. - GetMysqlPort() (int32, error) + GetMysqlPort(ctx context.Context) (int32, error) // GetServerID returns the servers ID. GetServerID(ctx context.Context) (uint32, error) diff --git a/go/vt/mysqlctl/mysqld.go b/go/vt/mysqlctl/mysqld.go index a1f6e257887..433ccc7de64 100644 --- a/go/vt/mysqlctl/mysqld.go +++ b/go/vt/mysqlctl/mysqld.go @@ -521,11 +521,15 @@ func (mysqld *Mysqld) WaitForDBAGrants(ctx context.Context, waitTime time.Durati if waitTime == 0 { return nil } + params, err := mysqld.dbcfgs.DbaConnector().MysqlParams() + if err != nil { + return err + } timer := time.NewTimer(waitTime) ctx, cancel := context.WithTimeout(ctx, waitTime) defer cancel() for { - conn, connErr := dbconnpool.NewDBConnection(ctx, mysqld.dbcfgs.DbaConnector()) + conn, connErr := mysql.Connect(ctx, params) if connErr == nil { res, fetchErr := conn.ExecuteFetch("SHOW GRANTS", 1000, false) conn.Close() diff --git a/go/vt/mysqlctl/replication.go b/go/vt/mysqlctl/replication.go index e53caac593d..8603b172606 100644 --- a/go/vt/mysqlctl/replication.go +++ b/go/vt/mysqlctl/replication.go @@ -29,6 +29,7 @@ import ( "strings" "time" + "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/mysql/replication" "vitess.io/vitess/go/netutil" "vitess.io/vitess/go/vt/hook" @@ -174,8 +175,21 @@ func (mysqld *Mysqld) RestartReplication(hookExtraEnv map[string]string) error { } // GetMysqlPort returns mysql port -func (mysqld *Mysqld) GetMysqlPort() (int32, error) { - qr, err := mysqld.FetchSuperQuery(context.TODO(), "SHOW VARIABLES LIKE 'port'") +func (mysqld *Mysqld) GetMysqlPort(ctx context.Context) (int32, error) { + // We can not use the connection pool here. This check runs very early + // during MySQL startup when we still might be loading things like grants. + // This means we need to use an isolated connection to avoid poisoning the + // DBA connection pool for further queries. + params, err := mysqld.dbcfgs.DbaConnector().MysqlParams() + if err != nil { + return 0, err + } + conn, err := mysql.Connect(ctx, params) + if err != nil { + return 0, err + } + defer conn.Close() + qr, err := conn.ExecuteFetch("SHOW VARIABLES LIKE 'port'", 1, false) if err != nil { return 0, err } diff --git a/go/vt/mysqlctl/replication_test.go b/go/vt/mysqlctl/replication_test.go index e171379f668..d117a96ab89 100644 --- a/go/vt/mysqlctl/replication_test.go +++ b/go/vt/mysqlctl/replication_test.go @@ -21,6 +21,7 @@ import ( "fmt" "net" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -133,12 +134,14 @@ func TestGetMysqlPort(t *testing.T) { testMysqld := NewMysqld(dbc) defer testMysqld.Close() - res, err := testMysqld.GetMysqlPort() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + res, err := testMysqld.GetMysqlPort(ctx) assert.Equal(t, int32(12), res) assert.NoError(t, err) db.AddQuery("SHOW VARIABLES LIKE 'port'", &sqltypes.Result{}) - res, err = testMysqld.GetMysqlPort() + res, err = testMysqld.GetMysqlPort(ctx) assert.ErrorContains(t, err, "no port variable in mysql") assert.Equal(t, int32(0), res) } diff --git a/go/vt/vttablet/tabletmanager/tm_init.go b/go/vt/vttablet/tabletmanager/tm_init.go index efb6c5e878f..5a4d7bcdb16 100644 --- a/go/vt/vttablet/tabletmanager/tm_init.go +++ b/go/vt/vttablet/tabletmanager/tm_init.go @@ -366,7 +366,7 @@ func (tm *TabletManager) Start(tablet *topodatapb.Tablet, config *tabletenv.Tabl if err := tm.checkPrimaryShip(ctx, si); err != nil { return err } - if err := tm.checkMysql(); err != nil { + if err := tm.checkMysql(ctx); err != nil { return err } if err := tm.initTablet(ctx); err != nil { @@ -702,7 +702,7 @@ func (tm *TabletManager) checkPrimaryShip(ctx context.Context, si *topo.ShardInf return nil } -func (tm *TabletManager) checkMysql() error { +func (tm *TabletManager) checkMysql(ctx context.Context) error { appConfig, err := tm.DBConfigs.AppWithDB().MysqlParams() if err != nil { return err @@ -717,7 +717,7 @@ func (tm *TabletManager) checkMysql() error { tm.tmState.UpdateTablet(func(tablet *topodatapb.Tablet) { tablet.MysqlHostname = tablet.Hostname }) - mysqlPort, err := tm.MysqlDaemon.GetMysqlPort() + mysqlPort, err := tm.MysqlDaemon.GetMysqlPort(ctx) if err != nil { log.Warningf("Cannot get current mysql port, will keep retrying every %v: %v", mysqlPortRetryInterval, err) go tm.findMysqlPort(mysqlPortRetryInterval) @@ -730,10 +730,18 @@ func (tm *TabletManager) checkMysql() error { return nil } +const portCheckTimeout = 5 * time.Second + +func (tm *TabletManager) getMysqlPort() (int32, error) { + ctx, cancel := context.WithTimeout(context.Background(), portCheckTimeout) + defer cancel() + return tm.MysqlDaemon.GetMysqlPort(ctx) +} + func (tm *TabletManager) findMysqlPort(retryInterval time.Duration) { for { time.Sleep(retryInterval) - mport, err := tm.MysqlDaemon.GetMysqlPort() + mport, err := tm.getMysqlPort() if err != nil || mport == 0 { continue }