Skip to content

Commit

Permalink
evalengine: Add support for enum and set
Browse files Browse the repository at this point in the history
The evalengine currently doesn't handle enum and set types properly.
The comparison function always returns 0 at the moment and we don't
consider ordering of elements etc.

Here we add two new native types to the evalengine for set and enum and
instantiate those appropriately. We also ensure we can compare them
correctly.

In case we don't have the schema information with the values, we do a
best effort case of depending on the string representation. This is not
correct always of course, but at least makes equality comparison work
for those cases and only ordering is off in that scenario.

Signed-off-by: Dirkjan Bussink <d.bussink@gmail.com>
  • Loading branch information
dbussink committed Apr 23, 2024
1 parent 5e2a873 commit 4cb5f7d
Show file tree
Hide file tree
Showing 49 changed files with 699 additions and 220 deletions.
4 changes: 4 additions & 0 deletions go/mysql/json/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ func NewFromSQL(v sqltypes.Value) (*Value, error) {
return NewDate(v.RawStr()), nil
case v.IsTime():
return NewTime(v.RawStr()), nil
case v.IsEnum():
return NewString(v.RawStr()), nil
case v.IsSet():
return NewString(v.RawStr()), nil
default:
return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "cannot coerce %v as a JSON type", v)
}
Expand Down
36 changes: 36 additions & 0 deletions go/sqltypes/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,12 @@ var RandomGenerators = map[Type]RandomGenerator{
}
return v
},
Enum: func() Value {
return MakeTrusted(Enum, randEnum())
},
Set: func() Value {
return MakeTrusted(Set, randSet())
},
}

