Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
systay committed Sep 5, 2024
1 parent d19c52d commit 90ef5f8
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 35 deletions.
28 changes: 10 additions & 18 deletions go/vt/vtgate/planbuilder/operators/query_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,32 +155,24 @@ func tryConvertApplyToValuesJoin(ctx *plancontext.PlanningContext, in *ApplyJoin
return in, NoRewrite
}

valuesTable := "values"
rhsID := TableID(in.LHS)
if rhsID.NumberOfTables() == 1 {
tbl, err := ctx.SemTable.TableInfoFor(rhsID)
if err != nil {
return in, NoRewrite
}
name, err := tbl.Name()
if err == nil {
valuesTable = sqlparser.String(name)
}
}
valuesTable = ctx.GetReservedArgumentForString(valuesTable)

vj := newValuesJoin(ctx, in.LHS, in.RHS, in.JoinType)
newRouteSrc := &ValuesTable{
unaryOperator: newUnaryOp(r.Source),
ListArgName: valuesTable,
ListArgName: vj.ListArg,
TableName: vj.TableName,
}
r.Source = newRouteSrc
var vj Operator = newValuesJoin(in.LHS, in.RHS, in.JoinType, valuesTable)

// we need to add the join predicates to the new ValuesJoin
var op Operator = vj
for _, column := range in.JoinPredicates.columns {
vj = vj.AddPredicate(ctx, column.Original)
op = op.AddPredicate(ctx, column.Original)
}

// TODO: Figure out what to do about the routing
routing := r.Routing.(*ShardedRouting)
routing.RouteOpCode = engine.IN
return vj, Rewrote("ApplyJoin to ValuesJoin")
return op, Rewrote("ApplyJoin to ValuesJoin")
}

