diff --git a/go/vt/vtgateproxy/mysql_server.go b/go/vt/vtgateproxy/mysql_server.go index 559dea49608..0dc4a7ba862 100644 --- a/go/vt/vtgateproxy/mysql_server.go +++ b/go/vt/vtgateproxy/mysql_server.go @@ -99,17 +99,7 @@ func (ph *proxyHandler) NewConnection(c *mysql.Conn) { func (ph *proxyHandler) ComResetConnection(c *mysql.Conn) { ctx := context.Background() - session, err := ph.getSession(ctx, c) - if err != nil { - return - } - if session.SessionPb().InTransaction { - defer atomic.AddInt32(&busyConnections, -1) - } - err = ph.proxy.CloseSession(ctx, session) - if err != nil { - log.Errorf("Error happened in transaction rollback: %v", err) - } + ph.closeSession(ctx, c) } func (ph *proxyHandler) ConnectionClosed(c *mysql.Conn) { @@ -127,14 +117,7 @@ func (ph *proxyHandler) ConnectionClosed(c *mysql.Conn) { } else { ctx = context.Background() } - session, err := ph.getSession(ctx, c) - if err != nil { - return - } - if session.SessionPb().InTransaction { - defer atomic.AddInt32(&busyConnections, -1) - } - _ = ph.proxy.CloseSession(ctx, session) + ph.closeSession(ctx, c) } // Regexp to extract parent span id over the sql query @@ -377,6 +360,23 @@ func (ph *proxyHandler) getSession(ctx context.Context, c *mysql.Conn) (*vtgatec return session, nil } +func (ph *proxyHandler) closeSession(ctx context.Context, c *mysql.Conn) { + session, _ := c.ClientData.(*vtgateconn.VTGateSession) + if session == nil { + return // no active session + } + + if session.SessionPb().InTransaction { + defer atomic.AddInt32(&busyConnections, -1) + } + err := ph.proxy.CloseSession(ctx, session) + if err != nil { + log.Errorf("Error happened in transaction rollback: %v", err) + } + + c.ClientData = nil +} + var mysqlListener *mysql.Listener var mysqlUnixListener *mysql.Listener var sigChan chan os.Signal