Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for multi table update with non literal value #15980

Merged
merged 7 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions go/test/endtoend/vtgate/queries/dml/dml_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,37 @@ func TestMultiTargetUpdate(t *testing.T) {
`[[INT64(1) VARCHAR("a")] [INT64(2) VARCHAR("xyz")] [INT64(3) VARCHAR("a")] [INT64(4) VARCHAR("a")]]`)
}

// TestMultiTargetNonLiteralUpdate executed multi-target update queries with non-literal values.
func TestMultiTargetNonLiteralUpdate(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 20, "vtgate")

mcmp, closer := start(t)
defer closer()

// initial rows
mcmp.Exec("insert into order_tbl(region_id, oid, cust_no) values (1,1,4), (1,2,2), (2,3,5), (2,4,55)")
mcmp.Exec("insert into oevent_tbl(oid, ename) values (1,'a'), (2,'b'), (3,'a'), (4,'c')")

// multi target update
qr := mcmp.Exec(`update order_tbl o join oevent_tbl ev on o.oid = ev.oid set ev.ename = o.cust_no where ev.oid > 3`)
assert.EqualValues(t, 1, qr.RowsAffected)

// check rows
mcmp.AssertMatches(`select region_id, oid, cust_no from order_tbl order by oid`,
`[[INT64(1) INT64(1) INT64(4)] [INT64(1) INT64(2) INT64(2)] [INT64(2) INT64(3) INT64(5)] [INT64(2) INT64(4) INT64(55)]]`)
mcmp.AssertMatches(`select oid, ename from oevent_tbl order by oid`,
`[[INT64(1) VARCHAR("a")] [INT64(2) VARCHAR("b")] [INT64(3) VARCHAR("a")] [INT64(4) VARCHAR("55")]]`)

qr = mcmp.Exec(`update order_tbl o, oevent_tbl ev set ev.ename = 'xyz', o.oid = ev.oid + 40 where o.cust_no = ev.oid and ev.ename = 'b'`)
assert.EqualValues(t, 2, qr.RowsAffected)

// check rows
mcmp.AssertMatches(`select region_id, oid, cust_no from order_tbl order by oid, region_id`,
`[[INT64(1) INT64(1) INT64(4)] [INT64(2) INT64(3) INT64(5)] [INT64(2) INT64(4) INT64(55)] [INT64(1) INT64(42) INT64(2)]]`)
mcmp.AssertMatches(`select oid, ename from oevent_tbl order by oid`,
`[[INT64(1) VARCHAR("a")] [INT64(2) VARCHAR("xyz")] [INT64(3) VARCHAR("a")] [INT64(4) VARCHAR("55")]]`)
}