func tryPushDelete(in *Delete) (Operator, *ApplyResult) {
Expand Down
16 changes: 14 additions & 2 deletions go/vt/vtgate/planbuilder/operators/route_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package operators
import (
"bytes"
"io"
"vitess.io/vitess/go/mysql/capabilities"

querypb "vitess.io/vitess/go/vt/proto/query"
topodatapb "vitess.io/vitess/go/vt/proto/topodata"
Expand Down Expand Up @@ -287,6 +288,17 @@ func requiresSwitchingSides(ctx *plancontext.PlanningContext, op Operator) (requ
return
}

func createVersionJoin(ctx *plancontext.PlanningContext, lhs, rhs Operator, joinType sqlparser.JoinType) JoinOp {
ok, err := capabilities.MySQLVersionHasCapability(ctx.VSchema.Environment().MySQLVersion(), capabilities.ValuesRow)
if err != nil {
panic(err)
}
if !ok {
return NewApplyJoin(ctx, Clone(lhs), Clone(rhs), nil, joinType)
}
return newValuesJoin(ctx, lhs, rhs, joinType)
}

func mergeOrJoin(ctx *plancontext.PlanningContext, lhs, rhs Operator, joinPredicates []sqlparser.Expr, joinType sqlparser.JoinType) (Operator, *ApplyResult) {
jm := newJoinMerge(joinPredicates, joinType)
newPlan := jm.mergeJoinInputs(ctx, lhs, rhs, joinPredicates)
Expand All @@ -305,14 +317,14 @@ func mergeOrJoin(ctx *plancontext.PlanningContext, lhs, rhs Operator, joinPredic
return join, Rewrote("use a hash join because we have LIMIT on the LHS")
}

join := NewApplyJoin(ctx, Clone(rhs), Clone(lhs), nil, joinType)
join := createVersionJoin(ctx, Clone(rhs), Clone(lhs), joinType)
for _, pred := range joinPredicates {
join.AddJoinPredicate(ctx, pred)
}
return join, Rewrote("logical join to applyJoin, switching side because LIMIT")
}

join := NewApplyJoin(ctx, Clone(lhs), Clone(rhs), nil, joinType)
join := createVersionJoin(ctx, Clone(lhs), Clone(rhs), joinType)
for _, pred := range joinPredicates {
join.AddJoinPredicate(ctx, pred)
}
Expand Down
33 changes: 23 additions & 10 deletions go/vt/vtgate/planbuilder/operators/values_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ import (
type ValuesJoin struct {
binaryOperator

JoinType sqlparser.JoinType
LHSExprs sqlparser.Exprs
ListArg string // the bindvar name for the list of values
JoinType sqlparser.JoinType
LHSExprs sqlparser.Exprs
ListArg string // the bindvar name for the list of values
TableName string // the name of the derived table that will be created

// 👇Done at offset planning time 👇

Expand All @@ -39,10 +40,24 @@ type ValuesJoin struct {
Columns []string
}

func newValuesJoin(lhs, rhs Operator, joinType sqlparser.JoinType, listArg string) *ValuesJoin {
func newValuesJoin(ctx *plancontext.PlanningContext, lhs, rhs Operator, joinType sqlparser.JoinType) *ValuesJoin {
// If the RHS is a single table, we'll use its name as the derived table name, otherwise we'll use "v"
tblName := "v"
rhsID := TableID(lhs)
if rhsID.NumberOfTables() == 1 {
tbl, err := ctx.SemTable.TableInfoFor(rhsID)
if err == nil {
name, err := tbl.Name()
if err == nil {
tblName = sqlparser.String(name)
}
}
}
listArg := ctx.GetReservedArgumentForString(tblName)
return &ValuesJoin{
binaryOperator: newBinaryOp(lhs, rhs),
JoinType: joinType,
TableName: tblName,
ListArg: listArg,
}
}
Expand Down Expand Up @@ -85,7 +100,7 @@ func (vj *ValuesJoin) GetSelectExprs(ctx *plancontext.PlanningContext) sqlparser
}

func (vj *ValuesJoin) ShortDescription() string {
return ""
return vj.TableName
}

func (vj *ValuesJoin) GetOrdering(ctx *plancontext.PlanningContext) []OrderBy {
Expand All @@ -111,14 +126,12 @@ func (vj *ValuesJoin) AddJoinPredicate(ctx *plancontext.PlanningContext, expr sq
for _, pred := range predicates {
col := breakExpressionInLHSandRHS(ctx, pred, TableID(vj.LHS))

if col.IsPureLeft() {
if col.IsPureLeft() && vj.JoinType.IsInner() {
// If the predicate doesn't reference the RHS, we can add it to the LHS
// This is only valid for inner joins
vj.LHS = vj.LHS.AddPredicate(ctx, pred)
} else {
vj.addLHSExprs(col.LHSExprs)
err := ctx.SkipJoinPredicates(pred)
if err != nil {
panic(err)
}
vj.RHS = vj.RHS.AddPredicate(ctx, pred)
}
}
Expand Down
1 change: 1 addition & 0 deletions go/vt/vtgate/planbuilder/operators/values_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type ValuesTable struct {
unaryOperator

ListArgName string
TableName string
}

func (v *ValuesTable) Clone(inputs []Operator) Operator {
Expand Down
8 changes: 4 additions & 4 deletions go/vt/vtgate/planbuilder/plancontext/planning_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ func (ctx *PlanningContext) AddJoinPredicates(joinPred sqlparser.Expr, predicate
// the original predicate will be used.
func (ctx *PlanningContext) SkipJoinPredicates(joinPred sqlparser.Expr) error {
fn := func(_ sqlparser.Expr, rhsExprs []sqlparser.Expr) {
ctx.skipThesePredicates(rhsExprs...)
ctx.SkipThesePredicates(rhsExprs...)
}
if ctx.execOnJoinPredicateEqual(joinPred, fn) {
return nil
Expand All @@ -191,12 +191,12 @@ func (ctx *PlanningContext) KeepPredicateInfo(other *PlanningContext) {
ctx.AddJoinPredicates(k, v...)
}
for expr := range other.skipPredicates {
ctx.skipThesePredicates(expr)
ctx.SkipThesePredicates(expr)
}
}

// skipThesePredicates is a utility function to exclude certain predicates from SQL building
func (ctx *PlanningContext) skipThesePredicates(preds ...sqlparser.Expr) {
// SkipThesePredicates marks the given predicates as irrelevant for the current planning stage.
func (ctx *PlanningContext) SkipThesePredicates(preds ...sqlparser.Expr) {
outer:
for _, expr := range preds {
for k := range ctx.skipPredicates {
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/testdata/onecase.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[
{
"comment": "Add your test case here for debugging and run go test -run=One.",
"query": "",
"query": "select u.foo+ue.bar from user u join user_extra ue on u.foo = ue.user_id",
"plan": {
}
}
Expand Down

0 comments on commit 90ef5f8

Please sign in to comment.