diff --git a/go/test/endtoend/vtgate/queries/dml/dml_test.go b/go/test/endtoend/vtgate/queries/dml/dml_test.go index 8706b37d445..4383f59e6c4 100644 --- a/go/test/endtoend/vtgate/queries/dml/dml_test.go +++ b/go/test/endtoend/vtgate/queries/dml/dml_test.go @@ -366,6 +366,37 @@ func TestMultiTargetUpdate(t *testing.T) { `[[INT64(1) VARCHAR("a")] [INT64(2) VARCHAR("xyz")] [INT64(3) VARCHAR("a")] [INT64(4) VARCHAR("a")]]`) } +// TestMultiTargetNonLiteralUpdate executed multi-target update queries with non-literal values. +func TestMultiTargetNonLiteralUpdate(t *testing.T) { + utils.SkipIfBinaryIsBelowVersion(t, 20, "vtgate") + + mcmp, closer := start(t) + defer closer() + + // initial rows + mcmp.Exec("insert into order_tbl(region_id, oid, cust_no) values (1,1,4), (1,2,2), (2,3,5), (2,4,55)") + mcmp.Exec("insert into oevent_tbl(oid, ename) values (1,'a'), (2,'b'), (3,'a'), (4,'c')") + + // multi target update + qr := mcmp.Exec(`update order_tbl o join oevent_tbl ev on o.oid = ev.oid set ev.ename = o.cust_no where ev.oid > 3`) + assert.EqualValues(t, 1, qr.RowsAffected) + + // check rows + mcmp.AssertMatches(`select region_id, oid, cust_no from order_tbl order by oid`, + `[[INT64(1) INT64(1) INT64(4)] [INT64(1) INT64(2) INT64(2)] [INT64(2) INT64(3) INT64(5)] [INT64(2) INT64(4) INT64(55)]]`) + mcmp.AssertMatches(`select oid, ename from oevent_tbl order by oid`, + `[[INT64(1) VARCHAR("a")] [INT64(2) VARCHAR("b")] [INT64(3) VARCHAR("a")] [INT64(4) VARCHAR("55")]]`) + + qr = mcmp.Exec(`update order_tbl o, oevent_tbl ev set ev.ename = 'xyz', o.oid = ev.oid + 40 where o.cust_no = ev.oid and ev.ename = 'b'`) + assert.EqualValues(t, 2, qr.RowsAffected) + + // check rows + mcmp.AssertMatches(`select region_id, oid, cust_no from order_tbl order by oid, region_id`, + `[[INT64(1) INT64(1) INT64(4)] [INT64(2) INT64(3) INT64(5)] [INT64(2) INT64(4) INT64(55)] [INT64(1) INT64(42) INT64(2)]]`) + mcmp.AssertMatches(`select oid, ename from oevent_tbl order by oid`, + `[[INT64(1) VARCHAR("a")] [INT64(2) VARCHAR("xyz")] [INT64(3) VARCHAR("a")] [INT64(4) VARCHAR("55")]]`) +} + // TestDMLInUnique for update/delete statement using an IN clause with the Vindexes, // the query is correctly split according to the corresponding values in the IN list. func TestDMLInUnique(t *testing.T) { diff --git a/go/vt/vtgate/engine/cached_size.go b/go/vt/vtgate/engine/cached_size.go index c1b72382461..f0354e9e726 100644 --- a/go/vt/vtgate/engine/cached_size.go +++ b/go/vt/vtgate/engine/cached_size.go @@ -181,13 +181,15 @@ func (cached *DML) CachedSize(alloc bool) int64 { size += cached.RoutingParameters.CachedSize(true) return size } + +//go:nocheckptr func (cached *DMLWithInput) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) } size := int64(0) if alloc { - size += int64(64) + size += int64(96) } // field Input vitess.io/vitess/go/vt/vtgate/engine.Primitive if cc, ok := cached.Input.(cachedObject); ok { @@ -211,6 +213,25 @@ func (cached *DMLWithInput) CachedSize(alloc bool) int64 { } } } + // field BVList []map[string]int + { + size += hack.RuntimeAllocSize(int64(cap(cached.BVList)) * int64(8)) + for _, elem := range cached.BVList { + if elem != nil { + size += int64(48) + hmap := reflect.ValueOf(elem) + numBuckets := int(math.Pow(2, float64((*(*uint8)(unsafe.Pointer(hmap.Pointer() + uintptr(9))))))) + numOldBuckets := (*(*uint16)(unsafe.Pointer(hmap.Pointer() + uintptr(10)))) + size += hack.RuntimeAllocSize(int64(numOldBuckets * 208)) + if len(elem) > 0 || numBuckets > 1 { + size += hack.RuntimeAllocSize(int64(numBuckets * 208)) + } + for k := range elem { + size += hack.RuntimeAllocSize(int64(len(k))) + } + } + } + } return size } func (cached *Delete) CachedSize(alloc bool) int64 { diff --git a/go/vt/vtgate/engine/dml_with_input.go b/go/vt/vtgate/engine/dml_with_input.go index 28b306511df..0974f753cef 100644 --- a/go/vt/vtgate/engine/dml_with_input.go +++ b/go/vt/vtgate/engine/dml_with_input.go @@ -39,6 +39,7 @@ type DMLWithInput struct { DMLs []Primitive OutputCols [][]int + BVList []map[string]int } func (dml *DMLWithInput) RouteType() string { @@ -69,18 +70,16 @@ func (dml *DMLWithInput) TryExecute(ctx context.Context, vcursor VCursor, bindVa var res *sqltypes.Result for idx, prim := range dml.DMLs { - var bv *querypb.BindVariable - if len(dml.OutputCols[idx]) == 1 { - bv = getBVSingle(inputRes, dml.OutputCols[idx][0]) + var qr *sqltypes.Result + if len(dml.BVList) == 0 || len(dml.BVList[idx]) == 0 { + qr, err = executeLiteralUpdate(ctx, vcursor, bindVars, prim, inputRes, dml.OutputCols[idx]) } else { - bv = getBVMulti(inputRes, dml.OutputCols[idx]) + qr, err = executeNonLiteralUpdate(ctx, vcursor, bindVars, prim, inputRes, dml.OutputCols[idx], dml.BVList[idx]) } - - bindVars[DmlVals] = bv - qr, err := vcursor.ExecutePrimitive(ctx, prim, bindVars, false) if err != nil { return nil, err } + if res == nil { res = qr } else { @@ -90,18 +89,32 @@ func (dml *DMLWithInput) TryExecute(ctx context.Context, vcursor VCursor, bindVa return res, nil } -func getBVSingle(res *sqltypes.Result, offset int) *querypb.BindVariable { +// executeLiteralUpdate executes the primitive that can be executed with a single bind variable from the input result. +// The column updated have same value for all rows in the input result. +func executeLiteralUpdate(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, prim Primitive, inputRes *sqltypes.Result, outputCols []int) (*sqltypes.Result, error) { + var bv *querypb.BindVariable + if len(outputCols) == 1 { + bv = getBVSingle(inputRes.Rows, outputCols[0]) + } else { + bv = getBVMulti(inputRes.Rows, outputCols) + } + + bindVars[DmlVals] = bv + return vcursor.ExecutePrimitive(ctx, prim, bindVars, false) +} + +func getBVSingle(rows []sqltypes.Row, offset int) *querypb.BindVariable { bv := &querypb.BindVariable{Type: querypb.Type_TUPLE} - for _, row := range res.Rows { + for _, row := range rows { bv.Values = append(bv.Values, sqltypes.ValueToProto(row[offset])) } return bv } -func getBVMulti(res *sqltypes.Result, offsets []int) *querypb.BindVariable { +func getBVMulti(rows []sqltypes.Row, offsets []int) *querypb.BindVariable { bv := &querypb.BindVariable{Type: querypb.Type_TUPLE} outputVals := make([]sqltypes.Value, 0, len(offsets)) - for _, row := range res.Rows { + for _, row := range rows { for _, offset := range offsets { outputVals = append(outputVals, row[offset]) } @@ -111,6 +124,34 @@ func getBVMulti(res *sqltypes.Result, offsets []int) *querypb.BindVariable { return bv } +// executeNonLiteralUpdate executes the primitive that needs to be executed per row from the input result. +// The column updated might have different value for each row in the input result. +func executeNonLiteralUpdate(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, prim Primitive, inputRes *sqltypes.Result, outputCols []int, vars map[string]int) (qr *sqltypes.Result, err error) { + var res *sqltypes.Result + for _, row := range inputRes.Rows { + var bv *querypb.BindVariable + if len(outputCols) == 1 { + bv = getBVSingle([]sqltypes.Row{row}, outputCols[0]) + } else { + bv = getBVMulti([]sqltypes.Row{row}, outputCols) + } + bindVars[DmlVals] = bv + for k, v := range vars { + bindVars[k] = sqltypes.ValueBindVariable(row[v]) + } + qr, err = vcursor.ExecutePrimitive(ctx, prim, bindVars, false) + if err != nil { + return nil, err + } + if res == nil { + res = qr + } else { + res.RowsAffected += res.RowsAffected + } + } + return res, nil +} + // TryStreamExecute performs a streaming exec. func (dml *DMLWithInput) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { res, err := dml.TryExecute(ctx, vcursor, bindVars, wantfields) @@ -133,6 +174,16 @@ func (dml *DMLWithInput) description() PrimitiveDescription { other := map[string]any{ "Offset": offsets, } + var bvList []string + for idx, vars := range dml.BVList { + if len(vars) == 0 { + continue + } + bvList = append(bvList, fmt.Sprintf("%d:[%s]", idx, orderedStringIntMap(vars))) + } + if len(bvList) > 0 { + other["BindVars"] = bvList + } return PrimitiveDescription{ OperatorType: "DMLWithInput", TargetTabletType: topodatapb.TabletType_PRIMARY, diff --git a/go/vt/vtgate/engine/plan_description.go b/go/vt/vtgate/engine/plan_description.go index 72220fda460..a8daa25ecd0 100644 --- a/go/vt/vtgate/engine/plan_description.go +++ b/go/vt/vtgate/engine/plan_description.go @@ -21,6 +21,7 @@ import ( "encoding/json" "fmt" "sort" + "strings" "vitess.io/vitess/go/tools/graphviz" "vitess.io/vitess/go/vt/key" @@ -266,3 +267,11 @@ func (m orderedMap) MarshalJSON() ([]byte, error) { buf.WriteString("}") return buf.Bytes(), nil } + +func (m orderedMap) String() string { + var output []string + for _, val := range m { + output = append(output, fmt.Sprintf("%s:%v", val.key, val.val)) + } + return strings.Join(output, " ") +} diff --git a/go/vt/vtgate/planbuilder/dml_with_input.go b/go/vt/vtgate/planbuilder/dml_with_input.go index 729314e0fc9..1cf72e5ab17 100644 --- a/go/vt/vtgate/planbuilder/dml_with_input.go +++ b/go/vt/vtgate/planbuilder/dml_with_input.go @@ -25,6 +25,7 @@ type dmlWithInput struct { dmls []logicalPlan outputCols [][]int + bvList []map[string]int } var _ logicalPlan = (*dmlWithInput)(nil) @@ -40,5 +41,6 @@ func (d *dmlWithInput) Primitive() engine.Primitive { DMLs: dels, Input: inp, OutputCols: d.outputCols, + BVList: d.bvList, } } diff --git a/go/vt/vtgate/planbuilder/operator_transformers.go b/go/vt/vtgate/planbuilder/operator_transformers.go index c3bbb06f61c..c0935714bbc 100644 --- a/go/vt/vtgate/planbuilder/operator_transformers.go +++ b/go/vt/vtgate/planbuilder/operator_transformers.go @@ -100,6 +100,7 @@ func transformDMLWithInput(ctx *plancontext.PlanningContext, op *operators.DMLWi input: input, dmls: dmls, outputCols: op.Offsets, + bvList: op.BvList, }, nil } diff --git a/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go b/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go index 9d5b76b09a0..d12ed5d1e45 100644 --- a/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go +++ b/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go @@ -513,7 +513,7 @@ func splitGroupingToLeftAndRight( rhs.addGrouping(ctx, groupBy) columns.addRight(groupBy.Inner) case deps.IsSolvedBy(lhs.tableID.Merge(rhs.tableID)): - jc := breakExpressionInLHSandRHSForApplyJoin(ctx, groupBy.Inner, lhs.tableID) + jc := breakExpressionInLHSandRHS(ctx, groupBy.Inner, lhs.tableID) for _, lhsExpr := range jc.LHSExprs { e := lhsExpr.Expr lhs.addGrouping(ctx, NewGroupBy(e)) diff --git a/go/vt/vtgate/planbuilder/operators/apply_join.go b/go/vt/vtgate/planbuilder/operators/apply_join.go index e18169c28b1..894f2af7f7d 100644 --- a/go/vt/vtgate/planbuilder/operators/apply_join.go +++ b/go/vt/vtgate/planbuilder/operators/apply_join.go @@ -160,7 +160,7 @@ func (aj *ApplyJoin) AddJoinPredicate(ctx *plancontext.PlanningContext, expr sql rhs := aj.RHS predicates := sqlparser.SplitAndExpression(nil, expr) for _, pred := range predicates { - col := breakExpressionInLHSandRHSForApplyJoin(ctx, pred, TableID(aj.LHS)) + col := breakExpressionInLHSandRHS(ctx, pred, TableID(aj.LHS)) aj.JoinPredicates.add(col) ctx.AddJoinPredicates(pred, col.RHSExpr) rhs = rhs.AddPredicate(ctx, col.RHSExpr) @@ -202,7 +202,7 @@ func (aj *ApplyJoin) getJoinColumnFor(ctx *plancontext.PlanningContext, orig *sq case deps.IsSolvedBy(rhs): col.RHSExpr = e case deps.IsSolvedBy(both): - col = breakExpressionInLHSandRHSForApplyJoin(ctx, e, TableID(aj.LHS)) + col = breakExpressionInLHSandRHS(ctx, e, TableID(aj.LHS)) default: panic(vterrors.VT13002(sqlparser.String(e))) } diff --git a/go/vt/vtgate/planbuilder/operators/ast_to_op.go b/go/vt/vtgate/planbuilder/operators/ast_to_op.go index 5633239346d..0d838610866 100644 --- a/go/vt/vtgate/planbuilder/operators/ast_to_op.go +++ b/go/vt/vtgate/planbuilder/operators/ast_to_op.go @@ -136,7 +136,7 @@ func (jpc *joinPredicateCollector) inspectPredicate( // then we can use this predicate to connect the subquery to the outer query if !deps.IsSolvedBy(jpc.subqID) && deps.IsSolvedBy(jpc.totalID) { jpc.predicates = append(jpc.predicates, predicate) - jc := breakExpressionInLHSandRHSForApplyJoin(ctx, predicate, jpc.outerID) + jc := breakExpressionInLHSandRHS(ctx, predicate, jpc.outerID) jpc.joinColumns = append(jpc.joinColumns, jc) pred = jc.RHSExpr } diff --git a/go/vt/vtgate/planbuilder/operators/delete.go b/go/vt/vtgate/planbuilder/operators/delete.go index 453a503ced1..a3c45e79135 100644 --- a/go/vt/vtgate/planbuilder/operators/delete.go +++ b/go/vt/vtgate/planbuilder/operators/delete.go @@ -207,9 +207,9 @@ func createDeleteOpWithTarget(ctx *plancontext.PlanningContext, target semantics Where: sqlparser.NewWhere(sqlparser.WhereClause, compExpr), } return dmlOp{ - createOperatorFromDelete(ctx, del), - vTbl, - cols, + op: createOperatorFromDelete(ctx, del), + vTbl: vTbl, + cols: cols, } } diff --git a/go/vt/vtgate/planbuilder/operators/dml_planning.go b/go/vt/vtgate/planbuilder/operators/dml_planning.go index 6d51a33b4aa..866c308956c 100644 --- a/go/vt/vtgate/planbuilder/operators/dml_planning.go +++ b/go/vt/vtgate/planbuilder/operators/dml_planning.go @@ -39,11 +39,12 @@ type TargetTable struct { Name sqlparser.TableName } -// dmlOp stores intermediary value for Update/Delete Operator with the vindexes.Table for ordering. +// dmlOp stores intermediary value for Update/Delete Operator with the vindexes. Table for ordering. type dmlOp struct { - op Operator - vTbl *vindexes.Table - cols []*sqlparser.ColName + op Operator + vTbl *vindexes.Table + cols []*sqlparser.ColName + updList updList } // sortDmlOps sort the operator based on sharding vindex type. diff --git a/go/vt/vtgate/planbuilder/operators/dml_with_input.go b/go/vt/vtgate/planbuilder/operators/dml_with_input.go index 848941b4468..09859b90bac 100644 --- a/go/vt/vtgate/planbuilder/operators/dml_with_input.go +++ b/go/vt/vtgate/planbuilder/operators/dml_with_input.go @@ -32,6 +32,9 @@ type DMLWithInput struct { cols [][]*sqlparser.ColName Offsets [][]int + updList []updList + BvList []map[string]int + noColumns noPredicates } @@ -86,6 +89,7 @@ func (d *DMLWithInput) GetOrdering(ctx *plancontext.PlanningContext) []OrderBy { } func (d *DMLWithInput) planOffsets(ctx *plancontext.PlanningContext) Operator { + // go through the primary key columns to get offset from the input offsets := make([][]int, len(d.cols)) for idx, columns := range d.cols { for _, col := range columns { @@ -94,6 +98,22 @@ func (d *DMLWithInput) planOffsets(ctx *plancontext.PlanningContext) Operator { } } d.Offsets = offsets + + // go through the update list and get offset for input columns + bvList := make([]map[string]int, len(d.updList)) + for idx, ul := range d.updList { + vars := make(map[string]int) + for _, updCol := range ul { + for _, bvExpr := range updCol.jc.LHSExprs { + offset := d.Source.AddColumn(ctx, true, false, aeWrap(bvExpr.Expr)) + vars[bvExpr.Name] = offset + } + } + if len(vars) > 0 { + bvList[idx] = vars + } + } + d.BvList = bvList return d } diff --git a/go/vt/vtgate/planbuilder/operators/expressions.go b/go/vt/vtgate/planbuilder/operators/expressions.go index 521024ab7c9..a39ae96fa88 100644 --- a/go/vt/vtgate/planbuilder/operators/expressions.go +++ b/go/vt/vtgate/planbuilder/operators/expressions.go @@ -22,9 +22,9 @@ import ( "vitess.io/vitess/go/vt/vtgate/semantics" ) -// breakExpressionInLHSandRHSForApplyJoin takes an expression and +// breakExpressionInLHSandRHS takes an expression and // extracts the parts that are coming from one of the sides into `ColName`s that are needed -func breakExpressionInLHSandRHSForApplyJoin( +func breakExpressionInLHSandRHS( ctx *plancontext.PlanningContext, expr sqlparser.Expr, lhs semantics.TableSet, diff --git a/go/vt/vtgate/planbuilder/operators/subquery.go b/go/vt/vtgate/planbuilder/operators/subquery.go index 0597cbe0f18..03a482185d8 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery.go +++ b/go/vt/vtgate/planbuilder/operators/subquery.go @@ -100,7 +100,7 @@ func (sq *SubQuery) GetJoinColumns(ctx *plancontext.PlanningContext, outer Opera } sq.outerID = outerID mapper := func(in sqlparser.Expr) (applyJoinColumn, error) { - return breakExpressionInLHSandRHSForApplyJoin(ctx, in, outerID), nil + return breakExpressionInLHSandRHS(ctx, in, outerID), nil } joinPredicates, err := slice.MapWithError(sq.Predicates, mapper) if err != nil { diff --git a/go/vt/vtgate/planbuilder/operators/subquery_planning.go b/go/vt/vtgate/planbuilder/operators/subquery_planning.go index d2bbad53212..30d5119ca98 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery_planning.go +++ b/go/vt/vtgate/planbuilder/operators/subquery_planning.go @@ -263,7 +263,7 @@ func extractLHSExpr( lhs semantics.TableSet, ) func(expr sqlparser.Expr) sqlparser.Expr { return func(expr sqlparser.Expr) sqlparser.Expr { - col := breakExpressionInLHSandRHSForApplyJoin(ctx, expr, lhs) + col := breakExpressionInLHSandRHS(ctx, expr, lhs) if col.IsPureLeft() { panic(vterrors.VT13001("did not expect to find any predicates that do not need data from the inner here")) } diff --git a/go/vt/vtgate/planbuilder/operators/update.go b/go/vt/vtgate/planbuilder/operators/update.go index b4aca5ee5f8..4abf319ad08 100644 --- a/go/vt/vtgate/planbuilder/operators/update.go +++ b/go/vt/vtgate/planbuilder/operators/update.go @@ -151,17 +151,28 @@ func isMultiTargetUpdate(ctx *plancontext.PlanningContext, updateStmt *sqlparser var targetTS semantics.TableSet for _, ue := range updateStmt.Exprs { targetTS = targetTS.Merge(ctx.SemTable.DirectDeps(ue.Name)) + targetTS = targetTS.Merge(ctx.SemTable.RecursiveDeps(ue.Expr)) } return targetTS.NumberOfTables() > 1 } +type updColumn struct { + updCol *sqlparser.ColName + jc applyJoinColumn +} + +type updList []updColumn + func createUpdateWithInputOp(ctx *plancontext.PlanningContext, upd *sqlparser.Update) (op Operator) { updClone := ctx.SemTable.Clone(upd).(*sqlparser.Update) upd.Limit = nil + // Prepare the update expressions list + ueMap := prepareUpdateExpressionList(ctx, upd) + var updOps []dmlOp for _, target := range ctx.SemTable.Targets.Constituents() { - op := createUpdateOpWithTarget(ctx, target, upd) + op := createUpdateOpWithTarget(ctx, upd, target, ueMap[target]) updOps = append(updOps, op) } @@ -175,10 +186,12 @@ func createUpdateWithInputOp(ctx *plancontext.PlanningContext, upd *sqlparser.Up Lock: sqlparser.ForUpdateLock, } - // now map the operator and column list. + // now map the operator, column list and update list var colsList [][]*sqlparser.ColName + var uList []updList dmls := slice.Map(updOps, func(from dmlOp) Operator { colsList = append(colsList, from.cols) + uList = append(uList, from.updList) for _, col := range from.cols { selectStmt.SelectExprs = append(selectStmt.SelectExprs, aeWrap(col)) } @@ -186,9 +199,10 @@ func createUpdateWithInputOp(ctx *plancontext.PlanningContext, upd *sqlparser.Up }) op = &DMLWithInput{ - DML: dmls, - Source: createOperatorFromSelect(ctx, selectStmt), - cols: colsList, + DML: dmls, + Source: createOperatorFromSelect(ctx, selectStmt), + cols: colsList, + updList: uList, } if upd.Comments != nil { @@ -200,15 +214,43 @@ func createUpdateWithInputOp(ctx *plancontext.PlanningContext, upd *sqlparser.Up return op } -func createUpdateOpWithTarget(ctx *plancontext.PlanningContext, target semantics.TableSet, updStmt *sqlparser.Update) dmlOp { - var updExprs sqlparser.UpdateExprs - for _, ue := range updStmt.Exprs { - if ctx.SemTable.DirectDeps(ue.Name) == target { - updExprs = append(updExprs, ue) +func prepareUpdateExpressionList(ctx *plancontext.PlanningContext, upd *sqlparser.Update) map[semantics.TableSet]updList { + // Any update expression requiring column value from any other table is rewritten to take it as bindvar column. + // E.g. UPDATE t1 join t2 on t1.col = t2.col SET t1.col = t2.col + 1 where t2.col = 10; + // SET t1.col = t2.col + 1 -> SET t1.col = :t2_col + 1 (t2_col is the bindvar column which will be provided from the input) + ueMap := make(map[semantics.TableSet]updList) + for _, ue := range upd.Exprs { + target := ctx.SemTable.DirectDeps(ue.Name) + exprDeps := ctx.SemTable.RecursiveDeps(ue.Expr) + jc := breakExpressionInLHSandRHS(ctx, ue.Expr, exprDeps.Remove(target)) + ueMap[target] = append(ueMap[target], updColumn{ue.Name, jc}) + } + + // Check if any of the dependent columns are updated in the same query. + // This can result in a mismatch of rows on how MySQL interprets it and how Vitess would have updated those rows. + // It is safe to fail for those cases. + errIfDependentColumnUpdated(ctx, upd, ueMap) + + return ueMap +} + +func errIfDependentColumnUpdated(ctx *plancontext.PlanningContext, upd *sqlparser.Update, ueMap map[semantics.TableSet]updList) { + for _, ue := range upd.Exprs { + for _, list := range ueMap { + for _, dc := range list { + for _, bvExpr := range dc.jc.LHSExprs { + if ctx.SemTable.EqualsExprWithDeps(ue.Name, bvExpr.Expr) { + panic(vterrors.VT12001( + fmt.Sprintf("'%s' column referenced in update expression '%s' is itself updated", sqlparser.String(ue.Name), sqlparser.String(dc.jc.Original)))) + } + } + } } } +} - if len(updExprs) == 0 { +func createUpdateOpWithTarget(ctx *plancontext.PlanningContext, updStmt *sqlparser.Update, target semantics.TableSet, uList updList) dmlOp { + if len(uList) == 0 { panic(vterrors.VT13001("no update expression for the target")) } @@ -237,6 +279,14 @@ func createUpdateOpWithTarget(ctx *plancontext.PlanningContext, target semantics } compExpr := sqlparser.NewComparisonExpr(sqlparser.InOp, lhs, sqlparser.ListArg(engine.DmlVals), nil) + var updExprs sqlparser.UpdateExprs + for _, expr := range uList { + ue := &sqlparser.UpdateExpr{ + Name: expr.updCol, + Expr: expr.jc.RHSExpr, + } + updExprs = append(updExprs, ue) + } upd := &sqlparser.Update{ Ignore: updStmt.Ignore, TableExprs: sqlparser.TableExprs{ti.GetAliasedTableExpr()}, @@ -245,9 +295,10 @@ func createUpdateOpWithTarget(ctx *plancontext.PlanningContext, target semantics OrderBy: updStmt.OrderBy, } return dmlOp{ - createOperatorFromUpdate(ctx, upd), - vTbl, - cols, + op: createOperatorFromUpdate(ctx, upd), + vTbl: vTbl, + cols: cols, + updList: uList, } } diff --git a/go/vt/vtgate/planbuilder/testdata/dml_cases.json b/go/vt/vtgate/planbuilder/testdata/dml_cases.json index 3c1f202bd8d..9c2ed1920ee 100644 --- a/go/vt/vtgate/planbuilder/testdata/dml_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/dml_cases.json @@ -5819,6 +5819,257 @@ ] } }, + { + "comment": "update with multi table join with single target having dependent column update", + "query": "update user as u, user_extra as ue set u.col = ue.col where u.id = ue.id", + "plan": { + "QueryType": "UPDATE", + "Original": "update user as u, user_extra as ue set u.col = ue.col where u.id = ue.id", + "Instructions": { + "OperatorType": "DMLWithInput", + "TargetTabletType": "PRIMARY", + "BindVars": [ + "0:[ue_col:1]" + ], + "Offset": [ + "0:[0]" + ], + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "R:0,L:0", + "JoinVars": { + "ue_id": 1 + }, + "TableName": "user_extra_`user`", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select ue.col, ue.id from user_extra as ue where 1 != 1", + "Query": "select ue.col, ue.id from user_extra as ue for update", + "Table": "user_extra" + }, + { + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select u.id from `user` as u where 1 != 1", + "Query": "select u.id from `user` as u where u.id = :ue_id for update", + "Table": "`user`", + "Values": [ + ":ue_id" + ], + "Vindex": "user_index" + } + ] + }, + { + "OperatorType": "Update", + "Variant": "IN", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "TargetTabletType": "PRIMARY", + "Query": "update `user` as u set u.col = :ue_col where u.id in ::dml_vals", + "Table": "user", + "Values": [ + "::dml_vals" + ], + "Vindex": "user_index" + } + ] + }, + "TablesUsed": [ + "user.user", + "user.user_extra" + ] + } + }, + { + "comment": "update with multi table join with single target having multiple dependent column update", + "query": "update user as u, user_extra as ue set u.col = ue.foo + ue.bar + u.baz where u.id = ue.id", + "plan": { + "QueryType": "UPDATE", + "Original": "update user as u, user_extra as ue set u.col = ue.foo + ue.bar + u.baz where u.id = ue.id", + "Instructions": { + "OperatorType": "DMLWithInput", + "TargetTabletType": "PRIMARY", + "BindVars": [ + "0:[ue_bar:2 ue_foo:1]" + ], + "Offset": [ + "0:[0]" + ], + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "R:0,L:0,L:1", + "JoinVars": { + "ue_id": 2 + }, + "TableName": "user_extra_`user`", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select ue.foo, ue.bar, ue.id from user_extra as ue where 1 != 1", + "Query": "select ue.foo, ue.bar, ue.id from user_extra as ue for update", + "Table": "user_extra" + }, + { + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select u.id from `user` as u where 1 != 1", + "Query": "select u.id from `user` as u where u.id = :ue_id for update", + "Table": "`user`", + "Values": [ + ":ue_id" + ], + "Vindex": "user_index" + } + ] + }, + { + "OperatorType": "Update", + "Variant": "IN", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "TargetTabletType": "PRIMARY", + "Query": "update `user` as u set u.col = :ue_foo + :ue_bar + u.baz where u.id in ::dml_vals", + "Table": "user", + "Values": [ + "::dml_vals" + ], + "Vindex": "user_index" + } + ] + }, + "TablesUsed": [ + "user.user", + "user.user_extra" + ] + } + }, + { + "comment": "update with multi table join with multi target having dependent column update", + "query": "update user, user_extra ue set user.name = ue.id + 'foo', ue.bar = user.baz where user.id = ue.id and user.id = 1", + "plan": { + "QueryType": "UPDATE", + "Original": "update user, user_extra ue set user.name = ue.id + 'foo', ue.bar = user.baz where user.id = ue.id and user.id = 1", + "Instructions": { + "OperatorType": "DMLWithInput", + "TargetTabletType": "PRIMARY", + "BindVars": [ + "0:[ue_id:1]", + "1:[user_baz:3]" + ], + "Offset": [ + "0:[0]", + "1:[1 2]" + ], + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "L:0,R:0,R:1,L:1", + "JoinVars": { + "user_id": 0 + }, + "TableName": "`user`_user_extra", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select `user`.id, `user`.baz from `user` where 1 != 1", + "Query": "select `user`.id, `user`.baz from `user` where `user`.id = 1 for update", + "Table": "`user`", + "Values": [ + "1" + ], + "Vindex": "user_index" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select ue.id, ue.user_id from user_extra as ue where 1 != 1", + "Query": "select ue.id, ue.user_id from user_extra as ue where ue.id = :user_id for update", + "Table": "user_extra" + } + ] + }, + { + "OperatorType": "Update", + "Variant": "IN", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "TargetTabletType": "PRIMARY", + "ChangedVindexValues": [ + "name_user_map:3" + ], + "KsidLength": 1, + "KsidVindex": "user_index", + "OwnedVindexQuery": "select Id, `Name`, Costly, `user`.`name` = :ue_id + 'foo' from `user` where `user`.id in ::dml_vals for update", + "Query": "update `user` set `user`.`name` = :ue_id + 'foo' where `user`.id in ::dml_vals", + "Table": "user", + "Values": [ + "::dml_vals" + ], + "Vindex": "user_index" + }, + { + "OperatorType": "Update", + "Variant": "MultiEqual", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "TargetTabletType": "PRIMARY", + "Query": "update user_extra as ue set ue.bar = :user_baz where (ue.id, ue.user_id) in ::dml_vals", + "Table": "user_extra", + "Values": [ + "dml_vals:1" + ], + "Vindex": "user_index" + } + ] + }, + "TablesUsed": [ + "user.user", + "user.user_extra" + ] + } + }, { "comment": "update with multi table reference with multi target update on a derived table", "query": "update ignore (select foo, col, bar from user) u, music m set u.foo = 21, u.bar = 'abc' where u.col = m.col", diff --git a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json index 9565db9a035..251af436d27 100644 --- a/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/unsupported_cases.json @@ -44,6 +44,11 @@ "query": "update user_metadata set email = 'juan@vitess.io' where user_id = 1 limit 10", "plan": "VT12001: unsupported: Vindex update should have ORDER BY clause when using LIMIT" }, + { + "comment": "multi table update with dependent column getting updated", + "query": "update user u, user_extra ue set u.name = 'test' + ue.col, ue.col = 5 where u.id = ue.id and u.id = 1;", + "plan": "VT12001: unsupported: 'ue.col' column referenced in update expression ''test' + ue.col' is itself updated" + }, { "comment": "unsharded insert, col list does not match values", "query": "insert into unsharded_auto(id, val) values(1)",