From aa788bdda14345f67c54410512b47d83913dff05 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Wed, 28 Aug 2024 07:58:15 +0200 Subject: [PATCH] refactoring Signed-off-by: Andres Taylor --- go/vt/vtctl/workflow/stream_migrator.go | 11 ++--- go/vt/vtctl/workflow/vexec/query_planner.go | 48 +++++-------------- .../vtgate/planbuilder/operators/ast_to_op.go | 4 +- go/vt/vtgate/planbuilder/operators/phases.go | 2 +- 4 files changed, 17 insertions(+), 48 deletions(-) diff --git a/go/vt/vtctl/workflow/stream_migrator.go b/go/vt/vtctl/workflow/stream_migrator.go index b294ba1fcd0..724413e6030 100644 --- a/go/vt/vtctl/workflow/stream_migrator.go +++ b/go/vt/vtctl/workflow/stream_migrator.go @@ -1130,7 +1130,7 @@ func (sm *StreamMigrator) templatizeRule(ctx context.Context, rule *binlogdatapb case rule.Filter == vreplication.ExcludeStr: return StreamTypeUnknown, fmt.Errorf("unexpected rule in vreplication: %v", rule) default: - if err := sm.templatizeKeyRange(ctx, rule); err != nil { + if err := sm.templatizeKeyRange(rule); err != nil { return StreamTypeUnknown, err } @@ -1138,7 +1138,7 @@ func (sm *StreamMigrator) templatizeRule(ctx context.Context, rule *binlogdatapb } } -func (sm *StreamMigrator) templatizeKeyRange(ctx context.Context, rule *binlogdatapb.Rule) error { +func (sm *StreamMigrator) templatizeKeyRange(rule *binlogdatapb.Rule) error { statement, err := sm.parser.Parse(rule.Filter) if err != nil { return err @@ -1149,12 +1149,7 @@ func (sm *StreamMigrator) templatizeKeyRange(ctx context.Context, rule *binlogda return fmt.Errorf("unexpected query: %v", rule.Filter) } - var expr sqlparser.Expr - if sel.Where != nil { - expr = sel.Where.Expr - } - - exprs := sqlparser.SplitAndExpression(nil, expr) + exprs := sqlparser.SplitAndExpression(nil, sel.GetWherePredicate()) for _, subexpr := range exprs { funcExpr, ok := subexpr.(*sqlparser.FuncExpr) if !ok || !funcExpr.Name.EqualString("in_keyrange") { diff --git a/go/vt/vtctl/workflow/vexec/query_planner.go b/go/vt/vtctl/workflow/vexec/query_planner.go index 9d16dc72f55..3d3541fafce 100644 --- a/go/vt/vtctl/workflow/vexec/query_planner.go +++ b/go/vt/vtctl/workflow/vexec/query_planner.go @@ -181,7 +181,7 @@ func (planner *VReplicationQueryPlanner) planDelete(del *sqlparser.Delete) (*Fix ) } - del.Where = addDefaultWheres(planner, del.Where) + addDefaultWheres(planner, del) buf := sqlparser.NewTrackedBuffer(nil) buf.Myprintf("%v", del) @@ -194,7 +194,7 @@ func (planner *VReplicationQueryPlanner) planDelete(del *sqlparser.Delete) (*Fix } func (planner *VReplicationQueryPlanner) planSelect(sel *sqlparser.Select) (*FixedQueryPlan, error) { - sel.Where = addDefaultWheres(planner, sel.Where) + addDefaultWheres(planner, sel) buf := sqlparser.NewTrackedBuffer(nil) buf.Myprintf("%v", sel) @@ -230,7 +230,7 @@ func (planner *VReplicationQueryPlanner) planUpdate(upd *sqlparser.Update) (*Fix } } - upd.Where = addDefaultWheres(planner, upd.Where) + addDefaultWheres(planner, upd) buf := sqlparser.NewTrackedBuffer(nil) buf.Myprintf("%v", upd) @@ -289,8 +289,7 @@ func (planner *VReplicationLogQueryPlanner) QueryParams() QueryParams { } func (planner *VReplicationLogQueryPlanner) planSelect(sel *sqlparser.Select) (QueryPlan, error) { - where := sel.Where - cols := extractWhereComparisonColumns(where) + cols := extractWhereComparisonColumns(sel.GetWherePredicate()) hasVReplIDCol := false for _, col := range cols { @@ -313,10 +312,6 @@ func (planner *VReplicationLogQueryPlanner) planSelect(sel *sqlparser.Select) (Q // streamIDs. queriesByTarget := make(map[string]*sqlparser.ParsedQuery, len(planner.tabletStreamIDs)) for target, streamIDs := range planner.tabletStreamIDs { - targetWhere := &sqlparser.Where{ - Type: sqlparser.WhereClause, - } - var expr sqlparser.Expr switch len(streamIDs) { case 0: // WHERE vreplication_log.vrepl_id IN () => WHERE 1 != 1 @@ -349,15 +344,7 @@ func (planner *VReplicationLogQueryPlanner) planSelect(sel *sqlparser.Select) (Q Right: tuple, } } - - switch where { - case nil: - targetWhere.Expr = expr - default: - targetWhere.Expr = sqlparser.CreateAndExpr(expr, where.Expr) - } - - sel.Where = targetWhere + sel.AddWhere(expr) buf := sqlparser.NewTrackedBuffer(nil) buf.Myprintf("%v", sel) @@ -371,8 +358,8 @@ func (planner *VReplicationLogQueryPlanner) planSelect(sel *sqlparser.Select) (Q }, nil } -func addDefaultWheres(planner QueryPlanner, where *sqlparser.Where) *sqlparser.Where { - cols := extractWhereComparisonColumns(where) +func addDefaultWheres(planner QueryPlanner, stmt sqlparser.WhereAble) { + cols := extractWhereComparisonColumns(stmt.GetWherePredicate()) params := planner.QueryParams() hasDBNameCol := false @@ -387,8 +374,6 @@ func addDefaultWheres(planner QueryPlanner, where *sqlparser.Where) *sqlparser.W } } - newWhere := where - if !hasDBNameCol { expr := &sqlparser.ComparisonExpr{ Left: &sqlparser.ColName{ @@ -398,15 +383,7 @@ func addDefaultWheres(planner QueryPlanner, where *sqlparser.Where) *sqlparser.W Right: sqlparser.NewStrLiteral(params.DBName), } - switch newWhere { - case nil: - newWhere = &sqlparser.Where{ - Type: sqlparser.WhereClause, - Expr: expr, - } - default: - newWhere.Expr = sqlparser.CreateAndExpr(newWhere.Expr, expr) - } + stmt.AddWhere(expr) } if !hasWorkflowCol && params.Workflow != "" { @@ -417,23 +394,20 @@ func addDefaultWheres(planner QueryPlanner, where *sqlparser.Where) *sqlparser.W Operator: sqlparser.EqualOp, Right: sqlparser.NewStrLiteral(params.Workflow), } - - newWhere.Expr = sqlparser.CreateAndExpr(newWhere.Expr, expr) + stmt.AddWhere(expr) } - - return newWhere } // extractWhereComparisonColumns extracts the column names used in AND-ed // comparison expressions in a where clause, given the following assumptions: // - (1) The column name is always the left-hand side of the comparison. // - (2) There are no compound expressions within the where clause involving OR. -func extractWhereComparisonColumns(where *sqlparser.Where) []string { +func extractWhereComparisonColumns(where sqlparser.Expr) []string { if where == nil { return nil } - exprs := sqlparser.SplitAndExpression(nil, where.Expr) + exprs := sqlparser.SplitAndExpression(nil, where) cols := make([]string, 0, len(exprs)) for _, expr := range exprs { diff --git a/go/vt/vtgate/planbuilder/operators/ast_to_op.go b/go/vt/vtgate/planbuilder/operators/ast_to_op.go index 4c075f480d3..a9903edcc79 100644 --- a/go/vt/vtgate/planbuilder/operators/ast_to_op.go +++ b/go/vt/vtgate/planbuilder/operators/ast_to_op.go @@ -50,8 +50,8 @@ func translateQueryToOp(ctx *plancontext.PlanningContext, selStmt sqlparser.Stat func createOperatorFromSelect(ctx *plancontext.PlanningContext, sel *sqlparser.Select) Operator { op := crossJoin(ctx, sel.From) - if sel.Where != nil { - op = addWherePredicates(ctx, sel.Where.Expr, op) + if expr := sel.GetWherePredicate(); expr != nil { + op = addWherePredicates(ctx, expr, op) } if sel.Comments != nil || sel.Lock != sqlparser.NoLock { diff --git a/go/vt/vtgate/planbuilder/operators/phases.go b/go/vt/vtgate/planbuilder/operators/phases.go index d5354e9548f..cf126236a74 100644 --- a/go/vt/vtgate/planbuilder/operators/phases.go +++ b/go/vt/vtgate/planbuilder/operators/phases.go @@ -193,7 +193,7 @@ func createDMLWithInput(ctx *plancontext.PlanningContext, op, src Operator, in * if in.OwnedVindexQuery != nil { in.OwnedVindexQuery.From = sqlparser.TableExprs{targetQT.Alias} - in.OwnedVindexQuery.Where = sqlparser.NewWhere(sqlparser.WhereClause, compExpr) + in.OwnedVindexQuery.AddWhere(compExpr) in.OwnedVindexQuery.OrderBy = nil in.OwnedVindexQuery.Limit = nil }