Skip to content

Commit

Permalink
refactor: introduce evalengine type and use it (#14292)
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <andres@planetscale.com>
  • Loading branch information
systay authored Oct 19, 2023
1 parent 9d64432 commit 8cade46
Show file tree
Hide file tree
Showing 37 changed files with 330 additions and 321 deletions.
21 changes: 10 additions & 11 deletions go/vt/vtgate/engine/aggregations.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,9 @@ type AggregateParams struct {
Col int

// These are used only for distinct opcodes.
KeyCol int
WCol int
Type sqltypes.Type
CollationID collations.ID
KeyCol int
WCol int
Type evalengine.Type

Alias string `json:",omitempty"`
Expr sqlparser.Expr
Expand All @@ -58,7 +57,7 @@ func NewAggregateParam(opcode AggregateOpcode, col int, alias string) *Aggregate
Col: col,
Alias: alias,
WCol: -1,
Type: sqltypes.Unknown,
Type: evalengine.UnknownType(),
}
if opcode.NeedsComparableValues() {
out.KeyCol = col
Expand All @@ -75,8 +74,8 @@ func (ap *AggregateParams) String() string {
if ap.WAssigned() {
keyCol = fmt.Sprintf("%s|%d", keyCol, ap.WCol)
}
if sqltypes.IsText(ap.Type) && ap.CollationID != collations.Unknown {
keyCol += " COLLATE " + collations.Local().LookupName(ap.CollationID)
if sqltypes.IsText(ap.Type.Type) && ap.Type.Coll != collations.Unknown {
keyCol += " COLLATE " + collations.Local().LookupName(ap.Type.Coll)
}
dispOrigOp := ""
if ap.OrigOpcode != AggregateUnassigned && ap.OrigOpcode != ap.Opcode {
Expand Down Expand Up @@ -378,7 +377,7 @@ func newAggregation(fields []*querypb.Field, aggregates []*AggregateParams) (agg
from: aggr.Col,
distinct: aggregatorDistinct{
column: distinct,
coll: aggr.CollationID,
coll: aggr.Type.Coll,
},
}

Expand All @@ -396,23 +395,23 @@ func newAggregation(fields []*querypb.Field, aggregates []*AggregateParams) (agg
sum: sum,
distinct: aggregatorDistinct{
column: distinct,
coll: aggr.CollationID,
coll: aggr.Type.Coll,
},
}

case AggregateMin:
ag = &aggregatorMin{
aggregatorMinMax{
from: aggr.Col,
minmax: evalengine.NewAggregationMinMax(sourceType, aggr.CollationID),
minmax: evalengine.NewAggregationMinMax(sourceType, aggr.Type.Coll),
},
}

case AggregateMax:
ag = &aggregatorMax{
aggregatorMinMax{
from: aggr.Col,
minmax: evalengine.NewAggregationMinMax(sourceType, aggr.CollationID),
minmax: evalengine.NewAggregationMinMax(sourceType, aggr.Type.Coll),
},
}

Expand Down
8 changes: 4 additions & 4 deletions go/vt/vtgate/engine/cached_size.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion go/vt/vtgate/engine/comparer.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func extractSlices(input []OrderByParams) []*comparer {
weightString: order.WeightStringCol,
desc: order.Desc,
starColFixedIndex: order.StarColFixedIndex,
collationID: order.CollationID,
collationID: order.Type.Coll,
})
}
return result
Expand Down
5 changes: 2 additions & 3 deletions go/vt/vtgate/engine/delete_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"errors"
"testing"

"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/vt/vtgate/evalengine"

"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -90,7 +89,7 @@ func TestDeleteEqual(t *testing.T) {
})

// Failure case
expr := evalengine.NewBindVar("aa", sqltypes.Unknown, collations.Unknown)
expr := evalengine.NewBindVar("aa", evalengine.UnknownType())
del.Values = []evalengine.Expr{expr}
_, err = del.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false)
require.EqualError(t, err, "query arguments missing for aa")
Expand Down Expand Up @@ -122,7 +121,7 @@ func TestDeleteEqualMultiCol(t *testing.T) {
})

