Skip to content

Commit

Permalink
evalengine: Add support for enum and set (#15783)
Browse files Browse the repository at this point in the history
Signed-off-by: Dirkjan Bussink <d.bussink@gmail.com>
  • Loading branch information
dbussink authored Apr 24, 2024
1 parent 5f47800 commit 4c2df48
Show file tree
Hide file tree
Showing 51 changed files with 799 additions and 238 deletions.
4 changes: 4 additions & 0 deletions go/mysql/json/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ func NewFromSQL(v sqltypes.Value) (*Value, error) {
return NewDate(v.RawStr()), nil
case v.IsTime():
return NewTime(v.RawStr()), nil
case v.IsEnum():
return NewString(v.RawStr()), nil
case v.IsSet():
return NewString(v.RawStr()), nil
default:
return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "cannot coerce %v as a JSON type", v)
}
Expand Down
36 changes: 36 additions & 0 deletions go/sqltypes/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,12 @@ var RandomGenerators = map[Type]RandomGenerator{
}
return v
},
Enum: func() Value {
return MakeTrusted(Enum, randEnum())
},
Set: func() Value {
return MakeTrusted(Set, randSet())
},
}

func randTime() time.Time {
Expand All @@ -289,3 +295,33 @@ func randTime() time.Time {
sec := rand.Int64N(delta) + min
return time.Unix(sec, 0)
}

func randEnum() []byte {
enums := []string{
"xxsmall",
"xsmall",
"small",
"medium",
"large",
"xlarge",
"xxlarge",
}
return []byte(enums[rand.IntN(len(enums))])
}

func randSet() []byte {
set := []string{
"a",
"b",
"c",
"d",
"e",
"f",
"g",
}
rand.Shuffle(len(set), func(i, j int) {
set[i], set[j] = set[j], set[i]
})
set = set[:rand.IntN(len(set))]
return []byte(strings.Join(set, ","))
}
10 changes: 10 additions & 0 deletions go/sqltypes/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,16 @@ func IsNull(t querypb.Type) bool {
return t == Null
}

// IsEnum returns true if the type is Enum type
func IsEnum(t querypb.Type) bool {
return t == Enum
}

// IsSet returns true if the type is Set type
func IsSet(t querypb.Type) bool {
return t == Set
}

// Vitess data types. These are idiomatically named synonyms for the querypb.Type values.
// Although these constants are interchangeable, they should be treated as different from querypb.Type.
// Use the synonyms only to refer to the type in Value. For proto variables, use the querypb.Type constants instead.
Expand Down
10 changes: 10 additions & 0 deletions go/sqltypes/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,16 @@ func (v Value) IsDecimal() bool {
return IsDecimal(v.Type())
}

// IsEnum returns true if Value is enum.
func (v Value) IsEnum() bool {
return v.Type() == querypb.Type_ENUM
}

// IsSet returns true if Value is set.
func (v Value) IsSet() bool {
return v.Type() == querypb.Type_SET
}

// IsComparable returns true if the Value is null safe comparable without collation information.
func (v *Value) IsComparable() bool {
if v.Type() == Null || IsNumber(v.Type()) || IsBinary(v.Type()) {
Expand Down
15 changes: 14 additions & 1 deletion go/test/endtoend/vtgate/queries/misc/misc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func start(t *testing.T) (utils.MySQLCompare, func()) {
require.NoError(t, err)

deleteAll := func() {
tables := []string{"t1", "tbl", "unq_idx", "nonunq_idx", "uks.unsharded"}
tables := []string{"t1", "tbl", "unq_idx", "nonunq_idx", "tbl_enum_set", "uks.unsharded"}
for _, table := range tables {
_, _ = mcmp.ExecAndIgnore("delete from " + table)
}
Expand Down Expand Up @@ -452,3 +452,16 @@ func TestStraightJoin(t *testing.T) {
require.NoError(t, err)
require.Contains(t, fmt.Sprintf("%v", res.Rows), "t1_tbl")
}

func TestEnumSetVals(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 20, "vtgate")

mcmp, closer := start(t)
defer closer()
require.NoError(t, utils.WaitForAuthoritative(t, keyspaceName, "tbl_enum_set", clusterInstance.VtgateProcess.ReadVSchema))

mcmp.Exec("insert into tbl_enum_set(id, enum_col, set_col) values (1, 'medium', 'a,b,e'), (2, 'small', 'e,f,g'), (3, 'large', 'c'), (4, 'xsmall', 'a,b'), (5, 'medium', 'a,d')")

mcmp.AssertMatches("select id, enum_col, cast(enum_col as signed) from tbl_enum_set order by enum_col, id", `[[INT64(4) ENUM("xsmall") INT64(1)] [INT64(2) ENUM("small") INT64(2)] [INT64(1) ENUM("medium") INT64(3)] [INT64(5) ENUM("medium") INT64(3)] [INT64(3) ENUM("large") INT64(4)]]`)
mcmp.AssertMatches("select id, set_col, cast(set_col as unsigned) from tbl_enum_set order by set_col, id", `[[INT64(4) SET("a,b") UINT64(3)] [INT64(3) SET("c") UINT64(4)] [INT64(5) SET("a,d") UINT64(9)] [INT64(1) SET("a,b,e") UINT64(19)] [INT64(2) SET("e,f,g") UINT64(112)]]`)
}
8 changes: 8 additions & 0 deletions go/test/endtoend/vtgate/queries/misc/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,11 @@ create table tbl
primary key (id),
unique (unq_col)
) Engine = InnoDB;

create table tbl_enum_set
(
id bigint,
enum_col enum('xsmall', 'small', 'medium', 'large', 'xlarge'),
set_col set('a', 'b', 'c', 'd', 'e', 'f', 'g'),
primary key (id)
) Engine = InnoDB;
8 changes: 8 additions & 0 deletions go/test/endtoend/vtgate/queries/misc/vschema.json
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@
}
]
},
"tbl_enum_set": {
"column_vindexes": [
{
"column": "id",
"name": "hash"
}
]
},
"unq_idx": {
"column_vindexes": [
{
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtexplain/vtexplain_vttablet.go
Original file line number Diff line number Diff line change
Expand Up @@ -755,7 +755,7 @@ func (t *explainTablet) analyzeWhere(selStmt *sqlparser.Select, tableColumnMap m
// Check if we have a duplicate value
isNewValue := true
for _, v := range inVal {
result, err := evalengine.NullsafeCompare(v, value, t.collationEnv, t.collationEnv.DefaultConnectionCharset())
result, err := evalengine.NullsafeCompare(v, value, t.collationEnv, t.collationEnv.DefaultConnectionCharset(), nil)
if err != nil {
return "", nil, 0, nil, err
}
Expand Down
9 changes: 6 additions & 3 deletions go/vt/vtgate/engine/aggregations.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ type aggregatorDistinct struct {
last sqltypes.Value
coll collations.ID
collationEnv *collations.Environment
values *evalengine.EnumSetValues
}

func (a *aggregatorDistinct) shouldReturn(row []sqltypes.Value) (bool, error) {
Expand All @@ -115,7 +116,7 @@ func (a *aggregatorDistinct) shouldReturn(row []sqltypes.Value) (bool, error) {
next := row[a.column]
if !last.IsNull() {
if last.TinyWeightCmp(next) == 0 {
cmp, err := evalengine.NullsafeCompare(last, next, a.collationEnv, a.coll)
cmp, err := evalengine.NullsafeCompare(last, next, a.collationEnv, a.coll, a.values)
if err != nil {
return true, err
}
Expand Down Expand Up @@ -386,6 +387,7 @@ func newAggregation(fields []*querypb.Field, aggregates []*AggregateParams) (agg
column: distinct,
coll: aggr.Type.Collation(),
collationEnv: aggr.CollationEnv,
values: aggr.Type.Values(),
},
}

Expand All @@ -405,22 +407,23 @@ func newAggregation(fields []*querypb.Field, aggregates []*AggregateParams) (agg
column: distinct,
coll: aggr.Type.Collation(),
collationEnv: aggr.CollationEnv,
values: aggr.Type.Values(),
},
}

case AggregateMin:
ag = &aggregatorMin{
aggregatorMinMax{
from: aggr.Col,
minmax: evalengine.NewAggregationMinMax(sourceType, aggr.CollationEnv, aggr.Type.Collation()),
minmax: evalengine.NewAggregationMinMax(sourceType, aggr.CollationEnv, aggr.Type.Collation(), aggr.Type.Values()),
},
}

case AggregateMax:
ag = &aggregatorMax{
aggregatorMinMax{
from: aggr.Col,
minmax: evalengine.NewAggregationMinMax(sourceType, aggr.CollationEnv, aggr.Type.Collation()),
minmax: evalengine.NewAggregationMinMax(sourceType, aggr.CollationEnv, aggr.Type.Collation(), aggr.Type.Values()),
},
}

Expand Down
26 changes: 20 additions & 6 deletions go/vt/vtgate/engine/cached_size.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions go/vt/vtgate/engine/distinct.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,14 @@ func (pt *probeTable) hashCodeForRow(inputRow sqltypes.Row) (vthash.Hash, error)
return vthash.Hash{}, vterrors.VT13001("index out of range in row when creating the DISTINCT hash code")
}
col := inputRow[checkCol.Col]
err := evalengine.NullsafeHashcode128(&hasher, col, checkCol.Type.Collation(), checkCol.Type.Type(), pt.sqlmode)
err := evalengine.NullsafeHashcode128(&hasher, col, checkCol.Type.Collation(), checkCol.Type.Type(), pt.sqlmode, checkCol.Type.Values())
if err != nil {
if err != evalengine.UnsupportedCollationHashError || checkCol.WsCol == nil {
return vthash.Hash{}, err
}
checkCol = checkCol.SwitchToWeightString()
pt.checkCols[i] = checkCol
err = evalengine.NullsafeHashcode128(&hasher, inputRow[checkCol.Col], checkCol.Type.Collation(), checkCol.Type.Type(), pt.sqlmode)
err = evalengine.NullsafeHashcode128(&hasher, inputRow[checkCol.Col], checkCol.Type.Collation(), checkCol.Type.Type(), pt.sqlmode, checkCol.Type.Values())
if err != nil {
return vthash.Hash{}, err
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/engine/distinct_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func TestDistinct(t *testing.T) {
}
checkCols = append(checkCols, CheckCol{
Col: i,
Type: evalengine.NewTypeEx(tc.inputs.Fields[i].Type, collID, false, 0, 0),
Type: evalengine.NewTypeEx(tc.inputs.Fields[i].Type, collID, false, 0, 0, nil),
CollationEnv: collations.MySQL8(),
})
}
Expand Down
13 changes: 9 additions & 4 deletions go/vt/vtgate/engine/hash_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ type (
ComparisonType querypb.Type

CollationEnv *collations.Environment

// Values for enum and set types
Values *evalengine.EnumSetValues
}

hashJoinProbeTable struct {
Expand All @@ -78,6 +81,7 @@ type (
cols []int
hasher vthash.Hasher
sqlmode evalengine.SQLMode
values *evalengine.EnumSetValues
}

probeTableEntry struct {
Expand All @@ -94,7 +98,7 @@ func (hj *HashJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars ma
return nil, err
}

pt := newHashJoinProbeTable(hj.Collation, hj.ComparisonType, hj.LHSKey, hj.RHSKey, hj.Cols)
pt := newHashJoinProbeTable(hj.Collation, hj.ComparisonType, hj.LHSKey, hj.RHSKey, hj.Cols, hj.Values)
// build the probe table from the LHS result
for _, row := range lresult.Rows {
err := pt.addLeftRow(row)
Expand Down Expand Up @@ -130,7 +134,7 @@ func (hj *HashJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars ma
// TryStreamExecute implements the Primitive interface
func (hj *HashJoin) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
// build the probe table from the LHS result
pt := newHashJoinProbeTable(hj.Collation, hj.ComparisonType, hj.LHSKey, hj.RHSKey, hj.Cols)
pt := newHashJoinProbeTable(hj.Collation, hj.ComparisonType, hj.LHSKey, hj.RHSKey, hj.Cols, hj.Values)
var lfields []*querypb.Field
var mu sync.Mutex
err := vcursor.StreamExecutePrimitive(ctx, hj.Left, bindVars, wantfields, func(result *sqltypes.Result) error {
Expand Down Expand Up @@ -260,7 +264,7 @@ func (hj *HashJoin) description() PrimitiveDescription {
}
}

func newHashJoinProbeTable(coll collations.ID, typ querypb.Type, lhsKey, rhsKey int, cols []int) *hashJoinProbeTable {
func newHashJoinProbeTable(coll collations.ID, typ querypb.Type, lhsKey, rhsKey int, cols []int, values *evalengine.EnumSetValues) *hashJoinProbeTable {
return &hashJoinProbeTable{
innerMap: map[vthash.Hash]*probeTableEntry{},
coll: coll,
Expand All @@ -269,6 +273,7 @@ func newHashJoinProbeTable(coll collations.ID, typ querypb.Type, lhsKey, rhsKey
rhsKey: rhsKey,
cols: cols,
hasher: vthash.New(),
values: values,
}
}

Expand All @@ -286,7 +291,7 @@ func (pt *hashJoinProbeTable) addLeftRow(r sqltypes.Row) error {
}

func (pt *hashJoinProbeTable) hash(val sqltypes.Value) (vthash.Hash, error) {
err := evalengine.NullsafeHashcode128(&pt.hasher, val, pt.coll, pt.typ, pt.sqlmode)
err := evalengine.NullsafeHashcode128(&pt.hasher, val, pt.coll, pt.typ, pt.sqlmode, pt.values)
if err != nil {
return vthash.Hash{}, err
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/engine/opcode/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ func (code AggregateOpcode) ResolveType(t evalengine.Type, env *collations.Envir
if code == AggregateAvg {
scale += 4
}
return evalengine.NewTypeEx(sqltype, collation, nullable, size, scale)
return evalengine.NewTypeEx(sqltype, collation, nullable, size, scale, t.Values())
}

func (code AggregateOpcode) NeedsComparableValues() bool {
Expand Down
Loading

0 comments on commit 4c2df48

Please sign in to comment.