Skip to content

Commit

Permalink
chore: use interface instead of struct for tests (#15581)
Browse files Browse the repository at this point in the history
  • Loading branch information
systay authored Mar 27, 2024
1 parent 7aec15f commit 2b478cd
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 15 deletions.
17 changes: 13 additions & 4 deletions go/test/endtoend/utils/cmp.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,17 @@ import (
"vitess.io/vitess/go/sqltypes"
)

type TestingT interface {
require.TestingT
Helper()
}

type MySQLCompare struct {
t *testing.T
t TestingT
MySQLConn, VtConn *mysql.Conn
}

func NewMySQLCompare(t *testing.T, vtParams, mysqlParams mysql.ConnParams) (MySQLCompare, error) {
func NewMySQLCompare(t TestingT, vtParams, mysqlParams mysql.ConnParams) (MySQLCompare, error) {
ctx := context.Background()
vtConn, err := mysql.Connect(ctx, &vtParams)
if err != nil {
Expand All @@ -53,6 +58,10 @@ func NewMySQLCompare(t *testing.T, vtParams, mysqlParams mysql.ConnParams) (MySQ
}, nil
}

func (mcmp *MySQLCompare) AsT() *testing.T {
return mcmp.t.(*testing.T)
}

func (mcmp *MySQLCompare) Close() {
mcmp.VtConn.Close()
mcmp.MySQLConn.Close()
Expand All @@ -73,7 +82,7 @@ func (mcmp *MySQLCompare) AssertMatches(query, expected string) {
// SkipIfBinaryIsBelowVersion should be used instead of using utils.SkipIfBinaryIsBelowVersion(t,
// This is because we might be inside a Run block that has a different `t` variable
func (mcmp *MySQLCompare) SkipIfBinaryIsBelowVersion(majorVersion int, binary string) {
SkipIfBinaryIsBelowVersion(mcmp.t, majorVersion, binary)
SkipIfBinaryIsBelowVersion(mcmp.t.(*testing.T), majorVersion, binary)
}

// AssertMatchesAny ensures the given query produces any one of the expected results.
Expand Down Expand Up @@ -264,7 +273,7 @@ func (mcmp *MySQLCompare) ExecAndIgnore(query string) (*sqltypes.Result, error)
}

func (mcmp *MySQLCompare) Run(query string, f func(mcmp *MySQLCompare)) {
mcmp.t.Run(query, func(t *testing.T) {
mcmp.AsT().Run(query, func(t *testing.T) {
inner := &MySQLCompare{
t: t,
MySQLConn: mcmp.MySQLConn,
Expand Down
18 changes: 8 additions & 10 deletions go/test/endtoend/utils/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"fmt"
"os"
"path"
"testing"
"time"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -169,18 +168,18 @@ func prepareMySQLWithSchema(params mysql.ConnParams, sql string) error {
return nil
}

func compareVitessAndMySQLResults(t *testing.T, query string, vtConn *mysql.Conn, vtQr, mysqlQr *sqltypes.Result, compareColumnNames bool) error {
func compareVitessAndMySQLResults(t TestingT, query string, vtConn *mysql.Conn, vtQr, mysqlQr *sqltypes.Result, compareColumnNames bool) error {
t.Helper()

if vtQr == nil && mysqlQr == nil {
return nil
}
if vtQr == nil {
t.Error("Vitess result is 'nil' while MySQL's is not.")
t.Errorf("Vitess result is 'nil' while MySQL's is not.")
return errors.New("Vitess result is 'nil' while MySQL's is not.\n")
}
if mysqlQr == nil {
t.Error("MySQL result is 'nil' while Vitess' is not.")
t.Errorf("MySQL result is 'nil' while Vitess' is not.")
return errors.New("MySQL result is 'nil' while Vitess' is not.\n")
}

Expand Down Expand Up @@ -209,7 +208,7 @@ func compareVitessAndMySQLResults(t *testing.T, query string, vtConn *mysql.Conn

stmt, err := sqlparser.NewTestParser().Parse(query)
if err != nil {
t.Error(err)
t.Errorf(err.Error())
return err
}
orderBy := false
Expand Down Expand Up @@ -237,11 +236,11 @@ func compareVitessAndMySQLResults(t *testing.T, query string, vtConn *mysql.Conn
errStr += fmt.Sprintf("query plan: \n%s\n", qr.Rows[0][0].ToString())
}
}
t.Error(errStr)
t.Errorf(errStr)
return errors.New(errStr)
}

func checkFields(t *testing.T, columnName string, vtField, myField *querypb.Field) {
func checkFields(t TestingT, columnName string, vtField, myField *querypb.Field) {
t.Helper()
if vtField.Type != myField.Type {
t.Errorf("for column %s field types do not match\nNot equal: \nMySQL: %v\nVitess: %v\n", columnName, myField.Type.String(), vtField.Type.String())
Expand All @@ -255,10 +254,9 @@ func checkFields(t *testing.T, columnName string, vtField, myField *querypb.Fiel
}
}

func compareVitessAndMySQLErrors(t *testing.T, vtErr, mysqlErr error) {
func compareVitessAndMySQLErrors(t TestingT, vtErr, mysqlErr error) {
if vtErr != nil && mysqlErr != nil || vtErr == nil && mysqlErr == nil {
return
}
out := fmt.Sprintf("Vitess and MySQL are not erroring the same way.\nVitess error: %v\nMySQL error: %v", vtErr, mysqlErr)
t.Error(out)
t.Errorf("Vitess and MySQL are not erroring the same way.\nVitess error: %v\nMySQL error: %v", vtErr, mysqlErr)
}
2 changes: 1 addition & 1 deletion go/test/endtoend/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ func ExecCompareMySQL(t *testing.T, vtConn, mysqlConn *mysql.Conn, query string)

// ExecAllowError executes the given query without failing the test if it produces
// an error. The error is returned to the client, along with the result set.
func ExecAllowError(t testing.TB, conn *mysql.Conn, query string) (*sqltypes.Result, error) {
func ExecAllowError(t TestingT, conn *mysql.Conn, query string) (*sqltypes.Result, error) {
t.Helper()
return conn.ExecuteFetch(query, 1000, true)
}
Expand Down

0 comments on commit 2b478cd

Please sign in to comment.