func randTime() time.Time {
Expand All @@ -289,3 +295,33 @@ func randTime() time.Time {
sec := rand.Int64N(delta) + min
return time.Unix(sec, 0)
}

func randEnum() []byte {
enums := []string{
"xxsmall",
"xsmall",
"small",
"medium",
"large",
"xlarge",
"xxlarge",
}
return []byte(enums[rand.IntN(len(enums))])
}

func randSet() []byte {
set := []string{
"a",
"b",
"c",
"d",
"e",
"f",
"g",
}
rand.Shuffle(len(set), func(i, j int) {
set[i], set[j] = set[j], set[i]
})
set = set[:rand.IntN(len(set))]
return []byte(strings.Join(set, ","))
}
10 changes: 10 additions & 0 deletions go/sqltypes/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,16 @@ func IsNull(t querypb.Type) bool {
return t == Null
}

// IsEnum returns true if the type is Enum type
func IsEnum(t querypb.Type) bool {
return t == Enum
}

// IsSet returns true if the type is Set type
func IsSet(t querypb.Type) bool {
return t == Set
}

// Vitess data types. These are idiomatically named synonyms for the querypb.Type values.
// Although these constants are interchangeable, they should be treated as different from querypb.Type.
// Use the synonyms only to refer to the type in Value. For proto variables, use the querypb.Type constants instead.
Expand Down
10 changes: 10 additions & 0 deletions go/sqltypes/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,16 @@ func (v Value) IsDecimal() bool {
return IsDecimal(v.Type())
}

// IsEnum returns true if Value is time.
func (v Value) IsEnum() bool {
return v.Type() == querypb.Type_ENUM
}

// IsSet returns true if Value is time.
func (v Value) IsSet() bool {
return v.Type() == querypb.Type_SET
}

// IsComparable returns true if the Value is null safe comparable without collation information.
func (v *Value) IsComparable() bool {
if v.Type() == Null || IsNumber(v.Type()) || IsBinary(v.Type()) {
Expand Down
15 changes: 14 additions & 1 deletion go/test/endtoend/vtgate/queries/misc/misc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func start(t *testing.T) (utils.MySQLCompare, func()) {
require.NoError(t, err)

deleteAll := func() {
tables := []string{"t1", "tbl", "unq_idx", "nonunq_idx", "uks.unsharded"}
tables := []string{"t1", "tbl", "unq_idx", "nonunq_idx", "tbl_enum_set", "uks.unsharded"}
for _, table := range tables {
_, _ = mcmp.ExecAndIgnore("delete from " + table)
}
Expand Down Expand Up @@ -452,3 +452,16 @@ func TestStraightJoin(t *testing.T) {
require.NoError(t, err)
require.Contains(t, fmt.Sprintf("%v", res.Rows), "t1_tbl")
}

func TestEnumSetVals(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 20, "vtgate")

mcmp, closer := start(t)
defer closer()
require.NoError(t, utils.WaitForAuthoritative(t, keyspaceName, "tbl_enum_set", clusterInstance.VtgateProcess.ReadVSchema))

mcmp.Exec("insert into tbl_enum_set(id, enum_col, set_col) values (1, 'medium', 'a,b,e'), (2, 'small', 'e,f,g'), (3, 'large', 'c'), (4, 'xsmall', 'a,b'), (5, 'medium', 'a,d')")

mcmp.AssertMatches("select id, enum_col, cast(enum_col as signed) from tbl_enum_set order by enum_col, id", `[[INT64(4) ENUM("xsmall") INT64(1)] [INT64(2) ENUM("small") INT64(2)] [INT64(1) ENUM("medium") INT64(3)] [INT64(5) ENUM("medium") INT64(3)] [INT64(3) ENUM("large") INT64(4)]]`)
mcmp.AssertMatches("select id, set_col, cast(set_col as unsigned) from tbl_enum_set order by set_col, id", `[[INT64(4) SET("a,b") UINT64(3)] [INT64(3) SET("c") UINT64(4)] [INT64(5) SET("a,d") UINT64(9)] [INT64(1) SET("a,b,e") UINT64(19)] [INT64(2) SET("e,f,g") UINT64(112)]]`)
}
8 changes: 8 additions & 0 deletions go/test/endtoend/vtgate/queries/misc/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,11 @@ create table tbl
primary key (id),
unique (unq_col)
) Engine = InnoDB;

create table tbl_enum_set
(
id bigint,
enum_col enum('xsmall', 'small', 'medium', 'large', 'xlarge'),
set_col set('a', 'b', 'c', 'd', 'e', 'f', 'g'),
primary key (id)
) Engine = InnoDB;
8 changes: 8 additions & 0 deletions go/test/endtoend/vtgate/queries/misc/vschema.json
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@
}
]
},
"tbl_enum_set": {
"column_vindexes": [
{
"column": "id",
"name": "hash"
}
]
},
"unq_idx": {
"column_vindexes": [
{
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtexplain/vtexplain_vttablet.go
Original file line number Diff line number Diff line change
Expand Up @@ -755,7 +755,7 @@ func (t *explainTablet) analyzeWhere(selStmt *sqlparser.Select, tableColumnMap m
// Check if we have a duplicate value
isNewValue := true
for _, v := range inVal {
result, err := evalengine.NullsafeCompare(v, value, t.collationEnv, t.collationEnv.DefaultConnectionCharset())
result, err := evalengine.NullsafeCompare(v, value, t.collationEnv, t.collationEnv.DefaultConnectionCharset(), nil)
if err != nil {
return "", nil, 0, nil, err
}
Expand Down
9 changes: 6 additions & 3 deletions go/vt/vtgate/engine/aggregations.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ type aggregatorDistinct struct {
last sqltypes.Value
coll collations.ID
collationEnv *collations.Environment
values []string
}

func (a *aggregatorDistinct) shouldReturn(row []sqltypes.Value) (bool, error) {
Expand All @@ -115,7 +116,7 @@ func (a *aggregatorDistinct) shouldReturn(row []sqltypes.Value) (bool, error) {
next := row[a.column]
if !last.IsNull() {
if last.TinyWeightCmp(next) == 0 {
cmp, err := evalengine.NullsafeCompare(last, next, a.collationEnv, a.coll)
cmp, err := evalengine.NullsafeCompare(last, next, a.collationEnv, a.coll, a.values)
if err != nil {
return true, err
}
Expand Down Expand Up @@ -386,6 +387,7 @@ func newAggregation(fields []*querypb.Field, aggregates []*AggregateParams) (agg
column: distinct,
coll: aggr.Type.Collation(),
collationEnv: aggr.CollationEnv,
values: aggr.Type.Values(),
},
}

Expand All @@ -405,22 +407,23 @@ func newAggregation(fields []*querypb.Field, aggregates []*AggregateParams) (agg
column: distinct,
coll: aggr.Type.Collation(),
collationEnv: aggr.CollationEnv,
values: aggr.Type.Values(),
},
}

case AggregateMin:
ag = &aggregatorMin{
aggregatorMinMax{
from: aggr.Col,
minmax: evalengine.NewAggregationMinMax(sourceType, aggr.CollationEnv, aggr.Type.Collation()),
minmax: evalengine.NewAggregationMinMax(sourceType, aggr.CollationEnv, aggr.Type.Collation(), aggr.Type.Values()),
},
}

case AggregateMax:
ag = &aggregatorMax{
aggregatorMinMax{
from: aggr.Col,
minmax: evalengine.NewAggregationMinMax(sourceType, aggr.CollationEnv, aggr.Type.Collation()),
minmax: evalengine.NewAggregationMinMax(sourceType, aggr.CollationEnv, aggr.Type.Collation(), aggr.Type.Values()),
},
}

Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/engine/distinct.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,14 @@ func (pt *probeTable) hashCodeForRow(inputRow sqltypes.Row) (vthash.Hash, error)
return vthash.Hash{}, vterrors.VT13001("index out of range in row when creating the DISTINCT hash code")
}
col := inputRow[checkCol.Col]
err := evalengine.NullsafeHashcode128(&hasher, col, checkCol.Type.Collation(), checkCol.Type.Type(), pt.sqlmode)
err := evalengine.NullsafeHashcode128(&hasher, col, checkCol.Type.Collation(), checkCol.Type.Type(), pt.sqlmode, checkCol.Type.Values())
if err != nil {
if err != evalengine.UnsupportedCollationHashError || checkCol.WsCol == nil {
return vthash.Hash{}, err
}
checkCol = checkCol.SwitchToWeightString()
pt.checkCols[i] = checkCol
err = evalengine.NullsafeHashcode128(&hasher, inputRow[checkCol.Col], checkCol.Type.Collation(), checkCol.Type.Type(), pt.sqlmode)
err = evalengine.NullsafeHashcode128(&hasher, inputRow[checkCol.Col], checkCol.Type.Collation(), checkCol.Type.Type(), pt.sqlmode, checkCol.Type.Values())
if err != nil {
return vthash.Hash{}, err
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/engine/distinct_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func TestDistinct(t *testing.T) {
}
checkCols = append(checkCols, CheckCol{
Col: i,
Type: evalengine.NewTypeEx(tc.inputs.Fields[i].Type, collID, false, 0, 0),
Type: evalengine.NewTypeEx(tc.inputs.Fields[i].Type, collID, false, 0, 0, nil),
CollationEnv: collations.MySQL8(),
})
}
Expand Down
13 changes: 9 additions & 4 deletions go/vt/vtgate/engine/hash_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ type (
ComparisonType querypb.Type

CollationEnv *collations.Environment

// Values for enum and set types
Values []string
}

hashJoinProbeTable struct {
Expand All @@ -78,6 +81,7 @@ type (
cols []int
hasher vthash.Hasher
sqlmode evalengine.SQLMode
values []string
}

probeTableEntry struct {
Expand All @@ -94,7 +98,7 @@ func (hj *HashJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars ma
return nil, err
}

pt := newHashJoinProbeTable(hj.Collation, hj.ComparisonType, hj.LHSKey, hj.RHSKey, hj.Cols)
pt := newHashJoinProbeTable(hj.Collation, hj.ComparisonType, hj.LHSKey, hj.RHSKey, hj.Cols, hj.Values)
// build the probe table from the LHS result
for _, row := range lresult.Rows {
err := pt.addLeftRow(row)
Expand Down Expand Up @@ -130,7 +134,7 @@ func (hj *HashJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars ma
// TryStreamExecute implements the Primitive interface
func (hj *HashJoin) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
// build the probe table from the LHS result
pt := newHashJoinProbeTable(hj.Collation, hj.ComparisonType, hj.LHSKey, hj.RHSKey, hj.Cols)
pt := newHashJoinProbeTable(hj.Collation, hj.ComparisonType, hj.LHSKey, hj.RHSKey, hj.Cols, hj.Values)
var lfields []*querypb.Field
var mu sync.Mutex
err := vcursor.StreamExecutePrimitive(ctx, hj.Left, bindVars, wantfields, func(result *sqltypes.Result) error {
Expand Down Expand Up @@ -260,7 +264,7 @@ func (hj *HashJoin) description() PrimitiveDescription {
}
}

func newHashJoinProbeTable(coll collations.ID, typ querypb.Type, lhsKey, rhsKey int, cols []int) *hashJoinProbeTable {
func newHashJoinProbeTable(coll collations.ID, typ querypb.Type, lhsKey, rhsKey int, cols []int, values []string) *hashJoinProbeTable {
return &hashJoinProbeTable{
innerMap: map[vthash.Hash]*probeTableEntry{},
coll: coll,
Expand All @@ -269,6 +273,7 @@ func newHashJoinProbeTable(coll collations.ID, typ querypb.Type, lhsKey, rhsKey
rhsKey: rhsKey,
cols: cols,
hasher: vthash.New(),
values: values,
}
}

Expand All @@ -286,7 +291,7 @@ func (pt *hashJoinProbeTable) addLeftRow(r sqltypes.Row) error {
}

func (pt *hashJoinProbeTable) hash(val sqltypes.Value) (vthash.Hash, error) {
err := evalengine.NullsafeHashcode128(&pt.hasher, val, pt.coll, pt.typ, pt.sqlmode)
err := evalengine.NullsafeHashcode128(&pt.hasher, val, pt.coll, pt.typ, pt.sqlmode, pt.values)
if err != nil {
return vthash.Hash{}, err
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/engine/opcode/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ func (code AggregateOpcode) ResolveType(t evalengine.Type, env *collations.Envir
if code == AggregateAvg {
scale += 4
}
return evalengine.NewTypeEx(sqltype, collation, nullable, size, scale)
return evalengine.NewTypeEx(sqltype, collation, nullable, size, scale, t.Values())
}

func (code AggregateOpcode) NeedsComparableValues() bool {
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/engine/ordered_aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,14 +344,14 @@ func (oa *OrderedAggregate) nextGroupBy(currentKey, nextRow []sqltypes.Value) (n
return nextRow, true, nil
}

cmp, err := evalengine.NullsafeCompare(v1, v2, oa.CollationEnv, gb.Type.Collation())
cmp, err := evalengine.NullsafeCompare(v1, v2, oa.CollationEnv, gb.Type.Collation(), gb.Type.Values())
if err != nil {
_, isCollationErr := err.(evalengine.UnsupportedCollationError)
if !isCollationErr || gb.WeightStringCol == -1 {
return nil, false, err
}
gb.KeyCol = gb.WeightStringCol
cmp, err = evalengine.NullsafeCompare(currentKey[gb.WeightStringCol], nextRow[gb.WeightStringCol], oa.CollationEnv, gb.Type.Collation())
cmp, err = evalengine.NullsafeCompare(currentKey[gb.WeightStringCol], nextRow[gb.WeightStringCol], oa.CollationEnv, gb.Type.Collation(), gb.Type.Values())
if err != nil {
return nil, false, err
}
Expand Down
7 changes: 4 additions & 3 deletions go/vt/vtgate/evalengine/api_aggregation.go
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,7 @@ type aggregationMinMax struct {
current sqltypes.Value
collation collations.ID
collationEnv *collations.Environment
values []string
}

func (a *aggregationMinMax) minmax(value sqltypes.Value, max bool) (err error) {
Expand All @@ -458,7 +459,7 @@ func (a *aggregationMinMax) minmax(value sqltypes.Value, max bool) (err error) {
a.current = value
return nil
}
n, err := compare(a.current, value, a.collationEnv, a.collation)
n, err := compare(a.current, value, a.collationEnv, a.collation, a.values)
if err != nil {
return err
}
Expand All @@ -484,7 +485,7 @@ func (a *aggregationMinMax) Reset() {
a.current = sqltypes.NULL
}

func NewAggregationMinMax(typ sqltypes.Type, collationEnv *collations.Environment, collation collations.ID) MinMax {
func NewAggregationMinMax(typ sqltypes.Type, collationEnv *collations.Environment, collation collations.ID, values []string) MinMax {
switch {
case sqltypes.IsSigned(typ):
return &aggregationInt{t: typ}
Expand All @@ -495,6 +496,6 @@ func NewAggregationMinMax(typ sqltypes.Type, collationEnv *collations.Environmen
case sqltypes.IsDecimal(typ):
return &aggregationDecimal{}
default:
return &aggregationMinMax{collation: collation, collationEnv: collationEnv}
return &aggregationMinMax{collation: collation, collationEnv: collationEnv, values: values}
}
}
4 changes: 2 additions & 2 deletions go/vt/vtgate/evalengine/api_aggregation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func TestMinMax(t *testing.T) {
for i, tcase := range tcases {
t.Run(strconv.Itoa(i), func(t *testing.T) {
t.Run("Min", func(t *testing.T) {
agg := NewAggregationMinMax(tcase.type_, collations.MySQL8(), tcase.coll)
agg := NewAggregationMinMax(tcase.type_, collations.MySQL8(), tcase.coll, nil)

for _, v := range tcase.values {
err := agg.Min(v)
Expand All @@ -153,7 +153,7 @@ func TestMinMax(t *testing.T) {
})

t.Run("Max", func(t *testing.T) {
agg := NewAggregationMinMax(tcase.type_, collations.MySQL8(), tcase.coll)
agg := NewAggregationMinMax(tcase.type_, collations.MySQL8(), tcase.coll, nil)

for _, v := range tcase.values {
err := agg.Max(v)
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/evalengine/api_coerce.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import (
)

func CoerceTo(value sqltypes.Value, typ Type, sqlmode SQLMode) (sqltypes.Value, error) {
cast, err := valueToEvalCast(value, value.Type(), collations.Unknown, sqlmode)
cast, err := valueToEvalCast(value, value.Type(), collations.Unknown, typ.values, sqlmode)
if err != nil {
return sqltypes.Value{}, err
}
Expand All @@ -33,7 +33,7 @@ func CoerceTo(value sqltypes.Value, typ Type, sqlmode SQLMode) (sqltypes.Value,

// CoerceTypes takes two input types, and decides how they should be coerced before compared
func CoerceTypes(v1, v2 Type, collationEnv *collations.Environment) (out Type, err error) {
if v1 == v2 {
if v1.Equal(&v2) {
return v1, nil
}
if sqltypes.IsNull(v1.Type()) || sqltypes.IsNull(v2.Type()) {
Expand Down
Loading

0 comments on commit 4cb5f7d

Please sign in to comment.