diff --git a/go/mysql/conn.go b/go/mysql/conn.go index f13c3b2242f..17dc9ccca89 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -199,6 +199,12 @@ type Conn struct { // enableQueryInfo controls whether we parse the INFO field in QUERY_OK packets // See: ConnParams.EnableQueryInfo enableQueryInfo bool + + // mu protects the fields below + mu sync.Mutex + // this is used to mark the connection to be closed so that the command phase for the connection can be stopped and + // the connection gets closed. + closing bool } // splitStatementFunciton is the function that is used to split the statement in case of a multi-statement query. @@ -899,6 +905,11 @@ func (c *Conn) handleNextCommand(handler Handler) bool { return false } + // before continue to process the packet, check if the connection should be closed or not. + if c.IsMarkedForClose() { + return false + } + switch data[0] { case ComQuit: c.recycleReadPacket() @@ -1634,3 +1645,21 @@ func (c *Conn) IsUnixSocket() bool { func (c *Conn) GetRawConn() net.Conn { return c.conn } + +// MarkForClose marks the connection for close. +func (c *Conn) MarkForClose() { + c.mu.Lock() + defer c.mu.Unlock() + c.closing = true +} + +// IsMarkedForClose return true if the connection should be closed. +func (c *Conn) IsMarkedForClose() bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.closing +} + +func (c *Conn) IsShuttingDown() bool { + return c.listener.isShutdown() +} diff --git a/go/mysql/server.go b/go/mysql/server.go index 4d65ce93a81..0e782847adb 100644 --- a/go/mysql/server.go +++ b/go/mysql/server.go @@ -526,7 +526,8 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti for { kontinue := c.handleNextCommand(l.handler) - if !kontinue { + // before going for next command check if the connection should be closed or not. + if !kontinue || c.IsMarkedForClose() { return } } diff --git a/go/vt/vtgate/plugin_mysql_server.go b/go/vt/vtgate/plugin_mysql_server.go index 4a6f45ce629..d42128c74fc 100644 --- a/go/vt/vtgate/plugin_mysql_server.go +++ b/go/vt/vtgate/plugin_mysql_server.go @@ -199,6 +199,12 @@ func startSpan(ctx context.Context, query, label string) (trace.Span, context.Co } func (vh *vtgateHandler) ComQuery(c *mysql.Conn, query string, callback func(*sqltypes.Result) error) error { + session := vh.session(c) + if c.IsShuttingDown() && !session.InTransaction { + c.MarkForClose() + return mysql.NewSQLError(mysql.ERServerShutdown, mysql.SSNetError, "Server shutdown in progress") + } + ctx := context.Background() var cancel context.CancelFunc if mysqlQueryTimeout != 0 { @@ -226,7 +232,6 @@ func (vh *vtgateHandler) ComQuery(c *mysql.Conn, query string, callback func(*sq "VTGate MySQL Connector" /* subcomponent: part of the client */) ctx = callerid.NewContext(ctx, ef, im) - session := vh.session(c) if !session.InTransaction { atomic.AddInt32(&busyConnections, 1) } @@ -561,11 +566,11 @@ func newMysqlUnixSocket(address string, authServer mysql.AuthServer, handler mys func shutdownMysqlProtocolAndDrain() { if mysqlListener != nil { - mysqlListener.Close() + mysqlListener.Shutdown() mysqlListener = nil } if mysqlUnixListener != nil { - mysqlUnixListener.Close() + mysqlUnixListener.Shutdown() mysqlUnixListener = nil } if sigChan != nil {