Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: remove logical plan interface #16006

Merged
merged 18 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions go/vt/sqlparser/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -677,8 +677,8 @@ const (
NotRegexpOp
)

func Inverse(in ComparisonExprOperator) ComparisonExprOperator {
switch in {
func (op ComparisonExprOperator) Inverse() ComparisonExprOperator {
switch op {
case EqualOp:
return NotEqualOp
case LessThanOp:
Expand Down Expand Up @@ -709,6 +709,15 @@ func Inverse(in ComparisonExprOperator) ComparisonExprOperator {
panic("unreachable")
}

func (op ComparisonExprOperator) IsCommutative() bool {
switch op {
case EqualOp, NotEqualOp, NullSafeEqualOp:
return true
default:
return false
}
}

// Constant for Enum Type - IsExprOperator
const (
IsNullOp IsExprOperator = iota
Expand Down
8 changes: 1 addition & 7 deletions go/vt/vtgate/engine/cached_size.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 0 additions & 5 deletions go/vt/vtgate/engine/memory_sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,6 @@ func (ms *MemorySort) GetTableName() string {
return ms.Input.GetTableName()
}

// SetTruncateColumnCount sets the truncate column count.
func (ms *MemorySort) SetTruncateColumnCount(count int) {
ms.TruncateColumnCount = count
}

// TryExecute satisfies the Primitive interface.
func (ms *MemorySort) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) {
count, err := ms.fetchCount(ctx, vcursor, bindVars)
Expand Down
11 changes: 2 additions & 9 deletions go/vt/vtgate/engine/ordered_aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ type OrderedAggregate struct {

// Input is the primitive that will feed into this Primitive.
Input Primitive

CollationEnv *collations.Environment
}

// GroupByParams specify the grouping key to be used.
Expand Down Expand Up @@ -96,11 +94,6 @@ func (oa *OrderedAggregate) GetTableName() string {
return oa.Input.GetTableName()
}

// SetTruncateColumnCount sets the truncate column count.
func (oa *OrderedAggregate) SetTruncateColumnCount(count int) {
oa.TruncateColumnCount = count
}

// TryExecute is a Primitive function.
func (oa *OrderedAggregate) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, _ bool) (*sqltypes.Result, error) {
qr, err := oa.execute(ctx, vcursor, bindVars)
Expand Down Expand Up @@ -344,14 +337,14 @@ func (oa *OrderedAggregate) nextGroupBy(currentKey, nextRow []sqltypes.Value) (n
return nextRow, true, nil
}

cmp, err := evalengine.NullsafeCompare(v1, v2, oa.CollationEnv, gb.Type.Collation(), gb.Type.Values())
cmp, err := evalengine.NullsafeCompare(v1, v2, gb.CollationEnv, gb.Type.Collation(), gb.Type.Values())
if err != nil {
_, isCollationErr := err.(evalengine.UnsupportedCollationError)
if !isCollationErr || gb.WeightStringCol == -1 {
return nil, false, err
}
gb.KeyCol = gb.WeightStringCol
cmp, err = evalengine.NullsafeCompare(currentKey[gb.WeightStringCol], nextRow[gb.WeightStringCol], oa.CollationEnv, gb.Type.Collation(), gb.Type.Values())
cmp, err = evalengine.NullsafeCompare(currentKey[gb.WeightStringCol], nextRow[gb.WeightStringCol], gb.CollationEnv, gb.Type.Collation(), gb.Type.Values())
if err != nil {
return nil, false, err
}
Expand Down
5 changes: 0 additions & 5 deletions go/vt/vtgate/engine/route.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,6 @@ func (route *Route) GetTableName() string {
return route.TableName
}

// SetTruncateColumnCount sets the truncate column count.
func (route *Route) SetTruncateColumnCount(count int) {
route.TruncateColumnCount = count
}

// TryExecute performs a non-streaming exec.
func (route *Route) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) {
ctx, cancelFunc := addQueryTimeout(ctx, vcursor, route.QueryTimeout)
Expand Down
48 changes: 5 additions & 43 deletions go/vt/vtgate/engine/semi_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ package engine

import (
"context"
"fmt"
"strings"

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
Expand All @@ -33,14 +31,6 @@ type SemiJoin struct {
// of the SemiJoin. They can be any primitive.
Left, Right Primitive `json:",omitempty"`

// Cols defines which columns from the left
// results should be used to build the
// return result. For results coming from the
// left query, the index values go as -1, -2, etc.
// If Cols is {-1, -2}, it means that
// the returned result will be {Left0, Left1}.
Cols []int `json:",omitempty"`

// Vars defines the list of SemiJoinVars that need to
// be built from the LHS result before invoking
// the RHS subquery.
Expand All @@ -54,7 +44,7 @@ func (jn *SemiJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars ma
if err != nil {
return nil, err
}
result := &sqltypes.Result{Fields: projectFields(lresult.Fields, jn.Cols)}
result := &sqltypes.Result{Fields: lresult.Fields}
for _, lrow := range lresult.Rows {
for k, col := range jn.Vars {
joinVars[k] = sqltypes.ValueBindVariable(lrow[col])
Expand All @@ -64,7 +54,7 @@ func (jn *SemiJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars ma
return nil, err
}
if len(rresult.Rows) > 0 {
result.Rows = append(result.Rows, projectRows(lrow, jn.Cols))
result.Rows = append(result.Rows, lrow)
}
}
return result, nil
Expand All @@ -74,15 +64,15 @@ func (jn *SemiJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars ma
func (jn *SemiJoin) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
joinVars := make(map[string]*querypb.BindVariable)
err := vcursor.StreamExecutePrimitive(ctx, jn.Left, bindVars, wantfields, func(lresult *sqltypes.Result) error {
result := &sqltypes.Result{Fields: projectFields(lresult.Fields, jn.Cols)}
result := &sqltypes.Result{Fields: lresult.Fields}
for _, lrow := range lresult.Rows {
for k, col := range jn.Vars {
joinVars[k] = sqltypes.ValueBindVariable(lrow[col])
}
rowAdded := false
err := vcursor.StreamExecutePrimitive(ctx, jn.Right, combineVars(bindVars, joinVars), false, func(rresult *sqltypes.Result) error {
if len(rresult.Rows) > 0 && !rowAdded {
result.Rows = append(result.Rows, projectRows(lrow, jn.Cols))
result.Rows = append(result.Rows, lrow)
rowAdded = true
}
return nil
Expand Down Expand Up @@ -135,8 +125,7 @@ func (jn *SemiJoin) NeedsTransaction() bool {

func (jn *SemiJoin) description() PrimitiveDescription {
other := map[string]any{
"TableName": jn.GetTableName(),
"ProjectedIndexes": strings.Trim(strings.Join(strings.Fields(fmt.Sprint(jn.Cols)), ","), "[]"),
"TableName": jn.GetTableName(),
}
if len(jn.Vars) > 0 {
other["JoinVars"] = orderedStringIntMap(jn.Vars)
Expand All @@ -146,30 +135,3 @@ func (jn *SemiJoin) description() PrimitiveDescription {
Other: other,
}
}

func projectFields(lfields []*querypb.Field, cols []int) []*querypb.Field {
if lfields == nil {
return nil
}
if len(cols) == 0 {
return lfields
}
fields := make([]*querypb.Field, len(cols))
for i, index := range cols {
fields[i] = lfields[-index-1]
}
return fields
}

func projectRows(lrow []sqltypes.Value, cols []int) []sqltypes.Value {
if len(cols) == 0 {
return lrow
}
row := make([]sqltypes.Value, len(cols))
for i, index := range cols {
if index < 0 {
row[i] = lrow[-index-1]
}
}
return row
}
2 changes: 0 additions & 2 deletions go/vt/vtgate/engine/semi_join_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ func TestSemiJoinExecute(t *testing.T) {
Vars: map[string]int{
"bv": 1,
},
Cols: []int{-1, -2, -3},
}
r, err := jn.TryExecute(context.Background(), &noopVCursor{}, bv, true)
require.NoError(t, err)
Expand Down Expand Up @@ -139,7 +138,6 @@ func TestSemiJoinStreamExecute(t *testing.T) {
Vars: map[string]int{
"bv": 1,
},
Cols: []int{-1, -2, -3},
}
r, err := wrapStreamExecute(jn, &noopVCursor{}, map[string]*querypb.BindVariable{}, true)
require.NoError(t, err)
Expand Down
18 changes: 9 additions & 9 deletions go/vt/vtgate/engine/sql_calc_found_rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,22 @@ type SQLCalcFoundRows struct {
}

// RouteType implements the Primitive interface
func (s SQLCalcFoundRows) RouteType() string {
func (s *SQLCalcFoundRows) RouteType() string {
return "SQLCalcFoundRows"
}

// GetKeyspaceName implements the Primitive interface
func (s SQLCalcFoundRows) GetKeyspaceName() string {
func (s *SQLCalcFoundRows) GetKeyspaceName() string {
return s.LimitPrimitive.GetKeyspaceName()
}

// GetTableName implements the Primitive interface
func (s SQLCalcFoundRows) GetTableName() string {
func (s *SQLCalcFoundRows) GetTableName() string {
return s.LimitPrimitive.GetTableName()
}

// TryExecute implements the Primitive interface
func (s SQLCalcFoundRows) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) {
func (s *SQLCalcFoundRows) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) {
limitQr, err := vcursor.ExecutePrimitive(ctx, s.LimitPrimitive, bindVars, wantfields)
if err != nil {
return nil, err
Expand All @@ -70,7 +70,7 @@ func (s SQLCalcFoundRows) TryExecute(ctx context.Context, vcursor VCursor, bindV
}

// TryStreamExecute implements the Primitive interface
func (s SQLCalcFoundRows) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
func (s *SQLCalcFoundRows) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
err := vcursor.StreamExecutePrimitive(ctx, s.LimitPrimitive, bindVars, wantfields, callback)
if err != nil {
return err
Expand Down Expand Up @@ -104,21 +104,21 @@ func (s SQLCalcFoundRows) TryStreamExecute(ctx context.Context, vcursor VCursor,
}

// GetFields implements the Primitive interface
func (s SQLCalcFoundRows) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
func (s *SQLCalcFoundRows) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
return s.LimitPrimitive.GetFields(ctx, vcursor, bindVars)
}

// NeedsTransaction implements the Primitive interface
func (s SQLCalcFoundRows) NeedsTransaction() bool {
func (s *SQLCalcFoundRows) NeedsTransaction() bool {
return s.LimitPrimitive.NeedsTransaction()
}

// Inputs implements the Primitive interface
func (s SQLCalcFoundRows) Inputs() ([]Primitive, []map[string]any) {
func (s *SQLCalcFoundRows) Inputs() ([]Primitive, []map[string]any) {
return []Primitive{s.LimitPrimitive, s.CountPrimitive}, nil
}

func (s SQLCalcFoundRows) description() PrimitiveDescription {
func (s *SQLCalcFoundRows) description() PrimitiveDescription {
return PrimitiveDescription{
OperatorType: "SQL_CALC_FOUND_ROWS",
}
Expand Down
28 changes: 0 additions & 28 deletions go/vt/vtgate/engine/update_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,34 +209,6 @@ func TestUpdateEqualNoRoute(t *testing.T) {
})
}

func TestUpdateEqualNoScatter(t *testing.T) {
t.Skip("planner does not produces this plan anymore")
vindex, _ := vindexes.CreateVindex("lookup_unique", "", map[string]string{
"table": "lkp",
"from": "from",
"to": "toc",
"write_only": "true",
})
upd := &Update{
DML: &DML{
RoutingParameters: &RoutingParameters{
Opcode: Equal,
Keyspace: &vindexes.Keyspace{
Name: "ks",
Sharded: true,
},
Vindex: vindex,
Values: []evalengine.Expr{evalengine.NewLiteralInt(1)},
},
Query: "dummy_update",
},
}

vc := newDMLTestVCursor("0")
_, err := upd.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false)
require.EqualError(t, err, `cannot map vindex to unique keyspace id: DestinationKeyRange(-)`)
}

func TestUpdateEqualChangedVindex(t *testing.T) {
ks := buildTestVSchema().Keyspaces["sharded"]
upd := &Update{
Expand Down
4 changes: 0 additions & 4 deletions go/vt/vtgate/planbuilder/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,6 @@ var (
)

type (
truncater interface {
SetTruncateColumnCount(int)
}

planResult struct {
primitive engine.Primitive
tables []string
Expand Down
41 changes: 0 additions & 41 deletions go/vt/vtgate/planbuilder/concatenate.go

This file was deleted.

3 changes: 0 additions & 3 deletions go/vt/vtgate/planbuilder/ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,6 @@ func buildCreateViewCommon(
sqlparser.RemoveKeyspace(ddl)

if vschema.IsViewsEnabled() {
if keyspace == nil {
return nil, nil, vterrors.VT09005()
}
return destination, keyspace, nil
}
isRoutePlan, opCode := tryToGetRoutePlan(selectPlan.primitive)
Expand Down
Loading
Loading