// Failure case
expr := evalengine.NewBindVar("aa", sqltypes.Unknown, collations.Unknown)
expr := evalengine.NewBindVar("aa", evalengine.UnknownType())
del.Values = []evalengine.Expr{expr}
_, err = del.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false)
require.EqualError(t, err, "query arguments missing for aa")
Expand Down
26 changes: 12 additions & 14 deletions go/vt/vtgate/engine/distinct.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,9 @@ type (
Truncate int
}
CheckCol struct {
Col int
WsCol *int
Type sqltypes.Type
Collation collations.ID
Col int
WsCol *int
Type evalengine.Type
}
probeTable struct {
seenRows map[evalengine.HashCode][]sqltypes.Row
Expand Down Expand Up @@ -119,14 +118,14 @@ func (pt *probeTable) hashCodeForRow(inputRow sqltypes.Row) (evalengine.HashCode
return 0, vterrors.VT13001("index out of range in row when creating the DISTINCT hash code")
}
col := inputRow[checkCol.Col]
hashcode, err := evalengine.NullsafeHashcode(col, checkCol.Collation, col.Type())
hashcode, err := evalengine.NullsafeHashcode(col, checkCol.Type.Coll, col.Type())
if err != nil {
if err != evalengine.UnsupportedCollationHashError || checkCol.WsCol == nil {
return 0, err
}
checkCol = checkCol.SwitchToWeightString()
pt.checkCols[i] = checkCol
hashcode, err = evalengine.NullsafeHashcode(inputRow[checkCol.Col], checkCol.Collation, col.Type())
hashcode, err = evalengine.NullsafeHashcode(inputRow[checkCol.Col], checkCol.Type.Coll, col.Type())
if err != nil {
return 0, err
}
Expand All @@ -138,15 +137,15 @@ func (pt *probeTable) hashCodeForRow(inputRow sqltypes.Row) (evalengine.HashCode

func (pt *probeTable) equal(a, b sqltypes.Row) (bool, error) {
for i, checkCol := range pt.checkCols {
cmp, err := evalengine.NullsafeCompare(a[i], b[i], checkCol.Collation)
cmp, err := evalengine.NullsafeCompare(a[i], b[i], checkCol.Type.Coll)
if err != nil {
_, isComparisonErr := err.(evalengine.UnsupportedComparisonError)
if !isComparisonErr || checkCol.WsCol == nil {
return false, err
}
checkCol = checkCol.SwitchToWeightString()
pt.checkCols[i] = checkCol
cmp, err = evalengine.NullsafeCompare(a[i], b[i], checkCol.Collation)
cmp, err = evalengine.NullsafeCompare(a[i], b[i], checkCol.Type.Coll)
if err != nil {
return false, err
}
Expand Down Expand Up @@ -273,17 +272,16 @@ func (d *Distinct) description() PrimitiveDescription {
// SwitchToWeightString returns a new CheckCol that works on the weight string column instead
func (cc CheckCol) SwitchToWeightString() CheckCol {
return CheckCol{
Col: *cc.WsCol,
WsCol: nil,
Type: sqltypes.VarBinary,
Collation: collations.CollationBinaryID,
Col: *cc.WsCol,
WsCol: nil,
Type: evalengine.Type{Type: sqltypes.VarBinary, Coll: collations.CollationBinaryID},
}
}

func (cc CheckCol) String() string {
var collation string
if sqltypes.IsText(cc.Type) && cc.Collation != collations.Unknown {
collation = ": " + collations.Local().LookupName(cc.Collation)
if sqltypes.IsText(cc.Type.Type) && cc.Type.Coll != collations.Unknown {
collation = ": " + collations.Local().LookupName(cc.Type.Coll)
}

var column string
Expand Down
26 changes: 15 additions & 11 deletions go/vt/vtgate/engine/distinct_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"fmt"
"testing"

"vitess.io/vitess/go/vt/vtgate/evalengine"

"vitess.io/vitess/go/mysql/collations"

"vitess.io/vitess/go/test/utils"
Expand Down Expand Up @@ -86,10 +88,14 @@ func TestDistinct(t *testing.T) {
if sqltypes.IsNumber(tc.inputs.Fields[i].Type) {
collID = collations.CollationBinaryID
}
t := evalengine.Type{
Type: tc.inputs.Fields[i].Type,
Coll: collID,
Nullable: false,
}
checkCols = append(checkCols, CheckCol{
Col: i,
Type: tc.inputs.Fields[i].Type,
Collation: collID,
Col: i,
Type: t,
})
}
}
Expand Down Expand Up @@ -132,10 +138,9 @@ func TestDistinct(t *testing.T) {
func TestWeightStringFallBack(t *testing.T) {
offsetOne := 1
checkCols := []CheckCol{{
Col: 0,
WsCol: &offsetOne,
Type: sqltypes.Unknown,
Collation: collations.Unknown,
Col: 0,
WsCol: &offsetOne,
Type: evalengine.UnknownType(),
}}
input := r("myid|weightstring(myid)",
"varchar|varbinary",
Expand All @@ -158,9 +163,8 @@ func TestWeightStringFallBack(t *testing.T) {

// the primitive must not change just because one run needed weight strings
utils.MustMatch(t, []CheckCol{{
Col: 0,
WsCol: &offsetOne,
Type: sqltypes.Unknown,
Collation: collations.Unknown,
Col: 0,
WsCol: &offsetOne,
Type: evalengine.UnknownType(),
}}, distinct.CheckCols, "checkCols should not be updated")
}
10 changes: 5 additions & 5 deletions go/vt/vtgate/engine/limit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func TestLimitExecute(t *testing.T) {
results: []*sqltypes.Result{inputResult},
}
l = &Limit{
Count: evalengine.NewBindVar("l", sqltypes.Int64, collations.CollationBinaryID),
Count: evalengine.NewBindVar("l", evalengine.Type{Type: sqltypes.Int64, Coll: collations.CollationBinaryID}),
Input: fp,
}

Expand Down Expand Up @@ -343,8 +343,8 @@ func TestLimitOffsetExecute(t *testing.T) {
}

l = &Limit{
Count: evalengine.NewBindVar("l", sqltypes.Int64, collations.CollationBinaryID),
Offset: evalengine.NewBindVar("o", sqltypes.Int64, collations.CollationBinaryID),
Count: evalengine.NewBindVar("l", evalengine.Type{Type: sqltypes.Int64, Coll: collations.CollationBinaryID}),
Offset: evalengine.NewBindVar("o", evalengine.Type{Type: sqltypes.Int64, Coll: collations.CollationBinaryID}),
Input: fp,
}
result, err = l.TryExecute(context.Background(), &noopVCursor{}, map[string]*querypb.BindVariable{"l": sqltypes.Int64BindVariable(1), "o": sqltypes.Int64BindVariable(1)}, false)
Expand Down Expand Up @@ -396,7 +396,7 @@ func TestLimitStreamExecute(t *testing.T) {

// Test with bind vars.
fp.rewind()
l.Count = evalengine.NewBindVar("l", sqltypes.Int64, collations.CollationBinaryID)
l.Count = evalengine.NewBindVar("l", evalengine.Type{Type: sqltypes.Int64, Coll: collations.CollationBinaryID})
results = nil
err = l.TryStreamExecute(context.Background(), &noopVCursor{}, map[string]*querypb.BindVariable{"l": sqltypes.Int64BindVariable(2)}, true, func(qr *sqltypes.Result) error {
results = append(results, qr)
Expand Down Expand Up @@ -540,7 +540,7 @@ func TestLimitInputFail(t *testing.T) {

func TestLimitInvalidCount(t *testing.T) {
l := &Limit{
Count: evalengine.NewBindVar("l", sqltypes.Int64, collations.CollationBinaryID),
Count: evalengine.NewBindVar("l", evalengine.Type{Type: sqltypes.Int64, Coll: collations.CollationBinaryID}),
}
_, _, err := l.getCountAndOffset(context.Background(), &noopVCursor{}, nil)
assert.EqualError(t, err, "query arguments missing for l")
Expand Down
24 changes: 11 additions & 13 deletions go/vt/vtgate/engine/memory_sort_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func TestMemorySortExecute(t *testing.T) {
utils.MustMatch(t, wantResult, result)

fp.rewind()
ms.UpperLimit = evalengine.NewBindVar("__upper_limit", sqltypes.Int64, collations.CollationBinaryID)
ms.UpperLimit = evalengine.NewBindVar("__upper_limit", evalengine.Type{Type: sqltypes.Int64, Coll: collations.CollationBinaryID})
bv := map[string]*querypb.BindVariable{"__upper_limit": sqltypes.Int64BindVariable(3)}

result, err = ms.TryExecute(context.Background(), &noopVCursor{}, bv, false)
Expand Down Expand Up @@ -136,7 +136,7 @@ func TestMemorySortStreamExecuteWeightString(t *testing.T) {

t.Run("Limit test", func(t *testing.T) {
fp.rewind()
ms.UpperLimit = evalengine.NewBindVar("__upper_limit", sqltypes.Int64, collations.CollationBinaryID)
ms.UpperLimit = evalengine.NewBindVar("__upper_limit", evalengine.Type{Type: sqltypes.Int64, Coll: collations.CollationBinaryID})
bv := map[string]*querypb.BindVariable{"__upper_limit": sqltypes.Int64BindVariable(3)}

results = nil
Expand Down Expand Up @@ -194,7 +194,7 @@ func TestMemorySortExecuteWeightString(t *testing.T) {
utils.MustMatch(t, wantResult, result)

fp.rewind()
ms.UpperLimit = evalengine.NewBindVar("__upper_limit", sqltypes.Int64, collations.CollationBinaryID)
ms.UpperLimit = evalengine.NewBindVar("__upper_limit", evalengine.Type{Type: sqltypes.Int64, Coll: collations.CollationBinaryID})
bv := map[string]*querypb.BindVariable{"__upper_limit": sqltypes.Int64BindVariable(3)}

result, err = ms.TryExecute(context.Background(), &noopVCursor{}, bv, false)
Expand Down Expand Up @@ -228,9 +228,8 @@ func TestMemorySortStreamExecuteCollation(t *testing.T) {
collationID, _ := collations.Local().LookupID("utf8mb4_hu_0900_ai_ci")
ms := &MemorySort{
OrderBy: []OrderByParams{{
Col: 0,
Type: sqltypes.VarChar,
CollationID: collationID,
Col: 0,
Type: evalengine.Type{Type: sqltypes.VarChar, Coll: collationID},
}},
Input: fp,
}
Expand Down Expand Up @@ -278,7 +277,7 @@ func TestMemorySortStreamExecuteCollation(t *testing.T) {

t.Run("Limit test", func(t *testing.T) {
fp.rewind()
ms.UpperLimit = evalengine.NewBindVar("__upper_limit", sqltypes.Int64, collations.CollationBinaryID)
ms.UpperLimit = evalengine.NewBindVar("__upper_limit", evalengine.Type{Type: sqltypes.Int64, Coll: collations.CollationBinaryID})
bv := map[string]*querypb.BindVariable{"__upper_limit": sqltypes.Int64BindVariable(3)}

results = nil
Expand Down Expand Up @@ -317,9 +316,8 @@ func TestMemorySortExecuteCollation(t *testing.T) {
collationID, _ := collations.Local().LookupID("utf8mb4_hu_0900_ai_ci")
ms := &MemorySort{
OrderBy: []OrderByParams{{
Col: 0,
Type: sqltypes.VarChar,
CollationID: collationID,
Col: 0,
Type: evalengine.Type{Type: sqltypes.VarChar, Coll: collationID},
}},
Input: fp,
}
Expand All @@ -338,7 +336,7 @@ func TestMemorySortExecuteCollation(t *testing.T) {
utils.MustMatch(t, wantResult, result)

fp.rewind()
ms.UpperLimit = evalengine.NewBindVar("__upper_limit", sqltypes.Int64, collations.CollationBinaryID)
ms.UpperLimit = evalengine.NewBindVar("__upper_limit", evalengine.Type{Type: sqltypes.Int64, Coll: collations.CollationBinaryID})
bv := map[string]*querypb.BindVariable{"__upper_limit": sqltypes.Int64BindVariable(3)}

result, err = ms.TryExecute(context.Background(), &noopVCursor{}, bv, false)
Expand Down Expand Up @@ -395,7 +393,7 @@ func TestMemorySortStreamExecute(t *testing.T) {
utils.MustMatch(t, wantResults, results)

fp.rewind()
ms.UpperLimit = evalengine.NewBindVar("__upper_limit", sqltypes.Int64, collations.CollationBinaryID)
ms.UpperLimit = evalengine.NewBindVar("__upper_limit", evalengine.Type{Type: sqltypes.Int64, Coll: collations.CollationBinaryID})
bv := map[string]*querypb.BindVariable{"__upper_limit": sqltypes.Int64BindVariable(3)}

results = nil
Expand Down Expand Up @@ -554,7 +552,7 @@ func TestMemorySortMultiColumn(t *testing.T) {
utils.MustMatch(t, wantResult, result)

fp.rewind()
ms.UpperLimit = evalengine.NewBindVar("__upper_limit", sqltypes.Int64, collations.CollationBinaryID)
ms.UpperLimit = evalengine.NewBindVar("__upper_limit", evalengine.Type{Type: sqltypes.Int64, Coll: collations.CollationBinaryID})
bv := map[string]*querypb.BindVariable{"__upper_limit": sqltypes.Int64BindVariable(3)}

result, err = ms.TryExecute(context.Background(), &noopVCursor{}, bv, false)
Expand Down
Loading

0 comments on commit 8cade46

Please sign in to comment.