Skip to content

Commit

Permalink
Fix possible race in MySQL startup and vttablet in parallel (#15538)
Browse files Browse the repository at this point in the history
Signed-off-by: Manan Gupta <manan@planetscale.com>
Signed-off-by: Dirkjan Bussink <d.bussink@gmail.com>
Co-authored-by: Dirkjan Bussink <d.bussink@gmail.com>
  • Loading branch information
GuptaManan100 and dbussink committed Mar 21, 2024
1 parent ae066f6 commit 2dfc0dc
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 33 deletions.
8 changes: 8 additions & 0 deletions go/vt/mysqlctl/backup.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,10 @@ func ShouldRestore(ctx context.Context, params RestoreParams) (bool, error) {
if err := params.Mysqld.Wait(ctx, params.Cnf); err != nil {
return false, err
}
if err := params.Mysqld.WaitForDBAGrants(ctx, DbaGrantWaitTime); err != nil {
params.Logger.Errorf("error waiting for the grants: %v", err)
return false, err
}
return checkNoDB(ctx, params.Mysqld, params.DbName)
}

Expand Down Expand Up @@ -403,6 +407,10 @@ func Restore(ctx context.Context, params RestoreParams) (*BackupManifest, error)
params.Logger.Errorf("mysqld is not running: %v", err)
return nil, err
}
if err = params.Mysqld.WaitForDBAGrants(ctx, DbaGrantWaitTime); err != nil {
params.Logger.Errorf("error waiting for the grants: %v", err)
return nil, err
}
// Since this is an empty database make sure we start replication at the beginning
if err := params.Mysqld.ResetReplication(ctx); err != nil {
params.Logger.Errorf("error resetting replication: %v. Continuing", err)
Expand Down
4 changes: 4 additions & 0 deletions go/vt/mysqlctl/fakemysqldaemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,10 @@ func (fmd *FakeMysqlDaemon) Wait(ctx context.Context, cnf *Mycnf) error {
return nil
}

func (fmd *FakeMysqlDaemon) WaitForDBAGrants(ctx context.Context, waitTime time.Duration) (err error) {
return nil
}

// GetMysqlPort is part of the MysqlDaemon interface.
func (fmd *FakeMysqlDaemon) GetMysqlPort() (int32, error) {
if fmd.MysqlPort.Load() == -1 {
Expand Down
1 change: 1 addition & 0 deletions go/vt/mysqlctl/mysql_daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ type MysqlDaemon interface {
ReadBinlogFilesTimestamps(ctx context.Context, req *mysqlctlpb.ReadBinlogFilesTimestampsRequest) (*mysqlctlpb.ReadBinlogFilesTimestampsResponse, error)
ReinitConfig(ctx context.Context, cnf *Mycnf) error
Wait(ctx context.Context, cnf *Mycnf) error
WaitForDBAGrants(ctx context.Context, waitTime time.Duration) (err error)

// GetMysqlPort returns the current port mysql is listening on.
GetMysqlPort() (int32, error)
Expand Down
42 changes: 39 additions & 3 deletions go/vt/mysqlctl/mysqld.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,10 @@ import (
"vitess.io/vitess/go/vt/hook"
"vitess.io/vitess/go/vt/log"
"vitess.io/vitess/go/vt/mysqlctl/mysqlctlclient"
"vitess.io/vitess/go/vt/servenv"
"vitess.io/vitess/go/vt/vterrors"

mysqlctlpb "vitess.io/vitess/go/vt/proto/mysqlctl"
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
"vitess.io/vitess/go/vt/servenv"
"vitess.io/vitess/go/vt/vterrors"
)

// The string we expect before the MySQL version number
Expand All @@ -68,6 +67,9 @@ const versionStringPrefix = "Ver "
// How many bytes from MySQL error log to sample for error messages
const maxLogFileSampleSize = 4096

// DbaGrantWaitTime is the amount of time to wait for the grants to have applied
const DbaGrantWaitTime = 10 * time.Second

var (
// DisableActiveReparents is a flag to disable active
// reparents for safety reasons. It is used in three places:
Expand Down Expand Up @@ -514,6 +516,40 @@ func (mysqld *Mysqld) Wait(ctx context.Context, cnf *Mycnf) error {
return mysqld.wait(ctx, cnf, params)
}

// WaitForDBAGrants waits for the grants to have applied for all the users.
func (mysqld *Mysqld) WaitForDBAGrants(ctx context.Context, waitTime time.Duration) (err error) {
if waitTime == 0 {
return nil
}
timer := time.NewTimer(waitTime)
ctx, cancel := context.WithTimeout(ctx, waitTime)
defer cancel()
for {
conn, connErr := dbconnpool.NewDBConnection(ctx, mysqld.dbcfgs.DbaConnector())
if connErr == nil {
res, fetchErr := conn.ExecuteFetch("SHOW GRANTS", 1000, false)
conn.Close()
if fetchErr != nil {
log.Errorf("Error running SHOW GRANTS - %v", fetchErr)
}
if fetchErr == nil && res != nil && len(res.Rows) > 0 && len(res.Rows[0]) > 0 {
privileges := res.Rows[0][0].ToString()
// In MySQL 8.0, all the privileges are listed out explicitly, so we can search for SUPER in the output.
// In MySQL 5.7, all the privileges are not listed explicitly, instead ALL PRIVILEGES is written, so we search for that too.
if strings.Contains(privileges, "SUPER") || strings.Contains(privileges, "ALL PRIVILEGES") {
return nil
}
}
}
select {
case <-timer.C:
return fmt.Errorf("timed out after %v waiting for the dba user to have the required permissions", waitTime)
default:
time.Sleep(100 * time.Millisecond)
}
}
}

// wait is the internal version of Wait, that takes credentials.
func (mysqld *Mysqld) wait(ctx context.Context, cnf *Mycnf, params *mysql.ConnParams) error {
log.Infof("Waiting for mysqld socket file (%v) to be ready...", cnf.SocketFile)
Expand Down
33 changes: 3 additions & 30 deletions go/vt/vttablet/tabletmanager/tm_init.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ import (
const (
// Query rules from denylist
denyListQueryList string = "DenyListQueryRules"
dbaGrantWaitTime = 10 * time.Second
)

var (
Expand Down Expand Up @@ -424,7 +423,7 @@ func (tm *TabletManager) Start(tablet *topodatapb.Tablet, config *tabletenv.Tabl
}

// Make sure we have the correct privileges for the DBA user before we start the state manager.
err = tm.waitForDBAGrants(config, dbaGrantWaitTime)
err = tm.waitForDBAGrants(config, mysqlctl.DbaGrantWaitTime)
if err != nil {
return err
}
Expand Down Expand Up @@ -822,7 +821,7 @@ func (tm *TabletManager) handleRestore(ctx context.Context, config *tabletenv.Ta
}

// Make sure we have the correct privileges for the DBA user before we start the state manager.
err := tm.waitForDBAGrants(config, dbaGrantWaitTime)
err := tm.waitForDBAGrants(config, mysqlctl.DbaGrantWaitTime)
if err != nil {
log.Exitf("Failed waiting for DBA grants: %v", err)
}
Expand All @@ -849,33 +848,7 @@ func (tm *TabletManager) waitForDBAGrants(config *tabletenv.TabletConfig, waitTi
if config == nil || config.DB.HasGlobalSettings() || waitTime == 0 {
return nil
}
timer := time.NewTimer(waitTime)
ctx, cancel := context.WithTimeout(context.Background(), waitTime)
defer cancel()
for {
conn, connErr := dbconnpool.NewDBConnection(ctx, config.DB.DbaConnector())
if connErr == nil {
res, fetchErr := conn.ExecuteFetch("SHOW GRANTS", 1000, false)
conn.Close()
if fetchErr != nil {
log.Errorf("Error running SHOW GRANTS - %v", fetchErr)
}
if fetchErr == nil && res != nil && len(res.Rows) > 0 && len(res.Rows[0]) > 0 {
privileges := res.Rows[0][0].ToString()
// In MySQL 8.0, all the privileges are listed out explicitly, so we can search for SUPER in the output.
// In MySQL 5.7, all the privileges are not listed explicitly, instead ALL PRIVILEGES is written, so we search for that too.
if strings.Contains(privileges, "SUPER") || strings.Contains(privileges, "ALL PRIVILEGES") {
return nil
}
}
}
select {
case <-timer.C:
return fmt.Errorf("timed out after %v waiting for the dba user to have the required permissions", waitTime)
default:
time.Sleep(100 * time.Millisecond)
}
}
return tm.MysqlDaemon.WaitForDBAGrants(context.Background(), waitTime)
}

func (tm *TabletManager) exportStats() {
Expand Down
5 changes: 5 additions & 0 deletions go/vt/vttablet/tabletmanager/tm_init_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -858,8 +858,13 @@ func TestWaitForDBAGrants(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
config, cleanup := tt.setupFunc(t)
defer cleanup()
var dm mysqlctl.MysqlDaemon
if config != nil {
dm = mysqlctl.NewMysqld(config.DB)
}
tm := TabletManager{
_waitForGrantsComplete: make(chan struct{}),
MysqlDaemon: dm,
}
err := tm.waitForDBAGrants(config, tt.waitTime)
if tt.errWanted == "" {
Expand Down

0 comments on commit 2dfc0dc

Please sign in to comment.