Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <andres@planetscale.com>
  • Loading branch information
systay committed Aug 28, 2024
1 parent b0c5192 commit aa788bd
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 48 deletions.
11 changes: 3 additions & 8 deletions go/vt/vtctl/workflow/stream_migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -1130,15 +1130,15 @@ func (sm *StreamMigrator) templatizeRule(ctx context.Context, rule *binlogdatapb
case rule.Filter == vreplication.ExcludeStr:
return StreamTypeUnknown, fmt.Errorf("unexpected rule in vreplication: %v", rule)
default:
if err := sm.templatizeKeyRange(ctx, rule); err != nil {
if err := sm.templatizeKeyRange(rule); err != nil {
return StreamTypeUnknown, err
}

return StreamTypeSharded, nil
}
}

func (sm *StreamMigrator) templatizeKeyRange(ctx context.Context, rule *binlogdatapb.Rule) error {
func (sm *StreamMigrator) templatizeKeyRange(rule *binlogdatapb.Rule) error {
statement, err := sm.parser.Parse(rule.Filter)
if err != nil {
return err
Expand All @@ -1149,12 +1149,7 @@ func (sm *StreamMigrator) templatizeKeyRange(ctx context.Context, rule *binlogda
return fmt.Errorf("unexpected query: %v", rule.Filter)
}

var expr sqlparser.Expr
if sel.Where != nil {
expr = sel.Where.Expr
}

exprs := sqlparser.SplitAndExpression(nil, expr)
exprs := sqlparser.SplitAndExpression(nil, sel.GetWherePredicate())
for _, subexpr := range exprs {
funcExpr, ok := subexpr.(*sqlparser.FuncExpr)
if !ok || !funcExpr.Name.EqualString("in_keyrange") {
Expand Down
48 changes: 11 additions & 37 deletions go/vt/vtctl/workflow/vexec/query_planner.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ func (planner *VReplicationQueryPlanner) planDelete(del *sqlparser.Delete) (*Fix
)
}

del.Where = addDefaultWheres(planner, del.Where)
addDefaultWheres(planner, del)

buf := sqlparser.NewTrackedBuffer(nil)
buf.Myprintf("%v", del)
Expand All @@ -194,7 +194,7 @@ func (planner *VReplicationQueryPlanner) planDelete(del *sqlparser.Delete) (*Fix
}

func (planner *VReplicationQueryPlanner) planSelect(sel *sqlparser.Select) (*FixedQueryPlan, error) {
sel.Where = addDefaultWheres(planner, sel.Where)
addDefaultWheres(planner, sel)

buf := sqlparser.NewTrackedBuffer(nil)
buf.Myprintf("%v", sel)
Expand Down Expand Up @@ -230,7 +230,7 @@ func (planner *VReplicationQueryPlanner) planUpdate(upd *sqlparser.Update) (*Fix
}
}

upd.Where = addDefaultWheres(planner, upd.Where)
addDefaultWheres(planner, upd)

buf := sqlparser.NewTrackedBuffer(nil)
buf.Myprintf("%v", upd)
Expand Down Expand Up @@ -289,8 +289,7 @@ func (planner *VReplicationLogQueryPlanner) QueryParams() QueryParams {
}

func (planner *VReplicationLogQueryPlanner) planSelect(sel *sqlparser.Select) (QueryPlan, error) {
where := sel.Where
cols := extractWhereComparisonColumns(where)
cols := extractWhereComparisonColumns(sel.GetWherePredicate())
hasVReplIDCol := false

for _, col := range cols {
Expand All @@ -313,10 +312,6 @@ func (planner *VReplicationLogQueryPlanner) planSelect(sel *sqlparser.Select) (Q
// streamIDs.
queriesByTarget := make(map[string]*sqlparser.ParsedQuery, len(planner.tabletStreamIDs))
for target, streamIDs := range planner.tabletStreamIDs {
targetWhere := &sqlparser.Where{
Type: sqlparser.WhereClause,
}

var expr sqlparser.Expr
switch len(streamIDs) {
case 0: // WHERE vreplication_log.vrepl_id IN () => WHERE 1 != 1
Expand Down Expand Up @@ -349,15 +344,7 @@ func (planner *VReplicationLogQueryPlanner) planSelect(sel *sqlparser.Select) (Q
Right: tuple,
}
}

switch where {
case nil:
targetWhere.Expr = expr
default:
targetWhere.Expr = sqlparser.CreateAndExpr(expr, where.Expr)
}

sel.Where = targetWhere
sel.AddWhere(expr)

buf := sqlparser.NewTrackedBuffer(nil)
buf.Myprintf("%v", sel)
Expand All @@ -371,8 +358,8 @@ func (planner *VReplicationLogQueryPlanner) planSelect(sel *sqlparser.Select) (Q
}, nil
}

func addDefaultWheres(planner QueryPlanner, where *sqlparser.Where) *sqlparser.Where {
cols := extractWhereComparisonColumns(where)
func addDefaultWheres(planner QueryPlanner, stmt sqlparser.WhereAble) {
cols := extractWhereComparisonColumns(stmt.GetWherePredicate())

params := planner.QueryParams()
hasDBNameCol := false
Expand All @@ -387,8 +374,6 @@ func addDefaultWheres(planner QueryPlanner, where *sqlparser.Where) *sqlparser.W
}
}

newWhere := where

if !hasDBNameCol {
expr := &sqlparser.ComparisonExpr{
Left: &sqlparser.ColName{
Expand All @@ -398,15 +383,7 @@ func addDefaultWheres(planner QueryPlanner, where *sqlparser.Where) *sqlparser.W
Right: sqlparser.NewStrLiteral(params.DBName),
}

switch newWhere {
case nil:
newWhere = &sqlparser.Where{
Type: sqlparser.WhereClause,
Expr: expr,
}
default:
newWhere.Expr = sqlparser.CreateAndExpr(newWhere.Expr, expr)
}
stmt.AddWhere(expr)
}

if !hasWorkflowCol && params.Workflow != "" {
Expand All @@ -417,23 +394,20 @@ func addDefaultWheres(planner QueryPlanner, where *sqlparser.Where) *sqlparser.W
Operator: sqlparser.EqualOp,
Right: sqlparser.NewStrLiteral(params.Workflow),
}

newWhere.Expr = sqlparser.CreateAndExpr(newWhere.Expr, expr)
stmt.AddWhere(expr)
}

return newWhere
}

// extractWhereComparisonColumns extracts the column names used in AND-ed
// comparison expressions in a where clause, given the following assumptions:
// - (1) The column name is always the left-hand side of the comparison.
// - (2) There are no compound expressions within the where clause involving OR.
func extractWhereComparisonColumns(where *sqlparser.Where) []string {
func extractWhereComparisonColumns(where sqlparser.Expr) []string {
if where == nil {
return nil
}

exprs := sqlparser.SplitAndExpression(nil, where.Expr)
exprs := sqlparser.SplitAndExpression(nil, where)
cols := make([]string, 0, len(exprs))

for _, expr := range exprs {
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/operators/ast_to_op.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ func translateQueryToOp(ctx *plancontext.PlanningContext, selStmt sqlparser.Stat
func createOperatorFromSelect(ctx *plancontext.PlanningContext, sel *sqlparser.Select) Operator {
op := crossJoin(ctx, sel.From)

if sel.Where != nil {
op = addWherePredicates(ctx, sel.Where.Expr, op)
if expr := sel.GetWherePredicate(); expr != nil {
op = addWherePredicates(ctx, expr, op)
}

if sel.Comments != nil || sel.Lock != sqlparser.NoLock {
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/phases.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ func createDMLWithInput(ctx *plancontext.PlanningContext, op, src Operator, in *

if in.OwnedVindexQuery != nil {
in.OwnedVindexQuery.From = sqlparser.TableExprs{targetQT.Alias}
in.OwnedVindexQuery.Where = sqlparser.NewWhere(sqlparser.WhereClause, compExpr)
in.OwnedVindexQuery.AddWhere(compExpr)
in.OwnedVindexQuery.OrderBy = nil
in.OwnedVindexQuery.Limit = nil
}
Expand Down

0 comments on commit aa788bd

Please sign in to comment.