diff --git a/go/test/endtoend/vtgate/queries/dml/insert_test.go b/go/test/endtoend/vtgate/queries/dml/insert_test.go index 1d09d3aab51..242b2813efc 100644 --- a/go/test/endtoend/vtgate/queries/dml/insert_test.go +++ b/go/test/endtoend/vtgate/queries/dml/insert_test.go @@ -54,6 +54,27 @@ func TestSimpleInsertSelect(t *testing.T) { utils.AssertMatches(t, mcmp.VtConn, `select num from num_vdx_tbl order by num`, `[[INT64(2)] [INT64(4)] [INT64(40)] [INT64(42)] [INT64(80)] [INT64(84)]]`) } +// TestInsertOnDup test the insert on duplicate key update feature with argument and list argument. +func TestInsertOnDup(t *testing.T) { + utils.SkipIfBinaryIsBelowVersion(t, 20, "vtgate") + + mcmp, closer := start(t) + defer closer() + + mcmp.Exec("insert into order_tbl(oid, region_id, cust_no) values (1,2,3),(3,4,5)") + + for _, mode := range []string{"oltp", "olap"} { + mcmp.Run(mode, func(mcmp *utils.MySQLCompare) { + utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", mode)) + + mcmp.Exec(`insert into order_tbl(oid, region_id, cust_no) values (2,2,3),(4,4,5) on duplicate key update cust_no = if(values(cust_no) in (1, 2, 3), region_id, values(cust_no))`) + mcmp.Exec(`select oid, region_id, cust_no from order_tbl order by oid, region_id`) + mcmp.Exec(`insert into order_tbl(oid, region_id, cust_no) values (7,2,2) on duplicate key update cust_no = 10 + values(cust_no)`) + mcmp.Exec(`select oid, region_id, cust_no from order_tbl order by oid, region_id`) + }) + } +} + func TestFailureInsertSelect(t *testing.T) { if clusterInstance.HasPartialKeyspaces { t.Skip("don't run on partial keyspaces") diff --git a/go/vt/sqlparser/normalizer_test.go b/go/vt/sqlparser/normalizer_test.go index 58580a78701..58dbfed440c 100644 --- a/go/vt/sqlparser/normalizer_test.go +++ b/go/vt/sqlparser/normalizer_test.go @@ -379,6 +379,15 @@ func TestNormalize(t *testing.T) { "v1": sqltypes.HexValBindVariable([]byte("x'31'")), "v2": sqltypes.Int64BindVariable(31), }, + }, { + // list in on duplicate key update + in: "insert into t(a, b) values (1, 2) on duplicate key update b = if(values(b) in (1, 2), b, values(b))", + outstmt: "insert into t(a, b) values (:bv1 /* INT64 */, :bv2 /* INT64 */) on duplicate key update b = if(values(b) in ::bv3, b, values(b))", + outbv: map[string]*querypb.BindVariable{ + "bv1": sqltypes.Int64BindVariable(1), + "bv2": sqltypes.Int64BindVariable(2), + "bv3": sqltypes.TestBindVariable([]any{1, 2}), + }, }} for _, tc := range testcases { t.Run(tc.in, func(t *testing.T) { diff --git a/go/vt/vtgate/engine/insert.go b/go/vt/vtgate/engine/insert.go index 83f07983f23..18dd6c1255a 100644 --- a/go/vt/vtgate/engine/insert.go +++ b/go/vt/vtgate/engine/insert.go @@ -765,13 +765,20 @@ func (ins *Insert) getInsertShardedRoute( index, _ := strconv.ParseInt(string(indexValue.Value), 0, 64) if keyspaceIDs[index] != nil { walkFunc := func(node sqlparser.SQLNode) (kontinue bool, err error) { - if arg, ok := node.(*sqlparser.Argument); ok { - bv, exists := bindVars[arg.Name] - if !exists { - return false, vterrors.VT03026(arg.Name) - } - shardBindVars[arg.Name] = bv + var arg string + switch argType := node.(type) { + case *sqlparser.Argument: + arg = argType.Name + case sqlparser.ListArg: + arg = string(argType) + default: + return true, nil } + bv, exists := bindVars[arg] + if !exists { + return false, vterrors.VT03026(arg) + } + shardBindVars[arg] = bv return true, nil } mids = append(mids, sqlparser.String(ins.Mid[index])) diff --git a/go/vt/vtgate/engine/insert_test.go b/go/vt/vtgate/engine/insert_test.go index 00b470f399b..5846f4ca97c 100644 --- a/go/vt/vtgate/engine/insert_test.go +++ b/go/vt/vtgate/engine/insert_test.go @@ -360,13 +360,22 @@ func TestInsertShardWithONDuplicateKey(t *testing.T) { {&sqlparser.Argument{Name: "_id_0", Type: sqltypes.Int64}}, }, sqlparser.OnDup{ - &sqlparser.UpdateExpr{Name: sqlparser.NewColName("suffix"), Expr: &sqlparser.Argument{Name: "_id_0", Type: sqltypes.Int64}}, - }, + &sqlparser.UpdateExpr{Name: sqlparser.NewColName("suffix1"), Expr: &sqlparser.Argument{Name: "_id_0", Type: sqltypes.Int64}}, + &sqlparser.UpdateExpr{Name: sqlparser.NewColName("suffix2"), Expr: &sqlparser.FuncExpr{ + Name: sqlparser.NewIdentifierCI("if"), + Exprs: sqlparser.SelectExprs{ + sqlparser.NewAliasedExpr(sqlparser.NewComparisonExpr(sqlparser.InOp, &sqlparser.ValuesFuncExpr{Name: sqlparser.NewColName("col")}, sqlparser.ListArg("_id_1"), nil), ""), + sqlparser.NewAliasedExpr(sqlparser.NewColName("col"), ""), + sqlparser.NewAliasedExpr(&sqlparser.ValuesFuncExpr{Name: sqlparser.NewColName("col")}, ""), + }, + }}}, ) vc := newDMLTestVCursor("-20", "20-") vc.shardForKsid = []string{"20-", "-20", "20-"} - _, err := ins.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) + _, err := ins.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{ + "_id_1": sqltypes.TestBindVariable([]int{1, 2}), + }, false) if err != nil { t.Fatal(err) } @@ -375,7 +384,10 @@ func TestInsertShardWithONDuplicateKey(t *testing.T) { `ResolveDestinations sharded [value:"0"] Destinations:DestinationKeyspaceID(166b40b44aba4bd6)`, // Row 2 will go to -20, rows 1 & 3 will go to 20- `ExecuteMultiShard ` + - `sharded.20-: prefix(:_id_0 /* INT64 */) on duplicate key update suffix = :_id_0 /* INT64 */ {_id_0: type:INT64 value:"1"} ` + + `sharded.20-: prefix(:_id_0 /* INT64 */) on duplicate key update ` + + `suffix1 = :_id_0 /* INT64 */, suffix2 = if(values(col) in ::_id_1, col, values(col)) ` + + `{_id_0: type:INT64 value:"1" ` + + `_id_1: type:TUPLE values:{type:INT64 value:"1"} values:{type:INT64 value:"2"}} ` + `true true`, })