Skip to content

Commit

Permalink
Merge pull request #23 from chrisseto/errors
Browse files Browse the repository at this point in the history
return errors in favor of panicking, when possible
  • Loading branch information
chrisseto authored Aug 24, 2022
2 parents eacfd65 + 5b4bb61 commit b3b04d8
Show file tree
Hide file tree
Showing 12 changed files with 203 additions and 50 deletions.
28 changes: 20 additions & 8 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,11 @@ func (c *proxyConn) ExecContext(
return &proxyResult{res: res}, nil
}

rec := currentSession.VerifyRecordWithStringArg(ConnExec, query)
err, _ := rec.Args[1].(error)
rec, err := currentSession.VerifyRecordWithStringArg(ConnExec, query)
if err != nil {
return nil, err
}
err, _ = rec.Args[1].(error)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -130,8 +133,11 @@ func (c *proxyConn) PrepareContext(ctx context.Context, query string) (driver.St
return &proxyStmt{stmt: stmt}, nil
}

rec := currentSession.VerifyRecordWithStringArg(ConnPrepare, query)
err, _ := rec.Args[1].(error)
rec, err := currentSession.VerifyRecordWithStringArg(ConnPrepare, query)
if err != nil {
return nil, err
}
err, _ = rec.Args[1].(error)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -169,8 +175,11 @@ func (c *proxyConn) QueryContext(
return &proxyRows{rows: rows}, nil
}

rec := currentSession.VerifyRecordWithStringArg(ConnQuery, query)
err, _ := rec.Args[1].(error)
rec, err := currentSession.VerifyRecordWithStringArg(ConnQuery, query)
if err != nil {
return nil, err
}
err, _ = rec.Args[1].(error)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -232,8 +241,11 @@ func (c *proxyConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.
return &proxyTx{tx: tx}, nil
}

