diff --git a/sql/database.go b/sql/database.go index e527ef27a4..3006e1cda5 100644 --- a/sql/database.go +++ b/sql/database.go @@ -609,11 +609,13 @@ func (db *sqliteDatabase) getTx(ctx context.Context, initstmt string) (*sqliteTx if db.closed { return nil, ErrClosed } - conn := db.getConn(ctx) + conCtx, cancel := context.WithCancel(ctx) + conn := db.getConn(conCtx) if conn == nil { + cancel() return nil, ErrNoConnection } - tx := &sqliteTx{queryCache: db.queryCache, db: db, conn: conn} + tx := &sqliteTx{queryCache: db.queryCache, db: db, conn: conn, freeConn: cancel} if err := tx.begin(initstmt); err != nil { return nil, err } @@ -998,6 +1000,7 @@ func exec(conn *sqlite.Conn, query string, encoder Encoder, decoder Decoder) (in encoder(stmt) } defer stmt.ClearBindings() + defer stmt.Reset() rows := 0 for { @@ -1027,6 +1030,7 @@ type sqliteTx struct { *queryCache db *sqliteDatabase conn *sqlite.Conn + freeConn func() committed bool err error } @@ -1055,10 +1059,12 @@ func (tx *sqliteTx) Commit() error { func (tx *sqliteTx) Release() error { defer tx.db.pool.Put(tx.conn) if tx.committed { + tx.freeConn() return nil } stmt := tx.conn.Prep("ROLLBACK") _, tx.err = stmt.Step() + tx.freeConn() return mapSqliteError(tx.err) } diff --git a/sql/database_test.go b/sql/database_test.go index 0690f618a9..c97df7798b 100644 --- a/sql/database_test.go +++ b/sql/database_test.go @@ -7,6 +7,7 @@ import ( "path/filepath" "strings" "testing" + "time" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" @@ -16,6 +17,31 @@ import ( "go.uber.org/zap/zaptest/observer" ) +func Test_ConReturnedToPool(t *testing.T) { + db := InMemory( + WithLogger(zaptest.NewLogger(t)), + WithConnections(1), + WithDatabaseSchema(&Schema{ + Script: `CREATE TABLE testing1 ( + id varchar primary key, + field int + );`, + }), + WithNoCheckSchemaDrift(), + ) + + require.Panics(t, func() { + db.Exec("select 1", nil, func(stmt *Statement) bool { + panic("decoder panic") + }) + }) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + con := db.pool.Get(ctx) + require.NotNil(t, con, "connection was not returned") +} + func Test_Transaction_Isolation(t *testing.T) { db := InMemory( WithLogger(zaptest.NewLogger(t)),