// TestDMLInUnique for update/delete statement using an IN clause with the Vindexes,
// the query is correctly split according to the corresponding values in the IN list.
func TestDMLInUnique(t *testing.T) {
Expand Down
23 changes: 22 additions & 1 deletion go/vt/vtgate/engine/cached_size.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

73 changes: 62 additions & 11 deletions go/vt/vtgate/engine/dml_with_input.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ type DMLWithInput struct {

DMLs []Primitive
OutputCols [][]int
BVList []map[string]int
}

func (dml *DMLWithInput) RouteType() string {
Expand Down Expand Up @@ -69,18 +70,16 @@ func (dml *DMLWithInput) TryExecute(ctx context.Context, vcursor VCursor, bindVa

var res *sqltypes.Result
for idx, prim := range dml.DMLs {
var bv *querypb.BindVariable
if len(dml.OutputCols[idx]) == 1 {
bv = getBVSingle(inputRes, dml.OutputCols[idx][0])
var qr *sqltypes.Result
if len(dml.BVList) == 0 || len(dml.BVList[idx]) == 0 {
qr, err = executeLiteralUpdate(ctx, vcursor, bindVars, prim, inputRes, dml.OutputCols[idx])
harshit-gangal marked this conversation as resolved.
Show resolved Hide resolved
} else {
bv = getBVMulti(inputRes, dml.OutputCols[idx])
qr, err = executeNonLiteralUpdate(ctx, vcursor, bindVars, prim, inputRes, dml.OutputCols[idx], dml.BVList[idx])
}

bindVars[DmlVals] = bv
qr, err := vcursor.ExecutePrimitive(ctx, prim, bindVars, false)
if err != nil {
return nil, err
}

if res == nil {
res = qr
} else {
Expand All @@ -90,18 +89,32 @@ func (dml *DMLWithInput) TryExecute(ctx context.Context, vcursor VCursor, bindVa
return res, nil
}

func getBVSingle(res *sqltypes.Result, offset int) *querypb.BindVariable {
// executeLiteralUpdate executes the primitive that can be executed with a single bind variable from the input result.
// The column updated have same value for all rows in the input result.
func executeLiteralUpdate(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, prim Primitive, inputRes *sqltypes.Result, outputCols []int) (*sqltypes.Result, error) {
harshit-gangal marked this conversation as resolved.
Show resolved Hide resolved
var bv *querypb.BindVariable
if len(outputCols) == 1 {
bv = getBVSingle(inputRes.Rows, outputCols[0])
} else {
bv = getBVMulti(inputRes.Rows, outputCols)
}

bindVars[DmlVals] = bv
return vcursor.ExecutePrimitive(ctx, prim, bindVars, false)
}

func getBVSingle(rows []sqltypes.Row, offset int) *querypb.BindVariable {
bv := &querypb.BindVariable{Type: querypb.Type_TUPLE}
for _, row := range res.Rows {
for _, row := range rows {
bv.Values = append(bv.Values, sqltypes.ValueToProto(row[offset]))
}
return bv
}

func getBVMulti(res *sqltypes.Result, offsets []int) *querypb.BindVariable {
func getBVMulti(rows []sqltypes.Row, offsets []int) *querypb.BindVariable {
bv := &querypb.BindVariable{Type: querypb.Type_TUPLE}
outputVals := make([]sqltypes.Value, 0, len(offsets))
for _, row := range res.Rows {
for _, row := range rows {
for _, offset := range offsets {
outputVals = append(outputVals, row[offset])
}
Expand All @@ -111,6 +124,34 @@ func getBVMulti(res *sqltypes.Result, offsets []int) *querypb.BindVariable {
return bv
}

// executeNonLiteralUpdate executes the primitive that needs to be executed per row from the input result.
// The column updated might have different value for each row in the input result.
func executeNonLiteralUpdate(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, prim Primitive, inputRes *sqltypes.Result, outputCols []int, vars map[string]int) (qr *sqltypes.Result, err error) {
harshit-gangal marked this conversation as resolved.
Show resolved Hide resolved
var res *sqltypes.Result
for _, row := range inputRes.Rows {
var bv *querypb.BindVariable
if len(outputCols) == 1 {
bv = getBVSingle([]sqltypes.Row{row}, outputCols[0])
} else {
bv = getBVMulti([]sqltypes.Row{row}, outputCols)
}
bindVars[DmlVals] = bv
for k, v := range vars {
bindVars[k] = sqltypes.ValueBindVariable(row[v])
}
qr, err = vcursor.ExecutePrimitive(ctx, prim, bindVars, false)
if err != nil {
return nil, err
}
if res == nil {
res = qr
} else {
res.RowsAffected += res.RowsAffected
}
}
return res, nil
}

// TryStreamExecute performs a streaming exec.
func (dml *DMLWithInput) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
res, err := dml.TryExecute(ctx, vcursor, bindVars, wantfields)
Expand All @@ -133,6 +174,16 @@ func (dml *DMLWithInput) description() PrimitiveDescription {
other := map[string]any{
"Offset": offsets,
}
var bvList []string
for idx, vars := range dml.BVList {
if len(vars) == 0 {
continue
}
bvList = append(bvList, fmt.Sprintf("%d:[%s]", idx, orderedStringIntMap(vars)))
}
if len(bvList) > 0 {
other["BindVars"] = bvList
}
return PrimitiveDescription{
OperatorType: "DMLWithInput",
TargetTabletType: topodatapb.TabletType_PRIMARY,
Expand Down
9 changes: 9 additions & 0 deletions go/vt/vtgate/engine/plan_description.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"encoding/json"
"fmt"
"sort"
"strings"

"vitess.io/vitess/go/tools/graphviz"
"vitess.io/vitess/go/vt/key"
Expand Down Expand Up @@ -266,3 +267,11 @@ func (m orderedMap) MarshalJSON() ([]byte, error) {
buf.WriteString("}")
return buf.Bytes(), nil
}

func (m orderedMap) String() string {
var output []string
for _, val := range m {
output = append(output, fmt.Sprintf("%s:%v", val.key, val.val))
}
return strings.Join(output, " ")
}
2 changes: 2 additions & 0 deletions go/vt/vtgate/planbuilder/dml_with_input.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type dmlWithInput struct {
dmls []logicalPlan

outputCols [][]int
bvList []map[string]int
}

var _ logicalPlan = (*dmlWithInput)(nil)
Expand All @@ -40,5 +41,6 @@ func (d *dmlWithInput) Primitive() engine.Primitive {
DMLs: dels,
Input: inp,
OutputCols: d.outputCols,
BVList: d.bvList,
}
}
1 change: 1 addition & 0 deletions go/vt/vtgate/planbuilder/operator_transformers.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ func transformDMLWithInput(ctx *plancontext.PlanningContext, op *operators.DMLWi
input: input,
dmls: dmls,
outputCols: op.Offsets,
bvList: op.BvList,
}, nil
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ func splitGroupingToLeftAndRight(
rhs.addGrouping(ctx, groupBy)
columns.addRight(groupBy.Inner)
case deps.IsSolvedBy(lhs.tableID.Merge(rhs.tableID)):
jc := breakExpressionInLHSandRHSForApplyJoin(ctx, groupBy.Inner, lhs.tableID)
jc := breakExpressionInLHSandRHS(ctx, groupBy.Inner, lhs.tableID)
for _, lhsExpr := range jc.LHSExprs {
e := lhsExpr.Expr
lhs.addGrouping(ctx, NewGroupBy(e))
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/operators/apply_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ func (aj *ApplyJoin) AddJoinPredicate(ctx *plancontext.PlanningContext, expr sql
rhs := aj.RHS
predicates := sqlparser.SplitAndExpression(nil, expr)
for _, pred := range predicates {
col := breakExpressionInLHSandRHSForApplyJoin(ctx, pred, TableID(aj.LHS))
col := breakExpressionInLHSandRHS(ctx, pred, TableID(aj.LHS))
aj.JoinPredicates.add(col)
ctx.AddJoinPredicates(pred, col.RHSExpr)
rhs = rhs.AddPredicate(ctx, col.RHSExpr)
Expand Down Expand Up @@ -202,7 +202,7 @@ func (aj *ApplyJoin) getJoinColumnFor(ctx *plancontext.PlanningContext, orig *sq
case deps.IsSolvedBy(rhs):
col.RHSExpr = e
case deps.IsSolvedBy(both):
col = breakExpressionInLHSandRHSForApplyJoin(ctx, e, TableID(aj.LHS))
col = breakExpressionInLHSandRHS(ctx, e, TableID(aj.LHS))
default:
panic(vterrors.VT13002(sqlparser.String(e)))
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/ast_to_op.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func (jpc *joinPredicateCollector) inspectPredicate(
// then we can use this predicate to connect the subquery to the outer query
if !deps.IsSolvedBy(jpc.subqID) && deps.IsSolvedBy(jpc.totalID) {
jpc.predicates = append(jpc.predicates, predicate)
jc := breakExpressionInLHSandRHSForApplyJoin(ctx, predicate, jpc.outerID)
jc := breakExpressionInLHSandRHS(ctx, predicate, jpc.outerID)
jpc.joinColumns = append(jpc.joinColumns, jc)
pred = jc.RHSExpr
}
Expand Down
6 changes: 3 additions & 3 deletions go/vt/vtgate/planbuilder/operators/delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,9 @@ func createDeleteOpWithTarget(ctx *plancontext.PlanningContext, target semantics
Where: sqlparser.NewWhere(sqlparser.WhereClause, compExpr),
}
return dmlOp{
createOperatorFromDelete(ctx, del),
vTbl,
cols,
op: createOperatorFromDelete(ctx, del),
vTbl: vTbl,
cols: cols,
}
}

Expand Down
9 changes: 5 additions & 4 deletions go/vt/vtgate/planbuilder/operators/dml_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@ type TargetTable struct {
Name sqlparser.TableName
}

// dmlOp stores intermediary value for Update/Delete Operator with the vindexes.Table for ordering.
// dmlOp stores intermediary value for Update/Delete Operator with the vindexes. Table for ordering.
type dmlOp struct {
op Operator
vTbl *vindexes.Table
cols []*sqlparser.ColName
op Operator
vTbl *vindexes.Table
cols []*sqlparser.ColName
updList updList
}

// sortDmlOps sort the operator based on sharding vindex type.
Expand Down
20 changes: 20 additions & 0 deletions go/vt/vtgate/planbuilder/operators/dml_with_input.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ type DMLWithInput struct {
cols [][]*sqlparser.ColName
Offsets [][]int

updList []updList
BvList []map[string]int

noColumns
noPredicates
}
Expand Down Expand Up @@ -86,6 +89,7 @@ func (d *DMLWithInput) GetOrdering(ctx *plancontext.PlanningContext) []OrderBy {
}

func (d *DMLWithInput) planOffsets(ctx *plancontext.PlanningContext) Operator {
// go through the primary key columns to get offset from the input
offsets := make([][]int, len(d.cols))
for idx, columns := range d.cols {
for _, col := range columns {
Expand All @@ -94,6 +98,22 @@ func (d *DMLWithInput) planOffsets(ctx *plancontext.PlanningContext) Operator {
}
}
d.Offsets = offsets

// go through the update list and get offset for input columns
bvList := make([]map[string]int, len(d.updList))
for idx, ul := range d.updList {
vars := make(map[string]int)
for _, updCol := range ul {
for _, bvExpr := range updCol.jc.LHSExprs {
offset := d.Source.AddColumn(ctx, true, false, aeWrap(bvExpr.Expr))
vars[bvExpr.Name] = offset
}
}
if len(vars) > 0 {
bvList[idx] = vars
}
}
d.BvList = bvList
return d
}

Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/operators/expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ import (
"vitess.io/vitess/go/vt/vtgate/semantics"
)

// breakExpressionInLHSandRHSForApplyJoin takes an expression and
// breakExpressionInLHSandRHS takes an expression and
// extracts the parts that are coming from one of the sides into `ColName`s that are needed
func breakExpressionInLHSandRHSForApplyJoin(
func breakExpressionInLHSandRHS(
ctx *plancontext.PlanningContext,
expr sqlparser.Expr,
lhs semantics.TableSet,
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/subquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func (sq *SubQuery) GetJoinColumns(ctx *plancontext.PlanningContext, outer Opera
}
sq.outerID = outerID
mapper := func(in sqlparser.Expr) (applyJoinColumn, error) {
return breakExpressionInLHSandRHSForApplyJoin(ctx, in, outerID), nil
return breakExpressionInLHSandRHS(ctx, in, outerID), nil
}
joinPredicates, err := slice.MapWithError(sq.Predicates, mapper)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/subquery_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ func extractLHSExpr(
lhs semantics.TableSet,
) func(expr sqlparser.Expr) sqlparser.Expr {
return func(expr sqlparser.Expr) sqlparser.Expr {
col := breakExpressionInLHSandRHSForApplyJoin(ctx, expr, lhs)
col := breakExpressionInLHSandRHS(ctx, expr, lhs)
if col.IsPureLeft() {
panic(vterrors.VT13001("did not expect to find any predicates that do not need data from the inner here"))
}
Expand Down
Loading
Loading