rec := currentSession.VerifyRecord(ConnBegin)
err, _ := rec.Args[0].(error)
rec, err := currentSession.VerifyRecord(ConnBegin)
if err != nil {
return nil, err
}
err, _ = rec.Args[0].(error)
if err != nil {
return nil, err
}
Expand Down
26 changes: 15 additions & 11 deletions copyist.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ var registered map[string]*proxyDriver
// The Register method takes the name of the SQL driver to be wrapped (e.g.
// "postgres"). Below is an example of how copyist.Register should be invoked.
//
// copyist.Register("postgres")
// copyist.Register("postgres")
//
// Note that Register can only be called once for a given driver; subsequent
// attempts will fail with an error. In addition, the same copyist driver must
Expand Down Expand Up @@ -139,22 +139,22 @@ func SetSessionInit(callback SessionInitCallback) {
// alongside the calling test file. If playing back, then the recording will
// be fetched from that recording file. Here is a typical calling pattern:
//
// func init() {
// copyist.Register("postgres")
// }
// func init() {
// copyist.Register("postgres")
// }
//
// func TestMyStuff(t *testing.T) {
// defer copyist.Open(t).Close()
// ...
// }
// func TestMyStuff(t *testing.T) {
// defer copyist.Open(t).Close()
// ...
// }
//
// The call to Open will initiate a new recording session. The deferred call to
// Close will complete the recording session and write the recording to a file
// in the testdata/ directory, like:
//
// mystuff_test.go
// testdata/
// mystuff_test.copyist
// mystuff_test.go
// testdata/
// mystuff_test.copyist
//
// Each test or sub-test that needs to be executed independently needs to record
// its own session.
Expand Down Expand Up @@ -215,6 +215,10 @@ func OpenSource(t testingT, source Source, recordingName string) io.Closer {
panic(r)
}

if currentSession.verificationErr != nil {
t.Fatalf("%+v\n", currentSession.verificationErr.error)
}

currentSession.Close()
currentSession = nil
return nil
Expand Down
8 changes: 4 additions & 4 deletions copyist_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func (t *mockTestingT) Fatalf(format string, args ...interface{}) {
fmt.Fprintf(&t.buf, format, args...)
}

func TestSessionPanicsAreCaught(t *testing.T) {
func TestSessionFailuresAreFatalfd(t *testing.T) {
// Enter playback mode.
*recordFlag = false
visitedRecording = true
Expand All @@ -102,16 +102,16 @@ func TestSessionPanicsAreCaught(t *testing.T) {

m := &mockTestingT{T: t}
defer func() {
require.Equal(t, "no recording exists with this name: TestSessionPanicsAreCaught\n",
require.Equal(t, "no recording exists with this name: TestSessionFailuresAreFatalfd\n",
m.buf.String())
}()

defer Open(m).Close()

db, err := sql.Open("copyist_postgres2", "")
require.NoError(t, err)
// NB: This will panic, but the panic will be caught by the copyist closer and
// converted into a call to testing.T.Fatalf.
// NB: This will return an error, but the the copyist closer will track the
// first error and convert it into a call to testing.T.Fatalf.
db.Query("SELECT 1")
}

Expand Down
7 changes: 5 additions & 2 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,11 @@ func (d *proxyDriver) Open(name string) (driver.Conn, error) {
return &proxyConn{driver: d, conn: conn, name: name, session: currentSession}, nil
}

rec := currentSession.VerifyRecord(DriverOpen)
err, _ := rec.Args[0].(error)
rec, err := currentSession.VerifyRecord(DriverOpen)
if err != nil {
return nil, err
}
err, _ = rec.Args[0].(error)
if err != nil {
return nil, err
}
Expand Down
88 changes: 88 additions & 0 deletions drivertest/commontest/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package commontest_test
import (
"bytes"
"database/sql"
"fmt"
"io"
"testing"

Expand Down Expand Up @@ -94,6 +95,84 @@ func TestOpenReadWriteCloser(t *testing.T) {
rows.Next()
}

func TestRollbackWithRecover(t *testing.T) {
// This bug is only present in playback mode, short circuit if we're
// recording.
if copyist.IsRecording() {
return
}

// This is a regression test for a deadlock when copyist would panic upon
// recording failures. We mount an intentionally out of date source that
// will fail on any action after opening our transaction. Our transaction
// helper will attempt a rollback in the case of an error or a panic, which
// would catch copyist's old behavior of panicking upon out of date
// recordings.
// We assert that we hit an out of date error and that rollback is called
// and returns.
source := CopyistSource(bytes.NewBuffer([]byte(`
1=DriverOpen 1:nil
2=ConnBegin 1:nil
"TestRollbackWithRecover"=1,2`)))

defer leaktest.Check(t)()

var mt mockT

closer := copyist.OpenSource(&mt, source, t.Name())

// Open database.
db, err := sql.Open("copyist_postgres", dataSourceName)
require.NoError(t, err)
defer db.Close()

fnErr, txErr := execTransaction(db, func(tx *sql.Tx) error {
_, err := tx.Query("SELECT 1")
return err
})

require.EqualError(t, fnErr, "too many calls to ConnQuery\n\nDo you need to regenerate the recording with the -record flag?")
require.EqualError(t, txErr, "too many calls to TxRollback\n\nDo you need to regenerate the recording with the -record flag?")

require.NoError(t, closer.Close()) // closer never errors.

// Verify that the call to .Close invokes t.Fatalf with the first session
// error that we encountered.
require.Contains(t, mt.failure, "too many calls to ConnQuery")
// Verify that t.Fatalf includes the stacktrace leading to the call that
// triggered the first error. In this case, we look for the error coming
// from the first closure defined within this test function.
require.Contains(t, mt.failure, fmt.Sprintf("commontest_test.%s.func1", t.Name()))
}

// execTransaction is a transaction helper function that attempts a rollback in
// the case of panics of errors. It returns both the closure error and the
// error of either commiting or rolling back.
// It is intended to mimic the behavior of
// https://github.com/cockroachdb/cockroach-go/blob/21a237074d6c3c35b68ec43e8d0c9e9ed714d21a/crdb/common.go#L38
func execTransaction(db *sql.DB, fn func(*sql.Tx) error) (fnErr error, txErr error) {
tx, err := db.Begin()
if err != nil {
return nil, err
}

defer func() {
if r := recover(); r != nil {
txErr = tx.Rollback()
panic(r)
}

if fnErr == nil {
txErr = tx.Commit()
} else {
txErr = tx.Rollback()
}
}()

return fn(tx), nil
}

func TestIsOpen(t *testing.T) {
require.False(t, copyist.IsOpen())

Expand Down Expand Up @@ -164,3 +243,12 @@ func (s copyistSource) WriteAll([]byte) error {
func CopyistSource(r io.Reader) copyist.Source {
return copyistSource{r}
}

type mockT struct {
failure string
}

func (mockT) Name() string { return "" }
func (t *mockT) Fatalf(format string, args ...interface{}) {
t.failure = fmt.Sprintf(format, args...)
}
1 change: 1 addition & 0 deletions drivertest/pqtestold/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hd
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg=
github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ require (
github.com/jackc/pgx/v4 v4.13.0
github.com/jmoiron/sqlx v1.3.4
github.com/lib/pq v1.10.2
github.com/pkg/errors v0.8.1
github.com/stretchr/testify v1.7.0
)
14 changes: 10 additions & 4 deletions result.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@ func (r *proxyResult) LastInsertId() (int64, error) {
return id, err
}

rec := currentSession.VerifyRecord(ResultLastInsertId)
err, _ := rec.Args[1].(error)
rec, err := currentSession.VerifyRecord(ResultLastInsertId)
if err != nil {
return 0, err
}
err, _ = rec.Args[1].(error)
if err != nil {
return 0, err
}
Expand All @@ -51,8 +54,11 @@ func (r *proxyResult) RowsAffected() (int64, error) {
return affected, err
}

rec := currentSession.VerifyRecord(ResultRowsAffected)
err, _ := rec.Args[1].(error)
rec, err := currentSession.VerifyRecord(ResultRowsAffected)
if err != nil {
return 0, err
}
err, _ = rec.Args[1].(error)
if err != nil {
return 0, err
}
Expand Down
12 changes: 9 additions & 3 deletions rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ func (r *proxyRows) Columns() []string {
return cols
}

rec := currentSession.VerifyRecord(RowsColumns)
rec, err := currentSession.VerifyRecord(RowsColumns)
if err != nil {
panic(err)
}
return rec.Args[0].([]string)
}

Expand Down Expand Up @@ -70,8 +73,11 @@ func (r *proxyRows) Next(dest []driver.Value) error {
return err
}

rec := currentSession.VerifyRecord(RowsNext)
err, _ := rec.Args[1].(error)
rec, err := currentSession.VerifyRecord(RowsNext)
if err != nil {
return err
}
err, _ = rec.Args[1].(error)
if err != nil {
return err
}
Expand Down
Loading

0 comments on commit b3b04d8

Please sign in to comment.