diff --git a/go/test/endtoend/vtgate/queries/dml/insert_test.go b/go/test/endtoend/vtgate/queries/dml/insert_test.go index 1d09d3aab51..242b2813efc 100644 --- a/go/test/endtoend/vtgate/queries/dml/insert_test.go +++ b/go/test/endtoend/vtgate/queries/dml/insert_test.go @@ -54,6 +54,27 @@ func TestSimpleInsertSelect(t *testing.T) { utils.AssertMatches(t, mcmp.VtConn, `select num from num_vdx_tbl order by num`, `[[INT64(2)] [INT64(4)] [INT64(40)] [INT64(42)] [INT64(80)] [INT64(84)]]`) } +// TestInsertOnDup test the insert on duplicate key update feature with argument and list argument. +func TestInsertOnDup(t *testing.T) { + utils.SkipIfBinaryIsBelowVersion(t, 20, "vtgate") + + mcmp, closer := start(t) + defer closer() + + mcmp.Exec("insert into order_tbl(oid, region_id, cust_no) values (1,2,3),(3,4,5)") + + for _, mode := range []string{"oltp", "olap"} { + mcmp.Run(mode, func(mcmp *utils.MySQLCompare) { + utils.Exec(t, mcmp.VtConn, fmt.Sprintf("set workload = %s", mode)) + + mcmp.Exec(`insert into order_tbl(oid, region_id, cust_no) values (2,2,3),(4,4,5) on duplicate key update cust_no = if(values(cust_no) in (1, 2, 3), region_id, values(cust_no))`) + mcmp.Exec(`select oid, region_id, cust_no from order_tbl order by oid, region_id`) + mcmp.Exec(`insert into order_tbl(oid, region_id, cust_no) values (7,2,2) on duplicate key update cust_no = 10 + values(cust_no)`) + mcmp.Exec(`select oid, region_id, cust_no from order_tbl order by oid, region_id`) + }) + } +} + func TestFailureInsertSelect(t *testing.T) { if clusterInstance.HasPartialKeyspaces { t.Skip("don't run on partial keyspaces") diff --git a/go/vt/sqlparser/normalizer_test.go b/go/vt/sqlparser/normalizer_test.go index 58580a78701..009d138908e 100644 --- a/go/vt/sqlparser/normalizer_test.go +++ b/go/vt/sqlparser/normalizer_test.go @@ -379,6 +379,27 @@ func TestNormalize(t *testing.T) { "v1": sqltypes.HexValBindVariable([]byte("x'31'")), "v2": sqltypes.Int64BindVariable(31), }, +<<<<<<< HEAD +======= + }, { + // ORDER BY and GROUP BY variable + in: "select a, b from t group by 1, field(a,1,2,3) order by 1 asc, field(a,1,2,3)", + outstmt: "select a, b from t group by 1, field(a, :bv1 /* INT64 */, :bv2 /* INT64 */, :bv3 /* INT64 */) order by 1 asc, field(a, :bv1 /* INT64 */, :bv2 /* INT64 */, :bv3 /* INT64 */) asc", + outbv: map[string]*querypb.BindVariable{ + "bv1": sqltypes.Int64BindVariable(1), + "bv2": sqltypes.Int64BindVariable(2), + "bv3": sqltypes.Int64BindVariable(3), + }, + }, { + // list in on duplicate key update + in: "insert into t(a, b) values (1, 2) on duplicate key update b = if(values(b) in (1, 2), b, values(b))", + outstmt: "insert into t(a, b) values (:bv1 /* INT64 */, :bv2 /* INT64 */) on duplicate key update b = if(values(b) in ::bv3, b, values(b))", + outbv: map[string]*querypb.BindVariable{ + "bv1": sqltypes.Int64BindVariable(1), + "bv2": sqltypes.Int64BindVariable(2), + "bv3": sqltypes.TestBindVariable([]any{1, 2}), + }, +>>>>>>> 3b1800db4b (fix: insert on duplicate update to add list argument in the bind variables map (#15961)) }} for _, tc := range testcases { t.Run(tc.in, func(t *testing.T) { diff --git a/go/vt/vtgate/engine/insert.go b/go/vt/vtgate/engine/insert.go index 83f07983f23..18dd6c1255a 100644 --- a/go/vt/vtgate/engine/insert.go +++ b/go/vt/vtgate/engine/insert.go @@ -765,13 +765,20 @@ func (ins *Insert) getInsertShardedRoute( index, _ := strconv.ParseInt(string(indexValue.Value), 0, 64) if keyspaceIDs[index] != nil { walkFunc := func(node sqlparser.SQLNode) (kontinue bool, err error) { - if arg, ok := node.(*sqlparser.Argument); ok { - bv, exists := bindVars[arg.Name] - if !exists { - return false, vterrors.VT03026(arg.Name) - } - shardBindVars[arg.Name] = bv + var arg string + switch argType := node.(type) { + case *sqlparser.Argument: + arg = argType.Name + case sqlparser.ListArg: + arg = string(argType) + default: + return true, nil } + bv, exists := bindVars[arg] + if !exists { + return false, vterrors.VT03026(arg) + } + shardBindVars[arg] = bv return true, nil } mids = append(mids, sqlparser.String(ins.Mid[index])) diff --git a/go/vt/vtgate/engine/insert_test.go b/go/vt/vtgate/engine/insert_test.go index 00b470f399b..79b4af35522 100644 --- a/go/vt/vtgate/engine/insert_test.go +++ b/go/vt/vtgate/engine/insert_test.go @@ -232,7 +232,167 @@ func TestInsertShardedSimple(t *testing.T) { }) // Multiple rows are not autocommitted by default +<<<<<<< HEAD ins = NewInsert( +======= + ins = newInsert( + InsertSharded, + false, + ks.Keyspace, + [][][]evalengine.Expr{{ + // colVindex columns: id + // 3 rows. + { + evalengine.NewLiteralInt(1), + evalengine.NewLiteralInt(2), + evalengine.NewLiteralInt(3), + }, + }}, + ks.Tables["t1"], + "prefix", + sqlparser.Values{ + {&sqlparser.Argument{Name: "_id_0", Type: sqltypes.Int64}}, + {&sqlparser.Argument{Name: "_id_1", Type: sqltypes.Int64}}, + {&sqlparser.Argument{Name: "_id_2", Type: sqltypes.Int64}}, + }, + nil, + ) + vc = newDMLTestVCursor("-20", "20-") + vc.shardForKsid = []string{"20-", "-20", "20-"} + + _, err = ins.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) + if err != nil { + t.Fatal(err) + } + vc.ExpectLog(t, []string{ + // Based on shardForKsid, values returned will be 20-, -20, 20-. + `ResolveDestinations sharded [value:"0" value:"1" value:"2"] Destinations:DestinationKeyspaceID(166b40b44aba4bd6),DestinationKeyspaceID(06e7ea22ce92708f),DestinationKeyspaceID(4eb190c9a2fa169c)`, + // Row 2 will go to -20, rows 1 & 3 will go to 20- + `ExecuteMultiShard ` + + `sharded.20-: prefix(:_id_0 /* INT64 */),(:_id_2 /* INT64 */) {_id_0: type:INT64 value:"1" _id_2: type:INT64 value:"3"} ` + + `sharded.-20: prefix(:_id_1 /* INT64 */) {_id_1: type:INT64 value:"2"} ` + + `true false`, + }) + + // Optional flag overrides autocommit + ins = newInsert( + InsertSharded, + false, + ks.Keyspace, + [][][]evalengine.Expr{{ + // colVindex columns: id + // 3 rows. + { + evalengine.NewLiteralInt(1), + evalengine.NewLiteralInt(2), + evalengine.NewLiteralInt(3), + }, + }}, + + ks.Tables["t1"], + "prefix", + sqlparser.Values{ + {&sqlparser.Argument{Name: "_id_0", Type: sqltypes.Int64}}, + {&sqlparser.Argument{Name: "_id_1", Type: sqltypes.Int64}}, + {&sqlparser.Argument{Name: "_id_2", Type: sqltypes.Int64}}, + }, + nil, + ) + ins.MultiShardAutocommit = true + + vc = newDMLTestVCursor("-20", "20-") + vc.shardForKsid = []string{"20-", "-20", "20-"} + + _, err = ins.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, false) + if err != nil { + t.Fatal(err) + } + vc.ExpectLog(t, []string{ + // Based on shardForKsid, values returned will be 20-, -20, 20-. + `ResolveDestinations sharded [value:"0" value:"1" value:"2"] Destinations:DestinationKeyspaceID(166b40b44aba4bd6),DestinationKeyspaceID(06e7ea22ce92708f),DestinationKeyspaceID(4eb190c9a2fa169c)`, + // Row 2 will go to -20, rows 1 & 3 will go to 20- + `ExecuteMultiShard ` + + `sharded.20-: prefix(:_id_0 /* INT64 */),(:_id_2 /* INT64 */) {_id_0: type:INT64 value:"1" _id_2: type:INT64 value:"3"} ` + + `sharded.-20: prefix(:_id_1 /* INT64 */) {_id_1: type:INT64 value:"2"} ` + + `true true`, + }) +} + +func TestInsertShardWithONDuplicateKey(t *testing.T) { + invschema := &vschemapb.SrvVSchema{ + Keyspaces: map[string]*vschemapb.Keyspace{ + "sharded": { + Sharded: true, + Vindexes: map[string]*vschemapb.Vindex{ + "hash": { + Type: "hash", + }, + }, + Tables: map[string]*vschemapb.Table{ + "t1": { + ColumnVindexes: []*vschemapb.ColumnVindex{{ + Name: "hash", + Columns: []string{"id"}, + }}, + }, + }, + }, + }, + } + vs := vindexes.BuildVSchema(invschema, sqlparser.NewTestParser()) + ks := vs.Keyspaces["sharded"] + + // A single row insert should be autocommitted + ins := newInsert( + InsertSharded, + false, + ks.Keyspace, + [][][]evalengine.Expr{{ + // colVindex columns: id + { + evalengine.NewLiteralInt(1), + }, + }}, + ks.Tables["t1"], + "prefix", + sqlparser.Values{ + {&sqlparser.Argument{Name: "_id_0", Type: sqltypes.Int64}}, + }, + sqlparser.OnDup{ + &sqlparser.UpdateExpr{Name: sqlparser.NewColName("suffix1"), Expr: &sqlparser.Argument{Name: "_id_0", Type: sqltypes.Int64}}, + &sqlparser.UpdateExpr{Name: sqlparser.NewColName("suffix2"), Expr: &sqlparser.FuncExpr{ + Name: sqlparser.NewIdentifierCI("if"), + Exprs: sqlparser.Exprs{ + sqlparser.NewComparisonExpr(sqlparser.InOp, &sqlparser.ValuesFuncExpr{Name: sqlparser.NewColName("col")}, sqlparser.ListArg("_id_1"), nil), + sqlparser.NewColName("col"), + &sqlparser.ValuesFuncExpr{Name: sqlparser.NewColName("col")}, + }, + }}}, + ) + vc := newDMLTestVCursor("-20", "20-") + vc.shardForKsid = []string{"20-", "-20", "20-"} + + _, err := ins.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{ + "_id_1": sqltypes.TestBindVariable([]int{1, 2}), + }, false) + if err != nil { + t.Fatal(err) + } + vc.ExpectLog(t, []string{ + // Based on shardForKsid, values returned will be 20-. + `ResolveDestinations sharded [value:"0"] Destinations:DestinationKeyspaceID(166b40b44aba4bd6)`, + // Row 2 will go to -20, rows 1 & 3 will go to 20- + `ExecuteMultiShard ` + + `sharded.20-: prefix(:_id_0 /* INT64 */) on duplicate key update ` + + `suffix1 = :_id_0 /* INT64 */, suffix2 = if(values(col) in ::_id_1, col, values(col)) ` + + `{_id_0: type:INT64 value:"1" ` + + `_id_1: type:TUPLE values:{type:INT64 value:"1"} values:{type:INT64 value:"2"}} ` + + `true true`, + }) + + // Multiple rows are not autocommitted by default + ins = newInsert( +>>>>>>> 3b1800db4b (fix: insert on duplicate update to add list argument in the bind variables map (#15961)) InsertSharded, false, ks.Keyspace,