Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

evalengine: Add support for enum and set #15783

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions go/mysql/json/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ func NewFromSQL(v sqltypes.Value) (*Value, error) {
return NewDate(v.RawStr()), nil
case v.IsTime():
return NewTime(v.RawStr()), nil
case v.IsEnum():
return NewString(v.RawStr()), nil
case v.IsSet():
return NewString(v.RawStr()), nil
default:
return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "cannot coerce %v as a JSON type", v)
}
Expand Down
36 changes: 36 additions & 0 deletions go/sqltypes/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,12 @@ var RandomGenerators = map[Type]RandomGenerator{
}
return v
},
Enum: func() Value {
return MakeTrusted(Enum, randEnum())
},
Set: func() Value {
return MakeTrusted(Set, randSet())
},
}

func randTime() time.Time {
Expand All @@ -289,3 +295,33 @@ func randTime() time.Time {
sec := rand.Int64N(delta) + min
return time.Unix(sec, 0)
}

func randEnum() []byte {
enums := []string{
"xxsmall",
"xsmall",
"small",
"medium",
"large",
"xlarge",
"xxlarge",
}
return []byte(enums[rand.IntN(len(enums))])
}

func randSet() []byte {
set := []string{
"a",
"b",
"c",
"d",
"e",
"f",
"g",
}
rand.Shuffle(len(set), func(i, j int) {
set[i], set[j] = set[j], set[i]
})
set = set[:rand.IntN(len(set))]
return []byte(strings.Join(set, ","))
}
10 changes: 10 additions & 0 deletions go/sqltypes/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,16 @@ func IsNull(t querypb.Type) bool {
return t == Null
}

// IsEnum returns true if the type is Enum type
func IsEnum(t querypb.Type) bool {
return t == Enum
}

// IsSet returns true if the type is Set type
func IsSet(t querypb.Type) bool {
return t == Set
}

// Vitess data types. These are idiomatically named synonyms for the querypb.Type values.
// Although these constants are interchangeable, they should be treated as different from querypb.Type.
// Use the synonyms only to refer to the type in Value. For proto variables, use the querypb.Type constants instead.
Expand Down
10 changes: 10 additions & 0 deletions go/sqltypes/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,16 @@ func (v Value) IsDecimal() bool {
return IsDecimal(v.Type())
}

// IsEnum returns true if Value is time.
func (v Value) IsEnum() bool {
return v.Type() == querypb.Type_ENUM
}

// IsSet returns true if Value is time.
dbussink marked this conversation as resolved.
Show resolved Hide resolved
func (v Value) IsSet() bool {
return v.Type() == querypb.Type_SET
}

// IsComparable returns true if the Value is null safe comparable without collation information.
func (v *Value) IsComparable() bool {
if v.Type() == Null || IsNumber(v.Type()) || IsBinary(v.Type()) {
Expand Down
15 changes: 14 additions & 1 deletion go/test/endtoend/vtgate/queries/misc/misc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func start(t *testing.T) (utils.MySQLCompare, func()) {
require.NoError(t, err)

deleteAll := func() {
tables := []string{"t1", "tbl", "unq_idx", "nonunq_idx", "uks.unsharded"}
tables := []string{"t1", "tbl", "unq_idx", "nonunq_idx", "tbl_enum_set", "uks.unsharded"}
for _, table := range tables {
_, _ = mcmp.ExecAndIgnore("delete from " + table)
}
Expand Down Expand Up @@ -452,3 +452,16 @@ func TestStraightJoin(t *testing.T) {
require.NoError(t, err)
require.Contains(t, fmt.Sprintf("%v", res.Rows), "t1_tbl")
}

func TestEnumSetVals(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 20, "vtgate")

mcmp, closer := start(t)
defer closer()
require.NoError(t, utils.WaitForAuthoritative(t, keyspaceName, "tbl_enum_set", clusterInstance.VtgateProcess.ReadVSchema))

mcmp.Exec("insert into tbl_enum_set(id, enum_col, set_col) values (1, 'medium', 'a,b,e'), (2, 'small', 'e,f,g'), (3, 'large', 'c'), (4, 'xsmall', 'a,b'), (5, 'medium', 'a,d')")

mcmp.AssertMatches("select id, enum_col, cast(enum_col as signed) from tbl_enum_set order by enum_col, id", `[[INT64(4) ENUM("xsmall") INT64(1)] [INT64(2) ENUM("small") INT64(2)] [INT64(1) ENUM("medium") INT64(3)] [INT64(5) ENUM("medium") INT64(3)] [INT64(3) ENUM("large") INT64(4)]]`)
mcmp.AssertMatches("select id, set_col, cast(set_col as unsigned) from tbl_enum_set order by set_col, id", `[[INT64(4) SET("a,b") UINT64(3)] [INT64(3) SET("c") UINT64(4)] [INT64(5) SET("a,d") UINT64(9)] [INT64(1) SET("a,b,e") UINT64(19)] [INT64(2) SET("e,f,g") UINT64(112)]]`)
}
8 changes: 8 additions & 0 deletions go/test/endtoend/vtgate/queries/misc/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,11 @@ create table tbl
primary key (id),
unique (unq_col)
) Engine = InnoDB;

create table tbl_enum_set
(
id bigint,
enum_col enum('xsmall', 'small', 'medium', 'large', 'xlarge'),
set_col set('a', 'b', 'c', 'd', 'e', 'f', 'g'),
primary key (id)
) Engine = InnoDB;
8 changes: 8 additions & 0 deletions go/test/endtoend/vtgate/queries/misc/vschema.json
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@
}
]
},
"tbl_enum_set": {
"column_vindexes": [
{
"column": "id",
"name": "hash"
}
]
},
"unq_idx": {
"column_vindexes": [
{
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtexplain/vtexplain_vttablet.go
Original file line number Diff line number Diff line change
Expand Up @@ -755,7 +755,7 @@ func (t *explainTablet) analyzeWhere(selStmt *sqlparser.Select, tableColumnMap m
// Check if we have a duplicate value
isNewValue := true
for _, v := range inVal {
result, err := evalengine.NullsafeCompare(v, value, t.collationEnv, t.collationEnv.DefaultConnectionCharset())
result, err := evalengine.NullsafeCompare(v, value, t.collationEnv, t.collationEnv.DefaultConnectionCharset(), nil)
if err != nil {
return "", nil, 0, nil, err
}
Expand Down
9 changes: 6 additions & 3 deletions go/vt/vtgate/engine/aggregations.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ type aggregatorDistinct struct {
last sqltypes.Value
coll collations.ID
collationEnv *collations.Environment
values *evalengine.EnumSetValues
}

func (a *aggregatorDistinct) shouldReturn(row []sqltypes.Value) (bool, error) {
Expand All @@ -115,7 +116,7 @@ func (a *aggregatorDistinct) shouldReturn(row []sqltypes.Value) (bool, error) {
next := row[a.column]
if !last.IsNull() {
if last.TinyWeightCmp(next) == 0 {
cmp, err := evalengine.NullsafeCompare(last, next, a.collationEnv, a.coll)
cmp, err := evalengine.NullsafeCompare(last, next, a.collationEnv, a.coll, a.values)
if err != nil {
return true, err
}
Expand Down Expand Up @@ -386,6 +387,7 @@ func newAggregation(fields []*querypb.Field, aggregates []*AggregateParams) (agg
column: distinct,
coll: aggr.Type.Collation(),
collationEnv: aggr.CollationEnv,
values: aggr.Type.Values(),
},
}

Expand All @@ -405,22 +407,23 @@ func newAggregation(fields []*querypb.Field, aggregates []*AggregateParams) (agg
column: distinct,
coll: aggr.Type.Collation(),
collationEnv: aggr.CollationEnv,
values: aggr.Type.Values(),
},
}

case AggregateMin:
ag = &aggregatorMin{
aggregatorMinMax{
from: aggr.Col,
minmax: evalengine.NewAggregationMinMax(sourceType, aggr.CollationEnv, aggr.Type.Collation()),
minmax: evalengine.NewAggregationMinMax(sourceType, aggr.CollationEnv, aggr.Type.Collation(), aggr.Type.Values()),
},
}

case AggregateMax:
ag = &aggregatorMax{
aggregatorMinMax{
from: aggr.Col,
minmax: evalengine.NewAggregationMinMax(sourceType, aggr.CollationEnv, aggr.Type.Collation()),
minmax: evalengine.NewAggregationMinMax(sourceType, aggr.CollationEnv, aggr.Type.Collation(), aggr.Type.Values()),
},
}

Expand Down
26 changes: 20 additions & 6 deletions go/vt/vtgate/engine/cached_size.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions go/vt/vtgate/engine/distinct.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,14 @@ func (pt *probeTable) hashCodeForRow(inputRow sqltypes.Row) (vthash.Hash, error)
return vthash.Hash{}, vterrors.VT13001("index out of range in row when creating the DISTINCT hash code")
}
col := inputRow[checkCol.Col]
err := evalengine.NullsafeHashcode128(&hasher, col, checkCol.Type.Collation(), checkCol.Type.Type(), pt.sqlmode)
err := evalengine.NullsafeHashcode128(&hasher, col, checkCol.Type.Collation(), checkCol.Type.Type(), pt.sqlmode, checkCol.Type.Values())
if err != nil {
if err != evalengine.UnsupportedCollationHashError || checkCol.WsCol == nil {
return vthash.Hash{}, err
}
checkCol = checkCol.SwitchToWeightString()
pt.checkCols[i] = checkCol
err = evalengine.NullsafeHashcode128(&hasher, inputRow[checkCol.Col], checkCol.Type.Collation(), checkCol.Type.Type(), pt.sqlmode)
err = evalengine.NullsafeHashcode128(&hasher, inputRow[checkCol.Col], checkCol.Type.Collation(), checkCol.Type.Type(), pt.sqlmode, checkCol.Type.Values())
if err != nil {
return vthash.Hash{}, err
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/engine/distinct_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func TestDistinct(t *testing.T) {
}
checkCols = append(checkCols, CheckCol{
Col: i,
Type: evalengine.NewTypeEx(tc.inputs.Fields[i].Type, collID, false, 0, 0),
Type: evalengine.NewTypeEx(tc.inputs.Fields[i].Type, collID, false, 0, 0, nil),
CollationEnv: collations.MySQL8(),
})
}
Expand Down
13 changes: 9 additions & 4 deletions go/vt/vtgate/engine/hash_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ type (
ComparisonType querypb.Type

CollationEnv *collations.Environment

// Values for enum and set types
Values *evalengine.EnumSetValues
}

hashJoinProbeTable struct {
Expand All @@ -78,6 +81,7 @@ type (
cols []int
hasher vthash.Hasher
sqlmode evalengine.SQLMode
values *evalengine.EnumSetValues
}

probeTableEntry struct {
Expand All @@ -94,7 +98,7 @@ func (hj *HashJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars ma
return nil, err
}

pt := newHashJoinProbeTable(hj.Collation, hj.ComparisonType, hj.LHSKey, hj.RHSKey, hj.Cols)
pt := newHashJoinProbeTable(hj.Collation, hj.ComparisonType, hj.LHSKey, hj.RHSKey, hj.Cols, hj.Values)
// build the probe table from the LHS result
for _, row := range lresult.Rows {
err := pt.addLeftRow(row)
Expand Down Expand Up @@ -130,7 +134,7 @@ func (hj *HashJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars ma
// TryStreamExecute implements the Primitive interface
func (hj *HashJoin) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
// build the probe table from the LHS result
pt := newHashJoinProbeTable(hj.Collation, hj.ComparisonType, hj.LHSKey, hj.RHSKey, hj.Cols)
pt := newHashJoinProbeTable(hj.Collation, hj.ComparisonType, hj.LHSKey, hj.RHSKey, hj.Cols, hj.Values)
var lfields []*querypb.Field
var mu sync.Mutex
err := vcursor.StreamExecutePrimitive(ctx, hj.Left, bindVars, wantfields, func(result *sqltypes.Result) error {
Expand Down Expand Up @@ -260,7 +264,7 @@ func (hj *HashJoin) description() PrimitiveDescription {
}
}

func newHashJoinProbeTable(coll collations.ID, typ querypb.Type, lhsKey, rhsKey int, cols []int) *hashJoinProbeTable {
func newHashJoinProbeTable(coll collations.ID, typ querypb.Type, lhsKey, rhsKey int, cols []int, values *evalengine.EnumSetValues) *hashJoinProbeTable {
return &hashJoinProbeTable{
innerMap: map[vthash.Hash]*probeTableEntry{},
coll: coll,
Expand All @@ -269,6 +273,7 @@ func newHashJoinProbeTable(coll collations.ID, typ querypb.Type, lhsKey, rhsKey
rhsKey: rhsKey,
cols: cols,
hasher: vthash.New(),
values: values,
}
}

Expand All @@ -286,7 +291,7 @@ func (pt *hashJoinProbeTable) addLeftRow(r sqltypes.Row) error {
}

func (pt *hashJoinProbeTable) hash(val sqltypes.Value) (vthash.Hash, error) {
err := evalengine.NullsafeHashcode128(&pt.hasher, val, pt.coll, pt.typ, pt.sqlmode)
err := evalengine.NullsafeHashcode128(&pt.hasher, val, pt.coll, pt.typ, pt.sqlmode, pt.values)
if err != nil {
return vthash.Hash{}, err
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/engine/opcode/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ func (code AggregateOpcode) ResolveType(t evalengine.Type, env *collations.Envir
if code == AggregateAvg {
scale += 4
}
return evalengine.NewTypeEx(sqltype, collation, nullable, size, scale)
return evalengine.NewTypeEx(sqltype, collation, nullable, size, scale, t.Values())
}

func (code AggregateOpcode) NeedsComparableValues() bool {
Expand Down
Loading
Loading