diff --git a/go/vt/vtgate/plugin_mysql_server.go b/go/vt/vtgate/plugin_mysql_server.go index f880b6c0412..df1911a6759 100644 --- a/go/vt/vtgate/plugin_mysql_server.go +++ b/go/vt/vtgate/plugin_mysql_server.go @@ -91,6 +91,8 @@ type vtgateHandler struct { vtg *VTGate connections map[*mysql.Conn]bool + + shutdown atomic.Bool } func newVtgateHandler(vtg *VTGate) *vtgateHandler { @@ -147,6 +149,14 @@ func (vh *vtgateHandler) ConnectionClosed(c *mysql.Conn) { _ = vh.vtg.CloseSession(ctx, session) } +func (vh *vtgateHandler) Shutdown() bool { + return vh.shutdown.CompareAndSwap(false, true) +} + +func (vh *vtgateHandler) IsShutdown() bool { + return vh.shutdown.Load() +} + // Regexp to extract parent span id over the sql query var r = regexp.MustCompile(`/\*VT_SPAN_CONTEXT=(.*)\*/`) @@ -182,6 +192,13 @@ func startSpan(ctx context.Context, query, label string) (trace.Span, context.Co } func (vh *vtgateHandler) ComQuery(c *mysql.Conn, query string, callback func(*sqltypes.Result) error) error { + session := vh.session(c) + if vh.IsShutdown() { + if !session.InTransaction { + return mysql.NewSQLError(mysql.ERServerShutdown, mysql.SSNetError, "VTGate shutdown in progress") + } + } + ctx := context.Background() var cancel context.CancelFunc if *mysqlQueryTimeout != 0 { @@ -209,7 +226,6 @@ func (vh *vtgateHandler) ComQuery(c *mysql.Conn, query string, callback func(*sq "VTGate MySQL Connector" /* subcomponent: part of the client */) ctx = callerid.NewContext(ctx, ef, im) - session := vh.session(c) if !session.InTransaction { atomic.AddInt32(&busyConnections, 1) } @@ -535,12 +551,14 @@ func newMysqlUnixSocket(address string, authServer mysql.AuthServer, handler mys } func shutdownMysqlProtocolAndDrain() { + vtgateHandle.Shutdown() + if mysqlListener != nil { - mysqlListener.Close() + mysqlListener.Shutdown() mysqlListener = nil } if mysqlUnixListener != nil { - mysqlUnixListener.Close() + mysqlUnixListener.Shutdown() mysqlUnixListener = nil } if sigChan != nil {