From 332c761719a38cb55bba30a22e10abc1fd37615e Mon Sep 17 00:00:00 2001 From: Noble Mittal Date: Sun, 12 May 2024 16:45:18 +0530 Subject: [PATCH 1/3] test: Add missing/required tests for sqltypes and mathstats Signed-off-by: Noble Mittal --- go/mathstats/sample_test.go | 22 ++ go/sqltypes/named_result_test.go | 47 ++++ go/sqltypes/result_test.go | 163 ++++++++++++ go/sqltypes/value_test.go | 424 ++++++++++++++++++++----------- 4 files changed, 513 insertions(+), 143 deletions(-) diff --git a/go/mathstats/sample_test.go b/go/mathstats/sample_test.go index c0da3c2b7f4..7b2b5101bcf 100644 --- a/go/mathstats/sample_test.go +++ b/go/mathstats/sample_test.go @@ -168,6 +168,28 @@ func TestSampleClear(t *testing.T) { assert.False(t, s.Sorted, "Sorting status should be false after clearing") } +func TestIQR(t *testing.T) { + tt := []struct { + sample Sample + expected float64 + }{ + {Sample{Xs: []float64{15, 20, 35, 40, 50}}, 24.999999999999996}, + {Sample{Xs: []float64{}, Sorted: false}, math.NaN()}, + {Sample{Xs: []float64{15, 0, 0, 40, 50}}, 43.33333333333333}, + {Sample{Xs: []float64{10, 2, 1, 0, 23}}, 13.666666666666663}, + } + + for _, tc := range tt { + iqr := tc.sample.IQR() + + if math.IsNaN(tc.expected) { + assert.True(t, math.IsNaN(iqr)) + } else { + assert.Equal(t, tc.expected, iqr) + } + } +} + func TestSampleSort(t *testing.T) { tt := []struct { sample Sample diff --git a/go/sqltypes/named_result_test.go b/go/sqltypes/named_result_test.go index ae42d4257dd..0d7651ba2ac 100644 --- a/go/sqltypes/named_result_test.go +++ b/go/sqltypes/named_result_test.go @@ -52,6 +52,7 @@ func TestToNamedResult(t *testing.T) { for i := range in.Rows { require.Equal(t, in.Rows[i][0], named.Rows[i]["id"]) require.Equal(t, int64(i), named.Rows[i].AsInt64("id", 0)) + require.Equal(t, int32(i), named.Rows[i].AsInt32("id", 0)) require.Equal(t, in.Rows[i][1], named.Rows[i]["status"]) require.Equal(t, fmt.Sprintf("s%d", i), named.Rows[i].AsString("status", "notfound")) @@ -173,3 +174,49 @@ func TestRow(t *testing.T) { }) } } + +func TestAsBool(t *testing.T) { + row := RowNamedValues{ + "testFalse": MakeTrusted(Int64, []byte("0")), + "testTrue": MakeTrusted(Int64, []byte("1")), + } + + r := row.AsBool("testFalse", true) + assert.False(t, r) + + r = row.AsBool("testTrue", false) + assert.True(t, r) + + r = row.AsBool("invalidField", true) + assert.True(t, r) +} + +func TestAsBytes(t *testing.T) { + row := RowNamedValues{ + "testField": MakeTrusted(Int64, []byte("1002")), + } + + r := row.AsBytes("testField", []byte("default")) + assert.Equal(t, []byte("1002"), r) + + r = row.AsBytes("invalidField", []byte("default")) + assert.Equal(t, []byte("default"), r) + +} + +func TestAsFloat64(t *testing.T) { + row := RowNamedValues{ + "testField": MakeTrusted(Int64, []byte("1002")), + "testField2": MakeTrusted(Float64, []byte("10.02")), + } + + r := row.AsFloat64("testField", 23.12) + assert.Equal(t, float64(1002), r) + + r = row.AsFloat64("testField2", 23.12) + assert.Equal(t, 10.02, r) + + r = row.AsFloat64("invalidField", 23.12) + assert.Equal(t, 23.12, r) + +} diff --git a/go/sqltypes/result_test.go b/go/sqltypes/result_test.go index 90d2eb9af65..fba543e35d7 100644 --- a/go/sqltypes/result_test.go +++ b/go/sqltypes/result_test.go @@ -19,6 +19,7 @@ package sqltypes import ( "testing" + "github.com/stretchr/testify/assert" "vitess.io/vitess/go/test/utils" querypb "vitess.io/vitess/go/vt/proto/query" @@ -345,3 +346,165 @@ func TestAppendResult(t *testing.T) { t.Errorf("Got:\n%#v, want:\n%#v", result, want) } } + +func TestReplaceKeyspace(t *testing.T) { + result := &Result{ + Fields: []*querypb.Field{{ + Type: Int64, + Database: "vttest", + }, { + Type: VarChar, + Database: "vttest", + }, { + Type: VarBinary, + }}, + } + + result.ReplaceKeyspace("keyspace-name") + assert.Equal(t, "keyspace-name", result.Fields[0].Database) + assert.Equal(t, "keyspace-name", result.Fields[1].Database) + // Expect empty database identifiers to remain empty + assert.Equal(t, "", result.Fields[2].Database) +} + +func TestShallowCopy(t *testing.T) { + result := &Result{ + Fields: []*querypb.Field{{ + Type: Int64, + Database: "vttest", + }, { + Type: VarChar, + Database: "vttest", + }}, + Rows: [][]Value{ + { + MakeTrusted(querypb.Type_INT32, []byte("10")), + MakeTrusted(querypb.Type_VARCHAR, []byte("name")), + }, + }, + } + + res := result.ShallowCopy() + assert.Equal(t, result, res) +} + +func TestMetadata(t *testing.T) { + result := &Result{ + Fields: []*querypb.Field{{ + Type: Int64, + Database: "vttest", + }, { + Type: VarChar, + Database: "vttest", + }}, + Rows: [][]Value{ + { + MakeTrusted(querypb.Type_INT32, []byte("10")), + MakeTrusted(querypb.Type_VARCHAR, []byte("name")), + }, + }, + } + + res := result.Metadata() + assert.Nil(t, res.Rows) + assert.Equal(t, result.Fields, res.Fields) +} + +func TestResultsEqualUnordered(t *testing.T) { + result1 := &Result{ + Fields: []*querypb.Field{{ + Type: Int64, + Database: "vttest", + }, { + Type: VarChar, + Database: "vttest", + }}, + Rows: [][]Value{ + { + MakeTrusted(querypb.Type_INT32, []byte("24")), + MakeTrusted(querypb.Type_VARCHAR, []byte("test-name1")), + }, + }, + RowsAffected: 2, + } + + result2 := &Result{ + Fields: []*querypb.Field{{ + Type: Int64, + Database: "vttest", + }, { + Type: VarChar, + Database: "vttest", + }}, + Rows: [][]Value{ + { + MakeTrusted(querypb.Type_INT32, []byte("10")), + MakeTrusted(querypb.Type_VARCHAR, []byte("test-name2")), + }, + }, + RowsAffected: 2, + } + + result3 := &Result{ + Fields: []*querypb.Field{{ + Type: Int64, + Database: "vttest", + }, { + Type: VarChar, + Database: "vttest", + }}, + Rows: [][]Value{ + { + MakeTrusted(querypb.Type_INT32, []byte("10")), + MakeTrusted(querypb.Type_VARCHAR, []byte("test-name2")), + }, + { + MakeTrusted(querypb.Type_INT32, []byte("24")), + MakeTrusted(querypb.Type_VARCHAR, []byte("test-name1")), + }, + }, + RowsAffected: 3, + } + + eq := ResultsEqualUnordered([]Result{*result1, *result2}, []Result{*result2, *result1}) + assert.True(t, eq) + + eq = ResultsEqualUnordered([]Result{*result1}, []Result{*result2, *result1}) + assert.False(t, eq) + + eq = ResultsEqualUnordered([]Result{*result1}, []Result{*result2}) + assert.False(t, eq) + + eq = ResultsEqualUnordered([]Result{*result1, *result3}, []Result{*result2, *result1}) + assert.False(t, eq) +} + +func TestStatusFlags(t *testing.T) { + result := &Result{ + Fields: []*querypb.Field{{ + Type: Int64, + Database: "vttest", + }, { + Type: VarChar, + Database: "vttest", + }}, + StatusFlags: ServerMoreResultsExists, + } + + assert.True(t, result.IsMoreResultsExists()) + assert.False(t, result.IsInTransaction()) + + result.StatusFlags = ServerStatusInTrans + + assert.False(t, result.IsMoreResultsExists()) + assert.True(t, result.IsInTransaction()) +} + +func TestIncludeFieldsOrDefault(t *testing.T) { + // Should return default if nil is passed + r := IncludeFieldsOrDefault(nil) + assert.Equal(t, querypb.ExecuteOptions_TYPE_AND_NAME, r) + + r = IncludeFieldsOrDefault(&querypb.ExecuteOptions{IncludedFields: querypb.ExecuteOptions_TYPE_ONLY}) + assert.Equal(t, querypb.ExecuteOptions_TYPE_ONLY, r) +} diff --git a/go/sqltypes/value_test.go b/go/sqltypes/value_test.go index d6a9b510b9e..36a0f5a5090 100644 --- a/go/sqltypes/value_test.go +++ b/go/sqltypes/value_test.go @@ -17,8 +17,7 @@ limitations under the License. package sqltypes import ( - "bytes" - "reflect" + "math" "strings" "testing" @@ -26,6 +25,7 @@ import ( "github.com/stretchr/testify/require" + "vitess.io/vitess/go/bytes2" querypb "vitess.io/vitess/go/vt/proto/query" ) @@ -190,18 +190,12 @@ func TestNewValue(t *testing.T) { for _, tcase := range testcases { v, err := NewValue(tcase.inType, []byte(tcase.inVal)) if tcase.outErr != "" { - if err == nil || !strings.Contains(err.Error(), tcase.outErr) { - t.Errorf("ValueFromBytes(%v, %v) error: %v, must contain %v", tcase.inType, tcase.inVal, err, tcase.outErr) - } - continue - } - if err != nil { - t.Errorf("ValueFromBytes(%v, %v) error: %v", tcase.inType, tcase.inVal, err) + assert.ErrorContains(t, err, tcase.outErr) continue } - if !reflect.DeepEqual(v, tcase.outVal) { - t.Errorf("ValueFromBytes(%v, %v) = %v, want %v", tcase.inType, tcase.inVal, v, tcase.outVal) - } + + assert.NoError(t, err) + assert.Equal(t, tcase.outVal, v) } } @@ -210,27 +204,24 @@ func TestNewValue(t *testing.T) { func TestNew(t *testing.T) { got := NewInt32(1) want := MakeTrusted(Int32, []byte("1")) - if !reflect.DeepEqual(got, want) { - t.Errorf("NewInt32(aa): %v, want %v", got, want) - } + assert.Equal(t, want, got) got = NewVarBinary("aa") want = MakeTrusted(VarBinary, []byte("aa")) - if !reflect.DeepEqual(got, want) { - t.Errorf("NewVarBinary(aa): %v, want %v", got, want) - } + assert.Equal(t, want, got) + + got, err := NewJSON("invalid-json") + assert.Empty(t, got) + assert.ErrorContains(t, err, "invalid JSON value") } func TestMakeTrusted(t *testing.T) { v := MakeTrusted(Null, []byte("abcd")) - if !reflect.DeepEqual(v, NULL) { - t.Errorf("MakeTrusted(Null...) = %v, want null", v) - } + assert.Equal(t, NULL, v) + v = MakeTrusted(Int64, []byte("1")) want := TestValue(Int64, "1") - if !reflect.DeepEqual(v, want) { - t.Errorf("MakeTrusted(Int64, \"1\") = %v, want %v", v, want) - } + assert.Equal(t, want, v) } func TestIntegralValue(t *testing.T) { @@ -254,18 +245,12 @@ func TestIntegralValue(t *testing.T) { for _, tcase := range testcases { v, err := NewIntegral(tcase.in) if tcase.outErr != "" { - if err == nil || !strings.Contains(err.Error(), tcase.outErr) { - t.Errorf("BuildIntegral(%v) error: %v, must contain %v", tcase.in, err, tcase.outErr) - } + assert.ErrorContains(t, err, tcase.outErr) continue } - if err != nil { - t.Errorf("BuildIntegral(%v) error: %v", tcase.in, err) - continue - } - if !reflect.DeepEqual(v, tcase.outVal) { - t.Errorf("BuildIntegral(%v) = %v, want %v", tcase.in, v, tcase.outVal) - } + + assert.NoError(t, err) + assert.Equal(t, tcase.outVal, v) } } @@ -294,118 +279,66 @@ func TestInterfaceValue(t *testing.T) { }} for _, tcase := range testcases { v, err := InterfaceToValue(tcase.in) - if err != nil { - t.Errorf("BuildValue(%#v) error: %v", tcase.in, err) - continue - } - if !reflect.DeepEqual(v, tcase.out) { - t.Errorf("BuildValue(%#v) = %v, want %v", tcase.in, v, tcase.out) - } + + assert.NoError(t, err) + assert.Equal(t, tcase.out, v) } _, err := InterfaceToValue(make(chan bool)) want := "unexpected" - if err == nil || !strings.Contains(err.Error(), want) { - t.Errorf("BuildValue(chan): %v, want %v", err, want) - } + assert.ErrorContains(t, err, want) } func TestAccessors(t *testing.T) { v := TestValue(Int64, "1") - if v.Type() != Int64 { - t.Errorf("v.Type=%v, want Int64", v.Type()) - } - if !bytes.Equal(v.Raw(), []byte("1")) { - t.Errorf("v.Raw=%s, want 1", v.Raw()) - } - if v.Len() != 1 { - t.Errorf("v.Len=%d, want 1", v.Len()) - } - if v.ToString() != "1" { - t.Errorf("v.String=%s, want 1", v.ToString()) - } - if v.IsNull() { - t.Error("v.IsNull: true, want false") - } - if !v.IsIntegral() { - t.Error("v.IsIntegral: false, want true") - } - if !v.IsSigned() { - t.Error("v.IsSigned: false, want true") - } - if v.IsUnsigned() { - t.Error("v.IsUnsigned: true, want false") - } - if v.IsFloat() { - t.Error("v.IsFloat: true, want false") - } - if v.IsQuoted() { - t.Error("v.IsQuoted: true, want false") - } - if v.IsText() { - t.Error("v.IsText: true, want false") - } - if v.IsBinary() { - t.Error("v.IsBinary: true, want false") - } + assert.Equal(t, Int64, v.Type()) + assert.Equal(t, []byte("1"), v.Raw()) + assert.Equal(t, 1, v.Len()) + assert.Equal(t, "1", v.ToString()) + assert.False(t, v.IsNull()) + assert.True(t, v.IsIntegral()) + assert.True(t, v.IsSigned()) + assert.False(t, v.IsUnsigned()) + assert.False(t, v.IsFloat()) + assert.False(t, v.IsQuoted()) + assert.False(t, v.IsText()) + assert.False(t, v.IsBinary()) + { i, err := v.ToInt64() - if err != nil { - t.Errorf("v.ToInt64: got error: %+v, want no error", err) - } - if i != 1 { - t.Errorf("v.ToInt64=%+v, want 1", i) - } + assert.NoError(t, err) + assert.Equal(t, int64(1), i) } { i, err := v.ToUint64() - if err != nil { - t.Errorf("v.ToUint64: got error: %+v, want no error", err) - } - if i != 1 { - t.Errorf("v.ToUint64=%+v, want 1", i) - } + assert.NoError(t, err) + assert.Equal(t, uint64(1), i) } { b, err := v.ToBool() - if err != nil { - t.Errorf("v.ToBool: got error: %+v, want no error", err) - } - if !b { - t.Errorf("v.ToBool=%+v, want true", b) - } + assert.NoError(t, err) + assert.True(t, b) } } func TestAccessorsNegative(t *testing.T) { v := TestValue(Int64, "-1") - if v.ToString() != "-1" { - t.Errorf("v.String=%s, want -1", v.ToString()) - } - if v.IsNull() { - t.Error("v.IsNull: true, want false") - } - if !v.IsIntegral() { - t.Error("v.IsIntegral: false, want true") - } + assert.Equal(t, "-1", v.ToString()) + assert.False(t, v.IsNull()) + assert.True(t, v.IsIntegral()) + { i, err := v.ToInt64() - if err != nil { - t.Errorf("v.ToInt64: got error: %+v, want no error", err) - } - if i != -1 { - t.Errorf("v.ToInt64=%+v, want -1", i) - } + assert.NoError(t, err) + assert.Equal(t, int64(-1), i) } { - if _, err := v.ToUint64(); err == nil { - t.Error("v.ToUint64: got no error, want error") - } + _, err := v.ToUint64() + assert.Error(t, err) } { - if _, err := v.ToBool(); err == nil { - t.Error("v.ToUint64: got no error, want error") - } + _, err := v.ToBool() + assert.Error(t, err) } } @@ -417,23 +350,15 @@ func TestToBytesAndString(t *testing.T) { } { vBytes, err := v.ToBytes() require.NoError(t, err) - if b := vBytes; !bytes.Equal(b, v.Raw()) { - t.Errorf("%v.ToBytes: %s, want %s", v, b, v.Raw()) - } - if s := v.ToString(); s != string(v.Raw()) { - t.Errorf("%v.ToString: %s, want %s", v, s, v.Raw()) - } + assert.Equal(t, v.Raw(), vBytes) + assert.Equal(t, string(v.Raw()), v.ToString()) } tv := TestValue(Expression, "aa") tvBytes, err := tv.ToBytes() require.EqualError(t, err, "expression cannot be converted to bytes") - if b := tvBytes; b != nil { - t.Errorf("%v.ToBytes: %s, want nil", tv, b) - } - if s := tv.ToString(); s != "" { - t.Errorf("%v.ToString: %s, want \"\"", tv, s) - } + assert.Nil(t, tvBytes) + assert.Empty(t, tv.ToString()) } func TestEncode(t *testing.T) { @@ -465,25 +390,18 @@ func TestEncode(t *testing.T) { for _, tcase := range testcases { var buf strings.Builder tcase.in.EncodeSQL(&buf) - if tcase.outSQL != buf.String() { - t.Errorf("%v.EncodeSQL = %q, want %q", tcase.in, buf.String(), tcase.outSQL) - } + assert.Equal(t, tcase.outSQL, buf.String()) + buf.Reset() tcase.in.EncodeASCII(&buf) - if tcase.outASCII != buf.String() { - t.Errorf("%v.EncodeASCII = %q, want %q", tcase.in, buf.String(), tcase.outASCII) - } + assert.Equal(t, tcase.outASCII, buf.String()) } } // TestEncodeMap ensures DontEscape is not escaped func TestEncodeMap(t *testing.T) { - if SQLEncodeMap[DontEscape] != DontEscape { - t.Errorf("SQLEncodeMap[DontEscape] = %v, want %v", SQLEncodeMap[DontEscape], DontEscape) - } - if SQLDecodeMap[DontEscape] != DontEscape { - t.Errorf("SQLDecodeMap[DontEscape] = %v, want %v", SQLEncodeMap[DontEscape], DontEscape) - } + assert.Equal(t, DontEscape, SQLEncodeMap[DontEscape]) + assert.Equal(t, DontEscape, SQLDecodeMap[DontEscape]) } func TestHexAndBitToBytes(t *testing.T) { @@ -569,3 +487,223 @@ func TestDecodeStringSQL(t *testing.T) { } } } + +func TestTinyWeightCmp(t *testing.T) { + val1 := TestValue(Int64, "12") + val2 := TestValue(VarChar, "aa") + + val1.SetTinyWeight(10) + + // Test TinyWeight + assert.Equal(t, uint32(10), val1.TinyWeight()) + + cmp := val1.TinyWeightCmp(val2) + assert.Equal(t, 0, cmp) + + val2.SetTinyWeight(10) + cmp = val1.TinyWeightCmp(val2) + assert.Equal(t, 0, cmp) + + val2.SetTinyWeight(20) + cmp = val1.TinyWeightCmp(val2) + assert.Equal(t, -1, cmp) + + val2.SetTinyWeight(5) + cmp = val1.TinyWeightCmp(val2) + assert.Equal(t, 1, cmp) +} + +func TestToCastInt64(t *testing.T) { + tcases := []struct { + in Value + want int64 + err string + }{ + {TestValue(Int64, "213"), 213, ""}, + {TestValue(Int64, "-213"), -213, ""}, + {TestValue(VarChar, "9223372036854775808a"), math.MaxInt64, `cannot parse int64 from "9223372036854775808a": overflow`}, + {TestValue(Time, "12:23:59"), 12, `unparsed tail left after parsing int64 from "12:23:59": ":23:59"`}, + } + + for _, tcase := range tcases { + t.Run(tcase.in.String(), func(t *testing.T) { + got, err := tcase.in.ToCastInt64() + assert.Equal(t, tcase.want, got) + + if tcase.err != "" { + assert.ErrorContains(t, err, tcase.err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestToCastUint64(t *testing.T) { + tcases := []struct { + in Value + want uint64 + err string + }{ + {TestValue(Int64, "213"), 213, ""}, + {TestValue(Int64, "-213"), 0, `cannot parse uint64 from "-213"`}, + {TestValue(VarChar, "9223372036854775808a"), 9223372036854775808, `unparsed tail left after parsing uint64 from "9223372036854775808a": "a"`}, + {TestValue(Time, "12:23:59"), 12, `unparsed tail left after parsing uint64 from "12:23:59": ":23:59"`}, + } + + for _, tcase := range tcases { + t.Run(tcase.in.String(), func(t *testing.T) { + got, err := tcase.in.ToCastUint64() + assert.Equal(t, tcase.want, got) + + if tcase.err != "" { + assert.ErrorContains(t, err, tcase.err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestToUint16(t *testing.T) { + tcases := []struct { + in Value + want uint16 + err string + }{ + {TestValue(Int64, "213"), 213, ""}, + {TestValue(Int64, "-213"), 0, `parsing "-213": invalid syntax`}, + {TestValue(VarChar, "9223372036854775808a"), 0, ErrIncompatibleTypeCast.Error()}, + {TestValue(Time, "12:23:59"), 0, ErrIncompatibleTypeCast.Error()}, + } + + for _, tcase := range tcases { + t.Run(tcase.in.String(), func(t *testing.T) { + got, err := tcase.in.ToUint16() + assert.Equal(t, tcase.want, got) + + if tcase.err != "" { + assert.ErrorContains(t, err, tcase.err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestToUint32(t *testing.T) { + tcases := []struct { + in Value + want uint32 + err string + }{ + {TestValue(Int64, "213"), 213, ""}, + {TestValue(Int64, "-213"), 0, `parsing "-213": invalid syntax`}, + {TestValue(VarChar, "9223372036854775808a"), 0, ErrIncompatibleTypeCast.Error()}, + {TestValue(Time, "12:23:59"), 0, ErrIncompatibleTypeCast.Error()}, + } + + for _, tcase := range tcases { + t.Run(tcase.in.String(), func(t *testing.T) { + got, err := tcase.in.ToUint32() + assert.Equal(t, tcase.want, got) + + if tcase.err != "" { + assert.ErrorContains(t, err, tcase.err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestEncodeSQLStringBuilder(t *testing.T) { + testcases := []struct { + in Value + outSQL string + }{{ + in: NULL, + outSQL: "null", + }, { + in: TestValue(Int64, "1"), + outSQL: "1", + }, { + in: TestValue(VarChar, "foo"), + outSQL: "'foo'", + }, { + in: TestValue(VarChar, "\x00'\"\b\n\r\t\x1A\\"), + outSQL: "'\\0\\'\\\"\\b\\n\\r\\t\\Z\\\\'", + }, { + in: TestValue(Bit, "a"), + outSQL: "b'01100001'", + }, { + in: TestTuple(TestValue(Int64, "1"), TestValue(VarChar, "foo")), + outSQL: "(1, 'foo')", + }} + for _, tcase := range testcases { + var buf strings.Builder + + tcase.in.EncodeSQLStringBuilder(&buf) + assert.Equal(t, tcase.outSQL, buf.String()) + } +} + +func TestEncodeSQLBytes2(t *testing.T) { + testcases := []struct { + in Value + outSQL string + }{{ + in: NULL, + outSQL: "null", + }, { + in: TestValue(Int64, "1"), + outSQL: "1", + }, { + in: TestValue(VarChar, "foo"), + outSQL: "'foo'", + }, { + in: TestValue(VarChar, "\x00'\"\b\n\r\t\x1A\\"), + outSQL: "'\\0\\'\\\"\\b\\n\\r\\t\\Z\\\\'", + }, { + in: TestValue(Bit, "a"), + outSQL: "b'01100001'", + }, { + in: TestTuple(TestValue(Int64, "1"), TestValue(VarChar, "foo")), + outSQL: "\x89\x02\x011\x950\x03foo", + }} + for _, tcase := range testcases { + buf := bytes2.NewBuffer([]byte{}) + + tcase.in.EncodeSQLBytes2(buf) + assert.Equal(t, tcase.outSQL, buf.String()) + } +} + +func TestIsComparable(t *testing.T) { + testcases := []struct { + in Value + isCmp bool + }{{ + in: NULL, + isCmp: true, + }, { + in: TestValue(Int64, "1"), + isCmp: true, + }, { + in: TestValue(VarChar, "foo"), + }, { + in: TestValue(VarChar, "\x00'\"\b\n\r\t\x1A\\"), + }, { + in: TestValue(Bit, "a"), + isCmp: true, + }, { + in: TestValue(Time, "12:21:11"), + isCmp: true, + }, { + in: TestTuple(TestValue(Int64, "1"), TestValue(VarChar, "foo")), + }} + for _, tcase := range testcases { + isCmp := tcase.in.IsComparable() + assert.Equal(t, tcase.isCmp, isCmp) + } +} From 2e9807156783068d93755cbdd413775e159c4ed5 Mon Sep 17 00:00:00 2001 From: Noble Mittal Date: Sun, 12 May 2024 16:46:45 +0530 Subject: [PATCH 2/3] Run goimports lint for result_test.go Signed-off-by: Noble Mittal --- go/sqltypes/result_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/go/sqltypes/result_test.go b/go/sqltypes/result_test.go index fba543e35d7..d8075ec0633 100644 --- a/go/sqltypes/result_test.go +++ b/go/sqltypes/result_test.go @@ -20,6 +20,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "vitess.io/vitess/go/test/utils" querypb "vitess.io/vitess/go/vt/proto/query" From ab8b4b1ffa0cf7a974486822ee8c379b86b445b8 Mon Sep 17 00:00:00 2001 From: Noble Mittal Date: Sun, 12 May 2024 17:36:53 +0530 Subject: [PATCH 3/3] refactor: Use testify instead of t.error/fatal in bind_variables_test.go Signed-off-by: Noble Mittal --- go/sqltypes/bind_variables_test.go | 126 +++++++++-------------------- 1 file changed, 37 insertions(+), 89 deletions(-) diff --git a/go/sqltypes/bind_variables_test.go b/go/sqltypes/bind_variables_test.go index 77b3381f751..99a9e2f2ef3 100644 --- a/go/sqltypes/bind_variables_test.go +++ b/go/sqltypes/bind_variables_test.go @@ -18,7 +18,6 @@ package sqltypes import ( "fmt" - "strings" "testing" "github.com/stretchr/testify/assert" @@ -92,15 +91,10 @@ func TestBuildBindVariables(t *testing.T) { }} for _, tcase := range tcases { bindVars, err := BuildBindVariables(tcase.in) - if err != nil { - if err.Error() != tcase.err { - t.Errorf("MapToBindVars(%v) error: %v, want %s", tcase.in, err, tcase.err) - } - continue - } - if tcase.err != "" { - t.Errorf("MapToBindVars(%v) error: nil, want %s", tcase.in, tcase.err) - continue + if tcase.err == "" { + assert.NoError(t, err) + } else { + assert.ErrorContains(t, err, tcase.err) } if !BindVariablesEqual(bindVars, tcase.out) { t.Errorf("MapToBindVars(%v): %v, want %s", tcase.in, bindVars, tcase.out) @@ -371,14 +365,10 @@ func TestValidateBindVarables(t *testing.T) { for _, tcase := range tcases { err := ValidateBindVariables(tcase.in) if tcase.err != "" { - if err == nil || err.Error() != tcase.err { - t.Errorf("ValidateBindVars(%v): %v, want %s", tcase.in, err, tcase.err) - } + assert.ErrorContains(t, err, tcase.err) continue } - if err != nil { - t.Errorf("ValidateBindVars(%v): %v, want nil", tcase.in, err) - } + assert.NoError(t, err) } } @@ -582,22 +572,16 @@ func TestValidateBindVariable(t *testing.T) { for _, tcase := range testcases { err := ValidateBindVariable(tcase.in) if tcase.err != "" { - if err == nil || !strings.Contains(err.Error(), tcase.err) { - t.Errorf("ValidateBindVar(%v) error: %v, must contain %v", tcase.in, err, tcase.err) - } + assert.ErrorContains(t, err, tcase.err) continue } - if err != nil { - t.Errorf("ValidateBindVar(%v) error: %v", tcase.in, err) - } + assert.NoError(t, err) } // Special case: nil bind var. err := ValidateBindVariable(nil) want := "bind variable is nil" - if err == nil || err.Error() != want { - t.Errorf("ValidateBindVar(nil) error: %v, want %s", err, want) - } + assert.ErrorContains(t, err, want) } func TestBindVariableToValue(t *testing.T) { @@ -633,19 +617,13 @@ func TestBindVariablesEqual(t *testing.T) { Value: []byte("1"), }, } - if !BindVariablesEqual(bv1, bv2) { - t.Errorf("%v != %v, want equal", bv1, bv2) - } - if !BindVariablesEqual(bv1, bv3) { - t.Errorf("%v = %v, want not equal", bv1, bv3) - } + assert.True(t, BindVariablesEqual(bv1, bv2)) + assert.True(t, BindVariablesEqual(bv1, bv3)) } func TestBindVariablesFormat(t *testing.T) { tupleBindVar, err := BuildBindVariable([]int64{1, 2}) - if err != nil { - t.Fatalf("failed to create a tuple bind var: %v", err) - } + require.NoError(t, err, "failed to create a tuple bind var") bindVariables := map[string]*querypb.BindVariable{ "key_1": StringBindVariable("val_1"), @@ -655,68 +633,38 @@ func TestBindVariablesFormat(t *testing.T) { } formattedStr := FormatBindVariables(bindVariables, true /* full */, false /* asJSON */) - if !strings.Contains(formattedStr, "key_1") || - !strings.Contains(formattedStr, "val_1") { - t.Fatalf("bind variable 'key_1': 'val_1' is not formatted") - } - if !strings.Contains(formattedStr, "key_2") || - !strings.Contains(formattedStr, "789") { - t.Fatalf("bind variable 'key_2': '789' is not formatted") - } - if !strings.Contains(formattedStr, "key_3") || !strings.Contains(formattedStr, "val_3") { - t.Fatalf("bind variable 'key_3': 'val_3' is not formatted") - } - if !strings.Contains(formattedStr, "key_4:type:TUPLE") { - t.Fatalf("bind variable 'key_4': (1, 2) is not formatted") - } + assert.Contains(t, formattedStr, "key_1") + assert.Contains(t, formattedStr, "val_1") - formattedStr = FormatBindVariables(bindVariables, false /* full */, false /* asJSON */) - if !strings.Contains(formattedStr, "key_1") { - t.Fatalf("bind variable 'key_1' is not formatted") - } - if !strings.Contains(formattedStr, "key_2") || - !strings.Contains(formattedStr, "789") { - t.Fatalf("bind variable 'key_2': '789' is not formatted") - } - if !strings.Contains(formattedStr, "key_3") || !strings.Contains(formattedStr, "5 bytes") { - t.Fatalf("bind variable 'key_3' is not formatted") - } - if !strings.Contains(formattedStr, "key_4") || !strings.Contains(formattedStr, "2 items") { - t.Fatalf("bind variable 'key_4' is not formatted") - } + assert.Contains(t, formattedStr, "key_2") + assert.Contains(t, formattedStr, "789") - formattedStr = FormatBindVariables(bindVariables, true /* full */, true /* asJSON */) - t.Logf("%q", formattedStr) - if !strings.Contains(formattedStr, "\"key_1\": {\"type\": \"VARCHAR\", \"value\": \"val_1\"}") { - t.Fatalf("bind variable 'key_1' is not formatted") - } + assert.Contains(t, formattedStr, "key_3") + assert.Contains(t, formattedStr, "val_3") - if !strings.Contains(formattedStr, "\"key_2\": {\"type\": \"INT64\", \"value\": 789}") { - t.Fatalf("bind variable 'key_2' is not formatted") - } + assert.Contains(t, formattedStr, "key_4:type:TUPLE") - if !strings.Contains(formattedStr, "\"key_3\": {\"type\": \"VARBINARY\", \"value\": \"val_3\"}") { - t.Fatalf("bind variable 'key_3' is not formatted") - } + formattedStr = FormatBindVariables(bindVariables, false /* full */, false /* asJSON */) + assert.Contains(t, formattedStr, "key_1") - if !strings.Contains(formattedStr, "\"key_4\": {\"type\": \"TUPLE\", \"value\": \"\"}") { - t.Fatalf("bind variable 'key_4' is not formatted") - } + assert.Contains(t, formattedStr, "key_2") + assert.Contains(t, formattedStr, "789") - formattedStr = FormatBindVariables(bindVariables, false /* full */, true /* asJSON */) - if !strings.Contains(formattedStr, "\"key_1\": {\"type\": \"VARCHAR\", \"value\": \"5 bytes\"}") { - t.Fatalf("bind variable 'key_1' is not formatted") - } + assert.Contains(t, formattedStr, "key_3") + assert.Contains(t, formattedStr, "5 bytes") - if !strings.Contains(formattedStr, "\"key_2\": {\"type\": \"INT64\", \"value\": 789}") { - t.Fatalf("bind variable 'key_2' is not formatted") - } + assert.Contains(t, formattedStr, "key_4") + assert.Contains(t, formattedStr, "2 items") - if !strings.Contains(formattedStr, "\"key_3\": {\"type\": \"VARCHAR\", \"value\": \"5 bytes\"}") { - t.Fatalf("bind variable 'key_3' is not formatted") - } + formattedStr = FormatBindVariables(bindVariables, true /* full */, true /* asJSON */) + assert.Contains(t, formattedStr, "\"key_1\": {\"type\": \"VARCHAR\", \"value\": \"val_1\"}") + assert.Contains(t, formattedStr, "\"key_2\": {\"type\": \"INT64\", \"value\": 789}") + assert.Contains(t, formattedStr, "\"key_3\": {\"type\": \"VARBINARY\", \"value\": \"val_3\"}") + assert.Contains(t, formattedStr, "\"key_4\": {\"type\": \"TUPLE\", \"value\": \"\"}") - if !strings.Contains(formattedStr, "\"key_4\": {\"type\": \"VARCHAR\", \"value\": \"2 items\"}") { - t.Fatalf("bind variable 'key_4' is not formatted") - } + formattedStr = FormatBindVariables(bindVariables, false /* full */, true /* asJSON */) + assert.Contains(t, formattedStr, "\"key_1\": {\"type\": \"VARCHAR\", \"value\": \"5 bytes\"}") + assert.Contains(t, formattedStr, "\"key_2\": {\"type\": \"INT64\", \"value\": 789}") + assert.Contains(t, formattedStr, "\"key_3\": {\"type\": \"VARCHAR\", \"value\": \"5 bytes\"}") + assert.Contains(t, formattedStr, "\"key_4\": {\"type\": \"VARCHAR\", \"value\": \"2 items\"}") }