diff --git a/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go b/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go index da2e14218fe..05a45abac69 100644 --- a/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go +++ b/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go @@ -40,7 +40,21 @@ func start(t *testing.T) (utils.MySQLCompare, func()) { deleteAll := func() { _, _ = utils.ExecAllowError(t, mcmp.VtConn, "set workload = oltp") - tables := []string{"t9", "aggr_test", "t3", "t7_xxhash", "aggr_test_dates", "t7_xxhash_idx", "t1", "t2", "t10"} + tables := []string{ + "t3", + "t3_id7_idx", + "t9", + "aggr_test", + "aggr_test_dates", + "t7_xxhash", + "t7_xxhash_idx", + "t1", + "t2", + "t10", + "emp", + "dept", + "bet_logs", + } for _, table := range tables { _, _ = mcmp.ExecAndIgnore("delete from " + table) } @@ -673,3 +687,84 @@ func TestDistinctAggregation(t *testing.T) { }) } } + +func TestHavingQueries(t *testing.T) { + mcmp, closer := start(t) + defer closer() + + inserts := []string{ + `INSERT INTO emp (empno, ename, job, mgr, hiredate, sal, comm, deptno) VALUES + (1, 'John', 'Manager', NULL, '2022-01-01', 5000, 500, 1), + (2, 'Doe', 'Analyst', 1, '2023-01-01', 4500, NULL, 1), + (3, 'Jane', 'Clerk', 1, '2023-02-01', 3000, 200, 2), + (4, 'Mary', 'Analyst', 2, '2022-03-01', 4700, NULL, 1), + (5, 'Smith', 'Salesman', 3, '2023-01-15', 3200, 300, 3)`, + "INSERT INTO dept (deptno, dname, loc) VALUES (1, 'IT', 'New York'), (2, 'HR', 'London'), (3, 'Sales', 'San Francisco')", + "INSERT INTO t1 (t1_id, name, value, shardKey) VALUES (1, 'Name1', 'Value1', 100), (2, 'Name2', 'Value2', 100), (3, 'Name1', 'Value3', 200)", + "INSERT INTO aggr_test_dates (id, val1, val2) VALUES (1, '2023-01-01', '2023-01-02'), (2, '2023-02-01', '2023-02-02'), (3, '2023-03-01', '2023-03-02')", + "INSERT INTO t10 (k, a, b) VALUES (1, 10, 20), (2, 30, 40), (3, 50, 60)", + "INSERT INTO t3 (id5, id6, id7) VALUES (1, 10, 100), (2, 20, 200), (3, 30, 300)", + "INSERT INTO t9 (id1, id2, id3) VALUES (1, 'A1', 'B1'), (2, 'A2', 'B2'), (3, 'A1', 'B3')", + "INSERT INTO aggr_test (id, val1, val2) VALUES (1, 'Test1', 100), (2, 'Test2', 200), (3, 'Test1', 300), (4, 'Test3', 400)", + "INSERT INTO t2 (id, shardKey) VALUES (1, 100), (2, 200), (3, 300)", + `INSERT INTO bet_logs (id, merchant_game_id, bet_amount, game_id) VALUES + (1, 1, 100.0, 10), + (2, 1, 200.0, 11), + (3, 2, 300.0, 10), + (4, 3, 400.0, 12)`, + } + + for _, insert := range inserts { + mcmp.Exec(insert) + } + + queries := []string{ + // The following queries are not allowed by MySQL but Vitess allows them + // SELECT ename FROM emp GROUP BY ename HAVING sal > 5000 + // SELECT val1, COUNT(val2) FROM aggr_test_dates GROUP BY val1 HAVING val2 > 5 + // SELECT k, a FROM t10 GROUP BY k HAVING b > 2 + // SELECT loc FROM dept GROUP BY loc HAVING COUNT(deptno) AND dname = 'Sales' + // SELECT AVG(val2) AS average_val2 FROM aggr_test HAVING val1 = 'Test' + + // these first queries are all failing in different ways. let's check that Vitess also fails + + "SELECT deptno, AVG(sal) AS average_salary HAVING average_salary > 5000 FROM emp", + "SELECT job, COUNT(empno) AS num_employees FROM emp HAVING num_employees > 2", + "SELECT dname, SUM(sal) FROM dept JOIN emp ON dept.deptno = emp.deptno HAVING AVG(sal) > 6000", + "SELECT COUNT(*) AS count FROM emp WHERE count > 5", + "SELECT `name`, AVG(`value`) FROM t1 GROUP BY `name` HAVING `name`", + "SELECT empno, MAX(sal) FROM emp HAVING COUNT(*) > 3", + "SELECT id, SUM(bet_amount) AS total_bets FROM bet_logs HAVING total_bets > 1000", + "SELECT merchant_game_id FROM bet_logs GROUP BY merchant_game_id HAVING SUM(bet_amount)", + "SELECT shardKey, COUNT(id) FROM t2 HAVING shardKey > 100", + "SELECT deptno FROM emp GROUP BY deptno HAVING MAX(hiredate) > '2020-01-01'", + + // These queries should not fail + "SELECT deptno, COUNT(*) AS num_employees FROM emp GROUP BY deptno HAVING num_employees > 5", + "SELECT ename, SUM(sal) FROM emp GROUP BY ename HAVING SUM(sal) > 10000", + "SELECT dname, AVG(sal) AS average_salary FROM emp JOIN dept ON emp.deptno = dept.deptno GROUP BY dname HAVING average_salary > 5000", + "SELECT dname, MAX(sal) AS max_salary FROM emp JOIN dept ON emp.deptno = dept.deptno GROUP BY dname HAVING max_salary < 10000", + "SELECT YEAR(hiredate) AS year, COUNT(*) FROM emp GROUP BY year HAVING COUNT(*) > 2", + "SELECT mgr, COUNT(empno) AS managed_employees FROM emp WHERE mgr IS NOT NULL GROUP BY mgr HAVING managed_employees >= 3", + "SELECT deptno, SUM(comm) AS total_comm FROM emp GROUP BY deptno HAVING total_comm > AVG(total_comm)", + "SELECT id2, COUNT(*) AS count FROM t9 GROUP BY id2 HAVING count > 1", + "SELECT val1, COUNT(*) FROM aggr_test GROUP BY val1 HAVING COUNT(*) > 1", + "SELECT DATE(val1) AS date, SUM(val2) FROM aggr_test_dates GROUP BY date HAVING SUM(val2) > 100", + "SELECT shardKey, AVG(`value`) FROM t1 WHERE `value` IS NOT NULL GROUP BY shardKey HAVING AVG(`value`) > 10", + "SELECT job, COUNT(*) AS job_count FROM emp GROUP BY job HAVING job_count > 3", + "SELECT b, AVG(a) AS avg_a FROM t10 GROUP BY b HAVING AVG(a) > 5", + "SELECT merchant_game_id, SUM(bet_amount) AS total_bets FROM bet_logs GROUP BY merchant_game_id HAVING total_bets > 1000", + "SELECT loc, COUNT(deptno) AS num_depts FROM dept GROUP BY loc HAVING num_depts > 1", + "SELECT `name`, COUNT(*) AS name_count FROM t1 GROUP BY `name` HAVING name_count > 2", + "SELECT COUNT(*) AS num_jobs FROM emp GROUP BY empno HAVING num_jobs > 1", + "SELECT id, COUNT(*) AS count FROM t2 GROUP BY id HAVING count > 1", + "SELECT val2, SUM(id) FROM aggr_test GROUP BY val2 HAVING SUM(id) > 10", + "SELECT game_id, COUNT(*) AS num_logs FROM bet_logs GROUP BY game_id HAVING num_logs > 5", + } + + for _, query := range queries { + mcmp.Run(query, func(mcmp *utils.MySQLCompare) { + mcmp.ExecAllowAndCompareError(query) + }) + } +} diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index d73445e4160..8f3e436deb8 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -1932,7 +1932,7 @@ func TestSelectScatterOrderBy(t *testing.T) { require.NoError(t, err) wantQueries := []*querypb.BoundQuery{{ - Sql: "select col1, col2, weight_string(col2) from `user` order by col2 desc", + Sql: "select col1, col2, weight_string(col2) from `user` order by `user`.col2 desc", BindVariables: map[string]*querypb.BindVariable{}, }} for _, conn := range conns { @@ -2005,7 +2005,7 @@ func TestSelectScatterOrderByVarChar(t *testing.T) { require.NoError(t, err) wantQueries := []*querypb.BoundQuery{{ - Sql: "select col1, textcol, weight_string(textcol) from `user` order by textcol desc", + Sql: "select col1, textcol, weight_string(textcol) from `user` order by `user`.textcol desc", BindVariables: map[string]*querypb.BindVariable{}, }} for _, conn := range conns { @@ -2071,7 +2071,7 @@ func TestStreamSelectScatterOrderBy(t *testing.T) { require.NoError(t, err) wantQueries := []*querypb.BoundQuery{{ - Sql: "select id, col, weight_string(col) from `user` order by col desc", + Sql: "select id, col, weight_string(col) from `user` order by `user`.col desc", BindVariables: map[string]*querypb.BindVariable{}, }} for _, conn := range conns { @@ -2133,7 +2133,7 @@ func TestStreamSelectScatterOrderByVarChar(t *testing.T) { require.NoError(t, err) wantQueries := []*querypb.BoundQuery{{ - Sql: "select id, textcol, weight_string(textcol) from `user` order by textcol desc", + Sql: "select id, textcol, weight_string(textcol) from `user` order by `user`.textcol desc", BindVariables: map[string]*querypb.BindVariable{}, }} for _, conn := range conns { @@ -2329,7 +2329,7 @@ func TestSelectScatterLimit(t *testing.T) { require.NoError(t, err) wantQueries := []*querypb.BoundQuery{{ - Sql: "select col1, col2, weight_string(col2) from `user` order by col2 desc limit :__upper_limit", + Sql: "select col1, col2, weight_string(col2) from `user` order by `user`.col2 desc limit :__upper_limit", BindVariables: map[string]*querypb.BindVariable{"__upper_limit": sqltypes.Int64BindVariable(3)}, }} for _, conn := range conns { @@ -2401,7 +2401,7 @@ func TestStreamSelectScatterLimit(t *testing.T) { require.NoError(t, err) wantQueries := []*querypb.BoundQuery{{ - Sql: "select col1, col2, weight_string(col2) from `user` order by col2 desc limit :__upper_limit", + Sql: "select col1, col2, weight_string(col2) from `user` order by `user`.col2 desc limit :__upper_limit", BindVariables: map[string]*querypb.BindVariable{"__upper_limit": sqltypes.Int64BindVariable(3)}, }} for _, conn := range conns { diff --git a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json index 5acbfbe61bc..e98f53bb4cf 100644 --- a/go/vt/vtgate/planbuilder/testdata/aggr_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/aggr_cases.json @@ -675,10 +675,15 @@ } }, { - "comment": "scatter aggregate group by aggregate function", + "comment": "scatter aggregate group by aggregate function - since we don't have authoratative columns for user, we can't be sure that the user isn't referring a column named b", "query": "select count(*) b from user group by b", "plan": "VT03005: cannot group on 'count(*)'" }, + { + "comment": "scatter aggregate group by aggregate function with column information", + "query": "select count(*) b from authoritative group by b", + "plan": "VT03005: cannot group on 'b'" + }, { "comment": "scatter aggregate multiple group by (columns)", "query": "select a, b, count(*) from user group by a, b", @@ -893,7 +898,7 @@ }, "FieldQuery": "select a, b, c, d, count(*), weight_string(a), weight_string(b), weight_string(c) from `user` where 1 != 1 group by a, b, c, weight_string(a), weight_string(b), weight_string(c)", "OrderBy": "(0|5) ASC, (1|6) ASC, (2|7) ASC", - "Query": "select a, b, c, d, count(*), weight_string(a), weight_string(b), weight_string(c) from `user` group by a, b, c, weight_string(a), weight_string(b), weight_string(c) order by a asc, b asc, c asc", + "Query": "select a, b, c, d, count(*), weight_string(a), weight_string(b), weight_string(c) from `user` group by a, b, c, weight_string(a), weight_string(b), weight_string(c) order by `user`.a asc, `user`.b asc, `user`.c asc", "Table": "`user`" } ] @@ -925,7 +930,7 @@ }, "FieldQuery": "select a, b, c, d, count(*), weight_string(d), weight_string(b), weight_string(a), weight_string(c) from `user` where 1 != 1 group by d, b, a, c, weight_string(d), weight_string(b), weight_string(a), weight_string(c)", "OrderBy": "(3|5) ASC, (1|6) ASC, (0|7) ASC, (2|8) ASC", - "Query": "select a, b, c, d, count(*), weight_string(d), weight_string(b), weight_string(a), weight_string(c) from `user` group by d, b, a, c, weight_string(d), weight_string(b), weight_string(a), weight_string(c) order by d asc, b asc, a asc, c asc", + "Query": "select a, b, c, d, count(*), weight_string(d), weight_string(b), weight_string(a), weight_string(c) from `user` group by d, b, a, c, weight_string(d), weight_string(b), weight_string(a), weight_string(c) order by `user`.d asc, `user`.b asc, `user`.a asc, `user`.c asc", "Table": "`user`" } ] @@ -957,7 +962,7 @@ }, "FieldQuery": "select a, b, c, d, count(*), weight_string(d), weight_string(b), weight_string(a), weight_string(c) from `user` where 1 != 1 group by d, b, a, c, weight_string(d), weight_string(b), weight_string(a), weight_string(c)", "OrderBy": "(3|5) ASC, (1|6) ASC, (0|7) ASC, (2|8) ASC", - "Query": "select a, b, c, d, count(*), weight_string(d), weight_string(b), weight_string(a), weight_string(c) from `user` group by d, b, a, c, weight_string(d), weight_string(b), weight_string(a), weight_string(c) order by d asc, b asc, a asc, c asc", + "Query": "select a, b, c, d, count(*), weight_string(d), weight_string(b), weight_string(a), weight_string(c) from `user` group by d, b, a, c, weight_string(d), weight_string(b), weight_string(a), weight_string(c) order by `user`.d asc, `user`.b asc, `user`.a asc, `user`.c asc", "Table": "`user`" } ] @@ -989,7 +994,7 @@ }, "FieldQuery": "select a, b, c, count(*), weight_string(a), weight_string(c), weight_string(b) from `user` where 1 != 1 group by a, c, b, weight_string(a), weight_string(c), weight_string(b)", "OrderBy": "(0|4) DESC, (2|5) DESC, (1|6) ASC", - "Query": "select a, b, c, count(*), weight_string(a), weight_string(c), weight_string(b) from `user` group by a, c, b, weight_string(a), weight_string(c), weight_string(b) order by a desc, c desc, b asc", + "Query": "select a, b, c, count(*), weight_string(a), weight_string(c), weight_string(b) from `user` group by a, c, b, weight_string(a), weight_string(c), weight_string(b) order by a desc, c desc, `user`.b asc", "Table": "`user`" } ] @@ -1041,32 +1046,6 @@ ] } }, - { - "comment": "Group by with collate operator", - "query": "select user.col1 as a from user where user.id = 5 group by a collate utf8_general_ci", - "plan": { - "QueryType": "SELECT", - "Original": "select user.col1 as a from user where user.id = 5 group by a collate utf8_general_ci", - "Instructions": { - "OperatorType": "Route", - "Variant": "EqualUnique", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select `user`.col1 as a from `user` where 1 != 1 group by `user`.col1 collate utf8_general_ci", - "Query": "select `user`.col1 as a from `user` where `user`.id = 5 group by `user`.col1 collate utf8_general_ci", - "Table": "`user`", - "Values": [ - "5" - ], - "Vindex": "user_index" - }, - "TablesUsed": [ - "user.user" - ] - } - }, { "comment": "routing rules for aggregates", "query": "select id, count(*) from route2 group by id", @@ -1103,7 +1082,7 @@ "Sharded": true }, "FieldQuery": "select col from ref where 1 != 1", - "Query": "select col from ref order by col asc", + "Query": "select col from ref order by ref.col asc", "Table": "ref" }, "TablesUsed": [ @@ -1584,10 +1563,10 @@ }, { "comment": "weight_string addition to group by", - "query": "select lower(textcol1) as v, count(*) from user group by v", + "query": "select lower(col1) as v, count(*) from authoritative group by v", "plan": { "QueryType": "SELECT", - "Original": "select lower(textcol1) as v, count(*) from user group by v", + "Original": "select lower(col1) as v, count(*) from authoritative group by v", "Instructions": { "OperatorType": "Aggregate", "Variant": "Ordered", @@ -1602,24 +1581,24 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select lower(textcol1) as v, count(*), weight_string(lower(textcol1)) from `user` where 1 != 1 group by lower(textcol1), weight_string(lower(textcol1))", + "FieldQuery": "select lower(col1) as v, count(*), weight_string(lower(col1)) from authoritative where 1 != 1 group by lower(col1), weight_string(lower(col1))", "OrderBy": "(0|2) ASC", - "Query": "select lower(textcol1) as v, count(*), weight_string(lower(textcol1)) from `user` group by lower(textcol1), weight_string(lower(textcol1)) order by lower(textcol1) asc", - "Table": "`user`" + "Query": "select lower(col1) as v, count(*), weight_string(lower(col1)) from authoritative group by lower(col1), weight_string(lower(col1)) order by lower(col1) asc", + "Table": "authoritative" } ] }, "TablesUsed": [ - "user.user" + "user.authoritative" ] } }, { "comment": "weight_string addition to group by when also there in order by", - "query": "select char_length(texcol1) as a, count(*) from user group by a order by a", + "query": "select char_length(col1) as a, count(*) from authoritative group by a order by a", "plan": { "QueryType": "SELECT", - "Original": "select char_length(texcol1) as a, count(*) from user group by a order by a", + "Original": "select char_length(col1) as a, count(*) from authoritative group by a order by a", "Instructions": { "OperatorType": "Aggregate", "Variant": "Ordered", @@ -1634,15 +1613,15 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select char_length(texcol1) as a, count(*), weight_string(char_length(texcol1)) from `user` where 1 != 1 group by char_length(texcol1), weight_string(char_length(texcol1))", + "FieldQuery": "select char_length(col1) as a, count(*), weight_string(char_length(col1)) from authoritative where 1 != 1 group by char_length(col1), weight_string(char_length(col1))", "OrderBy": "(0|2) ASC", - "Query": "select char_length(texcol1) as a, count(*), weight_string(char_length(texcol1)) from `user` group by char_length(texcol1), weight_string(char_length(texcol1)) order by char_length(texcol1) asc", - "Table": "`user`" + "Query": "select char_length(col1) as a, count(*), weight_string(char_length(col1)) from authoritative group by char_length(col1), weight_string(char_length(col1)) order by char_length(authoritative.col1) asc", + "Table": "authoritative" } ] }, "TablesUsed": [ - "user.user" + "user.authoritative" ] } }, @@ -1699,7 +1678,7 @@ }, "FieldQuery": "select col, id, weight_string(id) from `user` where 1 != 1", "OrderBy": "(1|2) ASC", - "Query": "select col, id, weight_string(id) from `user` order by id asc", + "Query": "select col, id, weight_string(id) from `user` order by `user`.id asc", "ResultColumns": 2, "Table": "`user`" }, @@ -2009,19 +1988,20 @@ }, { "comment": "Less Equal filter on scatter with grouping", - "query": "select col, count(*) a from user group by col having a <= 10", + "query": "select col1, count(*) a from user group by col1 having a <= 10", "plan": { "QueryType": "SELECT", - "Original": "select col, count(*) a from user group by col having a <= 10", + "Original": "select col1, count(*) a from user group by col1 having a <= 10", "Instructions": { "OperatorType": "Filter", "Predicate": "count(*) <= 10", + "ResultColumns": 2, "Inputs": [ { "OperatorType": "Aggregate", "Variant": "Ordered", "Aggregates": "sum_count_star(1) AS a", - "GroupBy": "0", + "GroupBy": "(0|2)", "Inputs": [ { "OperatorType": "Route", @@ -2030,9 +2010,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select col, count(*) as a from `user` where 1 != 1 group by col", - "OrderBy": "0 ASC", - "Query": "select col, count(*) as a from `user` group by col order by col asc", + "FieldQuery": "select col1, count(*) as a, weight_string(col1) from `user` where 1 != 1 group by col1, weight_string(col1)", + "OrderBy": "(0|2) ASC", + "Query": "select col1, count(*) as a, weight_string(col1) from `user` group by col1, weight_string(col1) order by col1 asc", "Table": "`user`" } ] @@ -2046,10 +2026,10 @@ }, { "comment": "We should be able to find grouping keys on ordered aggregates", - "query": "select count(*) as a, val1 from user group by val1 having a = 1.00", + "query": "select count(*) as a, col2 from user group by col2 having a = 1.00", "plan": { "QueryType": "SELECT", - "Original": "select count(*) as a, val1 from user group by val1 having a = 1.00", + "Original": "select count(*) as a, col2 from user group by col2 having a = 1.00", "Instructions": { "OperatorType": "Filter", "Predicate": "count(*) = 1.00", @@ -2068,9 +2048,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select count(*) as a, val1, weight_string(val1) from `user` where 1 != 1 group by val1, weight_string(val1)", + "FieldQuery": "select count(*) as a, col2, weight_string(col2) from `user` where 1 != 1 group by col2, weight_string(col2)", "OrderBy": "(1|2) ASC", - "Query": "select count(*) as a, val1, weight_string(val1) from `user` group by val1, weight_string(val1) order by val1 asc", + "Query": "select count(*) as a, col2, weight_string(col2) from `user` group by col2, weight_string(col2) order by col2 asc", "Table": "`user`" } ] @@ -2620,10 +2600,10 @@ }, { "comment": "group by column alias", - "query": "select ascii(val1) as a, count(*) from user group by a", + "query": "select ascii(col2) as a, count(*) from user group by a", "plan": { "QueryType": "SELECT", - "Original": "select ascii(val1) as a, count(*) from user group by a", + "Original": "select ascii(col2) as a, count(*) from user group by a", "Instructions": { "OperatorType": "Aggregate", "Variant": "Ordered", @@ -2638,9 +2618,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select ascii(val1) as a, count(*), weight_string(ascii(val1)) from `user` where 1 != 1 group by ascii(val1), weight_string(ascii(val1))", + "FieldQuery": "select ascii(col2) as a, count(*), weight_string(ascii(col2)) from `user` where 1 != 1 group by ascii(col2), weight_string(ascii(col2))", "OrderBy": "(0|2) ASC", - "Query": "select ascii(val1) as a, count(*), weight_string(ascii(val1)) from `user` group by ascii(val1), weight_string(ascii(val1)) order by ascii(val1) asc", + "Query": "select ascii(col2) as a, count(*), weight_string(ascii(col2)) from `user` group by ascii(col2), weight_string(ascii(col2)) order by ascii(col2) asc", "Table": "`user`" } ] @@ -2984,7 +2964,7 @@ "Original": "select foo, sum(foo) as fooSum, sum(bar) as barSum from user group by foo having fooSum+sum(bar) = 42", "Instructions": { "OperatorType": "Filter", - "Predicate": "sum(foo) + sum(bar) = 42", + "Predicate": "sum(`user`.foo) + sum(bar) = 42", "ResultColumns": 3, "Inputs": [ { @@ -3328,10 +3308,10 @@ }, { "comment": "group by and ',' joins", - "query": "select user.id from user, user_extra group by id", + "query": "select user.id from user, user_extra group by user.id", "plan": { "QueryType": "SELECT", - "Original": "select user.id from user, user_extra group by id", + "Original": "select user.id from user, user_extra group by user.id", "Instructions": { "OperatorType": "Aggregate", "Variant": "Ordered", @@ -3555,7 +3535,7 @@ }, "FieldQuery": "select x.id, x.val1, 1, weight_string(x.val1) from (select id, val1 from `user` where 1 != 1) as x where 1 != 1", "OrderBy": "(1|3) ASC", - "Query": "select x.id, x.val1, 1, weight_string(x.val1) from (select id, val1 from `user` where val2 < 4) as x order by val1 asc limit :__upper_limit", + "Query": "select x.id, x.val1, 1, weight_string(x.val1) from (select id, val1 from `user` where val2 < 4) as x order by `user`.val1 asc limit :__upper_limit", "Table": "`user`" } ] @@ -6938,5 +6918,45 @@ "user.user" ] } + }, + { + "comment": "col is a column on user, but the HAVING is referring to an alias", + "query": "select sum(x) col from user where x > 0 having col = 2", + "plan": { + "QueryType": "SELECT", + "Original": "select sum(x) col from user where x > 0 having col = 2", + "Instructions": { + "OperatorType": "Filter", + "Predicate": "sum(`user`.x) = 2", + "Inputs": [ + { + "OperatorType": "Aggregate", + "Variant": "Scalar", + "Aggregates": "sum(0) AS col", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select sum(x) as col from `user` where 1 != 1", + "Query": "select sum(x) as col from `user` where x > 0", + "Table": "`user`" + } + ] + } + ] + }, + "TablesUsed": [ + "user.user" + ] + } + }, + { + "comment": "baz in the HAVING clause can't be accessed because of the GROUP BY", + "query": "select foo, count(bar) as x from user group by foo having baz > avg(baz) order by x", + "plan": "Unknown column 'baz' in 'having clause'" } ] diff --git a/go/vt/vtgate/planbuilder/testdata/cte_cases.json b/go/vt/vtgate/planbuilder/testdata/cte_cases.json index a7027f80348..0d7d9020ac2 100644 --- a/go/vt/vtgate/planbuilder/testdata/cte_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/cte_cases.json @@ -348,7 +348,7 @@ }, "FieldQuery": "select x.id, x.val1, 1, weight_string(x.val1) from (select id, val1 from `user` where 1 != 1) as x where 1 != 1", "OrderBy": "(1|3) ASC", - "Query": "select x.id, x.val1, 1, weight_string(x.val1) from (select id, val1 from `user` where val2 < 4) as x order by val1 asc limit :__upper_limit", + "Query": "select x.id, x.val1, 1, weight_string(x.val1) from (select id, val1 from `user` where val2 < 4) as x order by `user`.val1 asc limit :__upper_limit", "Table": "`user`" } ] diff --git a/go/vt/vtgate/planbuilder/testdata/filter_cases.json b/go/vt/vtgate/planbuilder/testdata/filter_cases.json index 4353f31fd48..d144da2441d 100644 --- a/go/vt/vtgate/planbuilder/testdata/filter_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/filter_cases.json @@ -4018,7 +4018,7 @@ "Sharded": true }, "FieldQuery": "select a + 2 as a from `user` where 1 != 1", - "Query": "select a + 2 as a from `user` where a + 2 = 42", + "Query": "select a + 2 as a from `user` where `user`.a + 2 = 42", "Table": "`user`" }, "TablesUsed": [ @@ -4041,7 +4041,7 @@ }, "FieldQuery": "select a + 2 as a, weight_string(a + 2) from `user` where 1 != 1", "OrderBy": "(0|1) ASC", - "Query": "select a + 2 as a, weight_string(a + 2) from `user` order by a + 2 asc", + "Query": "select a + 2 as a, weight_string(a + 2) from `user` order by `user`.a + 2 asc", "ResultColumns": 1, "Table": "`user`" }, diff --git a/go/vt/vtgate/planbuilder/testdata/memory_sort_cases.json b/go/vt/vtgate/planbuilder/testdata/memory_sort_cases.json index 3ca7e1059e4..4a879997925 100644 --- a/go/vt/vtgate/planbuilder/testdata/memory_sort_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/memory_sort_cases.json @@ -24,9 +24,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select a, b, count(*), weight_string(a), weight_string(b) from `user` where 1 != 1 group by a, weight_string(a)", + "FieldQuery": "select a, b, count(*), weight_string(a), weight_string(`user`.b) from `user` where 1 != 1 group by a, weight_string(a)", "OrderBy": "(0|3) ASC", - "Query": "select a, b, count(*), weight_string(a), weight_string(b) from `user` group by a, weight_string(a) order by a asc", + "Query": "select a, b, count(*), weight_string(a), weight_string(`user`.b) from `user` group by a, weight_string(a) order by a asc", "Table": "`user`" } ] @@ -102,9 +102,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select a, b, count(*) as k, weight_string(a), weight_string(b) from `user` where 1 != 1 group by a, weight_string(a)", + "FieldQuery": "select a, b, count(*) as k, weight_string(a), weight_string(`user`.b) from `user` where 1 != 1 group by a, weight_string(a)", "OrderBy": "(0|3) ASC", - "Query": "select a, b, count(*) as k, weight_string(a), weight_string(b) from `user` group by a, weight_string(a) order by a asc", + "Query": "select a, b, count(*) as k, weight_string(a), weight_string(`user`.b) from `user` group by a, weight_string(a) order by a asc", "Table": "`user`" } ] @@ -259,7 +259,7 @@ }, "FieldQuery": "select id, weight_string(id) from (select `user`.id, `user`.col from `user` where 1 != 1) as t where 1 != 1", "OrderBy": "(0|1) ASC", - "Query": "select id, weight_string(id) from (select `user`.id, `user`.col from `user`) as t order by id asc", + "Query": "select id, weight_string(id) from (select `user`.id, `user`.col from `user`) as t order by t.id asc", "Table": "`user`" }, { @@ -552,9 +552,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select a, convert(a, binary), weight_string(convert(a, binary)) from `user` where 1 != 1", + "FieldQuery": "select a, convert(`user`.a, binary), weight_string(convert(`user`.a, binary)) from `user` where 1 != 1", "OrderBy": "(1|2) DESC", - "Query": "select a, convert(a, binary), weight_string(convert(a, binary)) from `user` order by convert(a, binary) desc", + "Query": "select a, convert(`user`.a, binary), weight_string(convert(`user`.a, binary)) from `user` order by convert(`user`.a, binary) desc", "ResultColumns": 1, "Table": "`user`" }, @@ -624,7 +624,7 @@ }, "FieldQuery": "select id, intcol from `user` where 1 != 1", "OrderBy": "1 ASC", - "Query": "select id, intcol from `user` order by intcol asc", + "Query": "select id, intcol from `user` order by `user`.intcol asc", "Table": "`user`" }, "TablesUsed": [ diff --git a/go/vt/vtgate/planbuilder/testdata/oltp_cases.json b/go/vt/vtgate/planbuilder/testdata/oltp_cases.json index 3af909415f9..45f1ac8c618 100644 --- a/go/vt/vtgate/planbuilder/testdata/oltp_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/oltp_cases.json @@ -91,7 +91,7 @@ }, "FieldQuery": "select c from sbtest1 where 1 != 1", "OrderBy": "0 ASC COLLATE latin1_swedish_ci", - "Query": "select c from sbtest1 where id between 50 and 235 order by c asc", + "Query": "select c from sbtest1 where id between 50 and 235 order by sbtest1.c asc", "Table": "sbtest1" }, "TablesUsed": [ @@ -119,7 +119,7 @@ }, "FieldQuery": "select c from sbtest30 where 1 != 1 group by c", "OrderBy": "0 ASC COLLATE latin1_swedish_ci", - "Query": "select c from sbtest30 where id between 1 and 10 group by c order by c asc", + "Query": "select c from sbtest30 where id between 1 and 10 group by c order by sbtest30.c asc", "Table": "sbtest30" } ] diff --git a/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json b/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json index 0b0c0658175..53d9a136b23 100644 --- a/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/postprocess_cases.json @@ -145,7 +145,7 @@ "Sharded": true }, "FieldQuery": "select id from `user` where 1 != 1", - "Query": "select id from `user` where :__sq_has_values and id in ::__vals", + "Query": "select id from `user` where :__sq_has_values and `user`.id in ::__vals", "Table": "`user`", "Values": [ "::__sq1" @@ -226,7 +226,7 @@ }, "FieldQuery": "select col from `user` where 1 != 1", "OrderBy": "0 ASC", - "Query": "select col from `user` order by col asc", + "Query": "select col from `user` order by `user`.col asc", "Table": "`user`" }, "TablesUsed": [ @@ -249,7 +249,7 @@ }, "FieldQuery": "select user_id, col1, col2, weight_string(user_id) from authoritative where 1 != 1", "OrderBy": "(0|3) ASC", - "Query": "select user_id, col1, col2, weight_string(user_id) from authoritative order by user_id asc", + "Query": "select user_id, col1, col2, weight_string(user_id) from authoritative order by authoritative.user_id asc", "ResultColumns": 3, "Table": "authoritative" }, @@ -273,7 +273,7 @@ }, "FieldQuery": "select user_id, col1, col2 from authoritative where 1 != 1", "OrderBy": "1 ASC COLLATE latin1_swedish_ci", - "Query": "select user_id, col1, col2 from authoritative order by col1 asc", + "Query": "select user_id, col1, col2 from authoritative order by authoritative.col1 asc", "Table": "authoritative" }, "TablesUsed": [ @@ -296,7 +296,7 @@ }, "FieldQuery": "select a, textcol1, b, weight_string(a), weight_string(b) from `user` where 1 != 1", "OrderBy": "(0|3) ASC, 1 ASC COLLATE latin1_swedish_ci, (2|4) ASC", - "Query": "select a, textcol1, b, weight_string(a), weight_string(b) from `user` order by a asc, textcol1 asc, b asc", + "Query": "select a, textcol1, b, weight_string(a), weight_string(b) from `user` order by `user`.a asc, `user`.textcol1 asc, `user`.b asc", "ResultColumns": 3, "Table": "`user`" }, @@ -320,7 +320,7 @@ }, "FieldQuery": "select a, `user`.textcol1, b, weight_string(a), weight_string(b) from `user` where 1 != 1", "OrderBy": "(0|3) ASC, 1 ASC COLLATE latin1_swedish_ci, (2|4) ASC", - "Query": "select a, `user`.textcol1, b, weight_string(a), weight_string(b) from `user` order by a asc, `user`.textcol1 asc, b asc", + "Query": "select a, `user`.textcol1, b, weight_string(a), weight_string(b) from `user` order by `user`.a asc, `user`.textcol1 asc, `user`.b asc", "ResultColumns": 3, "Table": "`user`" }, @@ -344,7 +344,7 @@ }, "FieldQuery": "select a, textcol1, b, textcol2, weight_string(a), weight_string(b), weight_string(textcol2) from `user` where 1 != 1", "OrderBy": "(0|4) ASC, 1 ASC COLLATE latin1_swedish_ci, (2|5) ASC, (3|6) ASC COLLATE ", - "Query": "select a, textcol1, b, textcol2, weight_string(a), weight_string(b), weight_string(textcol2) from `user` order by a asc, textcol1 asc, b asc, textcol2 asc", + "Query": "select a, textcol1, b, textcol2, weight_string(a), weight_string(b), weight_string(textcol2) from `user` order by `user`.a asc, `user`.textcol1 asc, `user`.b asc, `user`.textcol2 asc", "ResultColumns": 4, "Table": "`user`" }, @@ -440,7 +440,7 @@ }, "FieldQuery": "select col from `user` where 1 != 1", "OrderBy": "0 ASC", - "Query": "select col from `user` where :__sq_has_values and col in ::__sq1 order by col asc", + "Query": "select col from `user` where :__sq_has_values and col in ::__sq1 order by `user`.col asc", "Table": "`user`" } ] @@ -1079,7 +1079,7 @@ "Sharded": true }, "FieldQuery": "select col from `user` as route1 where 1 != 1", - "Query": "select col from `user` as route1 where id = 1 order by col asc", + "Query": "select col from `user` as route1 where id = 1 order by route1.col asc", "Table": "`user`", "Values": [ "1" @@ -1365,7 +1365,7 @@ }, "FieldQuery": "select id as foo, weight_string(id) from music where 1 != 1", "OrderBy": "(0|1) ASC", - "Query": "select id as foo, weight_string(id) from music order by id asc", + "Query": "select id as foo, weight_string(id) from music order by music.id asc", "ResultColumns": 1, "Table": "music" }, @@ -1389,7 +1389,7 @@ }, "FieldQuery": "select id as foo, id2 as id, weight_string(id2) from music where 1 != 1", "OrderBy": "(1|2) ASC", - "Query": "select id as foo, id2 as id, weight_string(id2) from music order by id2 asc", + "Query": "select id as foo, id2 as id, weight_string(id2) from music order by music.id2 asc", "ResultColumns": 2, "Table": "music" }, @@ -1419,7 +1419,7 @@ }, "FieldQuery": "select `name`, weight_string(`name`) from `user` where 1 != 1", "OrderBy": "(0|1) ASC", - "Query": "select `name`, weight_string(`name`) from `user` order by `name` asc", + "Query": "select `name`, weight_string(`name`) from `user` order by `user`.`name` asc", "Table": "`user`" }, { @@ -1606,7 +1606,7 @@ }, "FieldQuery": "select `name`, weight_string(`name`) from `user` where 1 != 1", "OrderBy": "(0|1) ASC", - "Query": "select `name`, weight_string(`name`) from `user` order by `name` asc", + "Query": "select `name`, weight_string(`name`) from `user` order by `user`.`name` asc", "Table": "`user`" }, { @@ -1645,7 +1645,7 @@ }, "FieldQuery": "select id, id, weight_string(id) from `user` where 1 != 1", "OrderBy": "(0|2) ASC", - "Query": "select id, id, weight_string(id) from `user` order by id asc", + "Query": "select id, id, weight_string(id) from `user` order by `user`.id asc", "ResultColumns": 2, "Table": "`user`" }, @@ -2095,7 +2095,7 @@ }, "FieldQuery": "select col from `user` where 1 != 1 group by col", "OrderBy": "0 ASC", - "Query": "select col from `user` where id between :vtg1 and :vtg2 group by col order by col asc", + "Query": "select col from `user` where id between :vtg1 and :vtg2 group by col order by `user`.col asc", "Table": "`user`" } ] @@ -2126,7 +2126,7 @@ }, "FieldQuery": "select foo, col, weight_string(foo) from `user` where 1 != 1 group by col, foo, weight_string(foo)", "OrderBy": "1 ASC, (0|2) ASC", - "Query": "select foo, col, weight_string(foo) from `user` where id between :vtg1 and :vtg2 group by col, foo, weight_string(foo) order by col asc, foo asc", + "Query": "select foo, col, weight_string(foo) from `user` where id between :vtg1 and :vtg2 group by col, foo, weight_string(foo) order by `user`.col asc, foo asc", "Table": "`user`" } ] diff --git a/go/vt/vtgate/planbuilder/testdata/select_cases.json b/go/vt/vtgate/planbuilder/testdata/select_cases.json index 0ef10b5247f..4a5f85b249d 100644 --- a/go/vt/vtgate/planbuilder/testdata/select_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/select_cases.json @@ -1078,7 +1078,7 @@ }, "FieldQuery": "select user_id, weight_string(user_id) from music where 1 != 1", "OrderBy": "(0|1) ASC", - "Query": "select user_id, weight_string(user_id) from music order by user_id asc limit :__upper_limit", + "Query": "select user_id, weight_string(user_id) from music order by music.user_id asc limit :__upper_limit", "ResultColumns": 1, "Table": "music" } @@ -1884,7 +1884,7 @@ }, "FieldQuery": "select user_id, count(id), weight_string(user_id) from music where 1 != 1 group by user_id", "OrderBy": "(0|2) ASC", - "Query": "select user_id, count(id), weight_string(user_id) from music group by user_id having count(user_id) = 1 order by user_id asc limit :__upper_limit", + "Query": "select user_id, count(id), weight_string(user_id) from music group by user_id having count(user_id) = 1 order by music.user_id asc limit :__upper_limit", "ResultColumns": 2, "Table": "music" } @@ -2414,7 +2414,7 @@ }, "FieldQuery": "select col, `user`.id from `user` where 1 != 1", "OrderBy": "0 ASC", - "Query": "select col, `user`.id from `user` order by col asc", + "Query": "select col, `user`.id from `user` order by `user`.col asc", "Table": "`user`" }, { @@ -2840,7 +2840,7 @@ }, "FieldQuery": "select id, weight_string(id) from `user` where 1 != 1", "OrderBy": "(0|1) ASC", - "Query": "select id, weight_string(id) from `user` order by id asc limit :__upper_limit", + "Query": "select id, weight_string(id) from `user` order by `user`.id asc limit :__upper_limit", "Table": "`user`" } ] @@ -2853,8 +2853,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select :__sq1 as `(select id from ``user`` order by id asc limit 1)` from user_extra where 1 != 1", - "Query": "select :__sq1 as `(select id from ``user`` order by id asc limit 1)` from user_extra", + "FieldQuery": "select :__sq1 as `(select id from ``user`` order by ``user``.id asc limit 1)` from user_extra where 1 != 1", + "Query": "select :__sq1 as `(select id from ``user`` order by ``user``.id asc limit 1)` from user_extra", "Table": "user_extra" } ] @@ -3330,7 +3330,7 @@ }, "FieldQuery": "select id, `name`, weight_string(id) from `user` where 1 != 1", "OrderBy": "(0|2) ASC", - "Query": "select id, `name`, weight_string(id) from `user` where `name` = 'aa' order by id asc limit :__upper_limit", + "Query": "select id, `name`, weight_string(id) from `user` where `name` = 'aa' order by `user`.id asc limit :__upper_limit", "ResultColumns": 2, "Table": "`user`" } diff --git a/go/vt/vtgate/planbuilder/testdata/tpcc_cases.json b/go/vt/vtgate/planbuilder/testdata/tpcc_cases.json index ee38e7d0538..f6072bcd9a5 100644 --- a/go/vt/vtgate/planbuilder/testdata/tpcc_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/tpcc_cases.json @@ -556,7 +556,7 @@ "Sharded": true }, "FieldQuery": "select c_balance, c_first, c_middle, c_id from customer1 where 1 != 1", - "Query": "select c_balance, c_first, c_middle, c_id from customer1 where c_w_id = 840 and c_d_id = 1 and c_last = 'test' order by c_first asc", + "Query": "select c_balance, c_first, c_middle, c_id from customer1 where c_w_id = 840 and c_d_id = 1 and c_last = 'test' order by customer1.c_first asc", "Table": "customer1", "Values": [ "840" @@ -608,7 +608,7 @@ "Sharded": true }, "FieldQuery": "select o_id, o_carrier_id, o_entry_d from orders1 where 1 != 1", - "Query": "select o_id, o_carrier_id, o_entry_d from orders1 where o_w_id = 9894 and o_d_id = 3 and o_c_id = 159 order by o_id desc", + "Query": "select o_id, o_carrier_id, o_entry_d from orders1 where o_w_id = 9894 and o_d_id = 3 and o_c_id = 159 order by orders1.o_id desc", "Table": "orders1", "Values": [ "9894" @@ -660,7 +660,7 @@ "Sharded": true }, "FieldQuery": "select no_o_id from new_orders1 where 1 != 1", - "Query": "select no_o_id from new_orders1 where no_d_id = 689 and no_w_id = 15 order by no_o_id asc limit 1 for update", + "Query": "select no_o_id from new_orders1 where no_d_id = 689 and no_w_id = 15 order by new_orders1.no_o_id asc limit 1 for update", "Table": "new_orders1", "Values": [ "15" diff --git a/go/vt/vtgate/planbuilder/testdata/tpch_cases.json b/go/vt/vtgate/planbuilder/testdata/tpch_cases.json index 2d225808992..609285c4bfe 100644 --- a/go/vt/vtgate/planbuilder/testdata/tpch_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/tpch_cases.json @@ -35,7 +35,7 @@ }, "FieldQuery": "select l_returnflag, l_linestatus, sum(l_quantity) as sum_qty, sum(l_extendedprice) as sum_base_price, sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, sum(l_quantity) as avg_qty, sum(l_extendedprice) as avg_price, sum(l_discount) as avg_disc, count(*) as count_order, count(l_quantity), count(l_extendedprice), count(l_discount), weight_string(l_returnflag), weight_string(l_linestatus) from lineitem where 1 != 1 group by l_returnflag, l_linestatus, weight_string(l_returnflag), weight_string(l_linestatus)", "OrderBy": "(0|13) ASC, (1|14) ASC", - "Query": "select l_returnflag, l_linestatus, sum(l_quantity) as sum_qty, sum(l_extendedprice) as sum_base_price, sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, sum(l_quantity) as avg_qty, sum(l_extendedprice) as avg_price, sum(l_discount) as avg_disc, count(*) as count_order, count(l_quantity), count(l_extendedprice), count(l_discount), weight_string(l_returnflag), weight_string(l_linestatus) from lineitem where l_shipdate <= '1998-12-01' - interval '108' day group by l_returnflag, l_linestatus, weight_string(l_returnflag), weight_string(l_linestatus) order by l_returnflag asc, l_linestatus asc", + "Query": "select l_returnflag, l_linestatus, sum(l_quantity) as sum_qty, sum(l_extendedprice) as sum_base_price, sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, sum(l_quantity) as avg_qty, sum(l_extendedprice) as avg_price, sum(l_discount) as avg_disc, count(*) as count_order, count(l_quantity), count(l_extendedprice), count(l_discount), weight_string(l_returnflag), weight_string(l_linestatus) from lineitem where l_shipdate <= '1998-12-01' - interval '108' day group by l_returnflag, l_linestatus, weight_string(l_returnflag), weight_string(l_linestatus) order by lineitem.l_returnflag asc, lineitem.l_linestatus asc", "Table": "lineitem" } ] @@ -213,7 +213,7 @@ }, "FieldQuery": "select o_orderpriority, count(*) as order_count, o_orderkey, weight_string(o_orderpriority) from orders where 1 != 1 group by o_orderpriority, o_orderkey, weight_string(o_orderpriority)", "OrderBy": "(0|3) ASC", - "Query": "select o_orderpriority, count(*) as order_count, o_orderkey, weight_string(o_orderpriority) from orders where o_orderdate >= date('1993-07-01') and o_orderdate < date('1993-07-01') + interval '3' month group by o_orderpriority, o_orderkey, weight_string(o_orderpriority) order by o_orderpriority asc", + "Query": "select o_orderpriority, count(*) as order_count, o_orderkey, weight_string(o_orderpriority) from orders where o_orderdate >= date('1993-07-01') and o_orderdate < date('1993-07-01') + interval '3' month group by o_orderpriority, o_orderkey, weight_string(o_orderpriority) order by orders.o_orderpriority asc", "Table": "orders" }, { @@ -631,9 +631,9 @@ "Name": "main", "Sharded": true }, - "FieldQuery": "select sum(volume) as revenue, l_year, shipping.l_suppkey, shipping.l_orderkey, weight_string(l_year), supp_nation, weight_string(supp_nation), cust_nation, weight_string(cust_nation) from (select extract(year from l_shipdate) as l_year, l_extendedprice * (1 - l_discount) as volume, l_suppkey as l_suppkey, l_orderkey as l_orderkey from lineitem where 1 != 1) as shipping where 1 != 1 group by l_year, shipping.l_suppkey, shipping.l_orderkey, weight_string(l_year)", + "FieldQuery": "select sum(volume) as revenue, l_year, shipping.l_suppkey, shipping.l_orderkey, weight_string(l_year), shipping.supp_nation, weight_string(shipping.supp_nation), shipping.cust_nation, weight_string(shipping.cust_nation) from (select extract(year from l_shipdate) as l_year, l_extendedprice * (1 - l_discount) as volume, l_suppkey as l_suppkey, l_orderkey as l_orderkey from lineitem where 1 != 1) as shipping where 1 != 1 group by l_year, shipping.l_suppkey, shipping.l_orderkey, weight_string(l_year)", "OrderBy": "(5|6) ASC, (7|8) ASC, (1|4) ASC", - "Query": "select sum(volume) as revenue, l_year, shipping.l_suppkey, shipping.l_orderkey, weight_string(l_year), supp_nation, weight_string(supp_nation), cust_nation, weight_string(cust_nation) from (select extract(year from l_shipdate) as l_year, l_extendedprice * (1 - l_discount) as volume, l_suppkey as l_suppkey, l_orderkey as l_orderkey from lineitem where l_shipdate between date('1995-01-01') and date('1996-12-31')) as shipping group by l_year, shipping.l_suppkey, shipping.l_orderkey, weight_string(l_year) order by supp_nation asc, cust_nation asc, l_year asc", + "Query": "select sum(volume) as revenue, l_year, shipping.l_suppkey, shipping.l_orderkey, weight_string(l_year), shipping.supp_nation, weight_string(shipping.supp_nation), shipping.cust_nation, weight_string(shipping.cust_nation) from (select extract(year from l_shipdate) as l_year, l_extendedprice * (1 - l_discount) as volume, l_suppkey as l_suppkey, l_orderkey as l_orderkey from lineitem where l_shipdate between date('1995-01-01') and date('1996-12-31')) as shipping group by l_year, shipping.l_suppkey, shipping.l_orderkey, weight_string(l_year) order by shipping.supp_nation asc, shipping.cust_nation asc, shipping.l_year asc", "Table": "lineitem" }, { @@ -1518,7 +1518,7 @@ }, "FieldQuery": "select s_suppkey, s_name, s_address, s_phone, total_revenue, weight_string(s_suppkey) from supplier, revenue0 where 1 != 1", "OrderBy": "(0|5) ASC", - "Query": "select s_suppkey, s_name, s_address, s_phone, total_revenue, weight_string(s_suppkey) from supplier, revenue0 where s_suppkey = supplier_no and total_revenue = :__sq1 order by s_suppkey asc", + "Query": "select s_suppkey, s_name, s_address, s_phone, total_revenue, weight_string(s_suppkey) from supplier, revenue0 where s_suppkey = supplier_no and total_revenue = :__sq1 order by supplier.s_suppkey asc", "ResultColumns": 5, "Table": "revenue0, supplier" } diff --git a/go/vt/vtgate/planbuilder/testdata/union_cases.json b/go/vt/vtgate/planbuilder/testdata/union_cases.json index 9ac8db73be7..7c225862235 100644 --- a/go/vt/vtgate/planbuilder/testdata/union_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/union_cases.json @@ -128,7 +128,7 @@ }, "FieldQuery": "select id, weight_string(id) from `user` where 1 != 1", "OrderBy": "(0|1) DESC", - "Query": "select id, weight_string(id) from `user` order by id desc limit :__upper_limit", + "Query": "select id, weight_string(id) from `user` order by `user`.id desc limit :__upper_limit", "Table": "`user`" } ] @@ -146,7 +146,7 @@ }, "FieldQuery": "select id, weight_string(id) from music where 1 != 1", "OrderBy": "(0|1) DESC", - "Query": "select id, weight_string(id) from music order by id desc limit :__upper_limit", + "Query": "select id, weight_string(id) from music order by music.id desc limit :__upper_limit", "Table": "music" } ] @@ -258,7 +258,7 @@ }, "FieldQuery": "select id, weight_string(id) from `user` where 1 != 1", "OrderBy": "(0|1) ASC", - "Query": "select id, weight_string(id) from `user` order by id asc limit :__upper_limit", + "Query": "select id, weight_string(id) from `user` order by `user`.id asc limit :__upper_limit", "Table": "`user`" } ] @@ -276,7 +276,7 @@ }, "FieldQuery": "select id, weight_string(id) from music where 1 != 1", "OrderBy": "(0|1) DESC", - "Query": "select id, weight_string(id) from music order by id desc limit :__upper_limit", + "Query": "select id, weight_string(id) from music order by music.id desc limit :__upper_limit", "Table": "music" } ] @@ -962,7 +962,7 @@ }, "FieldQuery": "select id, weight_string(id) from `user` where 1 != 1", "OrderBy": "(0|1) ASC", - "Query": "select id, weight_string(id) from `user` order by id asc limit :__upper_limit", + "Query": "select id, weight_string(id) from `user` order by `user`.id asc limit :__upper_limit", "Table": "`user`" } ] @@ -980,7 +980,7 @@ }, "FieldQuery": "select id, weight_string(id) from `user` where 1 != 1", "OrderBy": "(0|1) DESC", - "Query": "select id, weight_string(id) from `user` order by id desc limit :__upper_limit", + "Query": "select id, weight_string(id) from `user` order by `user`.id desc limit :__upper_limit", "Table": "`user`" } ] diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index 0fe9f4a934e..f289438a1c9 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -37,6 +37,7 @@ type analyzer struct { sig QuerySignature si SchemaInformation currentDb string + recheck bool err error inProjection int @@ -74,7 +75,8 @@ func (a *analyzer) lateInit() { expandedColumns: map[sqlparser.TableName][]*sqlparser.ColName{}, env: a.si.Environment(), aliasMapCache: map[*sqlparser.Select]map[string]exprContainer{}, - reAnalyze: a.lateAnalyze, + reAnalyze: a.reAnalyze, + tables: a.tables, } } @@ -249,9 +251,12 @@ func (a *analyzer) analyzeUp(cursor *sqlparser.Cursor) bool { return false } - if err := a.rewriter.up(cursor); err != nil { - a.setError(err) - return true + if !a.recheck { + // no need to run the rewriter on rechecking + if err := a.rewriter.up(cursor); err != nil { + a.setError(err) + return true + } } if err := a.scoper.up(cursor); err != nil { @@ -359,6 +364,14 @@ func (a *analyzer) lateAnalyze(statement sqlparser.SQLNode) error { return a.err } +func (a *analyzer) reAnalyze(statement sqlparser.SQLNode) error { + a.recheck = true + defer func() { + a.recheck = false + }() + return a.lateAnalyze(statement) +} + // canShortCut checks if we are dealing with a single unsharded keyspace and no tables that have managed foreign keys // if so, we can stop the analyzer early func (a *analyzer) canShortCut(statement sqlparser.Statement) (canShortCut bool) { @@ -639,6 +652,10 @@ type ShardedError struct { Inner error } +func (p ShardedError) Unwrap() error { + return p.Inner +} + func (p ShardedError) Error() string { return p.Inner.Error() } diff --git a/go/vt/vtgate/semantics/analyzer_test.go b/go/vt/vtgate/semantics/analyzer_test.go index a7c173ccc96..27a34a427f1 100644 --- a/go/vt/vtgate/semantics/analyzer_test.go +++ b/go/vt/vtgate/semantics/analyzer_test.go @@ -722,7 +722,10 @@ func TestGroupByBinding(t *testing.T) { TS1, }, { "select a.id from t as a, t1 group by id", - TS0, + // since we have authoritative info on t1, we know that it does have an `id` column, + // and we are missing column info for `t`, we just assume this is coming from t1. + // we really need schema tracking here + TS1, }, { "select a.id from t, t1 as a group by id", TS1, @@ -740,44 +743,47 @@ func TestGroupByBinding(t *testing.T) { func TestHavingBinding(t *testing.T) { tcases := []struct { - sql string - deps TableSet + sql, err string + deps TableSet }{{ - "select col from tabl having col = 1", - TS0, + sql: "select col from tabl having col = 1", + deps: TS0, }, { - "select col from tabl having tabl.col = 1", - TS0, + sql: "select col from tabl having tabl.col = 1", + deps: TS0, }, { - "select col from tabl having d.tabl.col = 1", - TS0, + sql: "select col from tabl having d.tabl.col = 1", + deps: TS0, }, { - "select tabl.col as x from tabl having x = 1", - TS0, + sql: "select tabl.col as x from tabl having col = 1", + deps: TS0, }, { - "select tabl.col as x from tabl having col", - TS0, + sql: "select tabl.col as x from tabl having x = 1", + deps: TS0, }, { - "select col from tabl having 1 = 1", - NoTables, + sql: "select tabl.col as x from tabl having col", + deps: TS0, }, { - "select col as c from tabl having c = 1", - TS0, + sql: "select col from tabl having 1 = 1", + deps: NoTables, }, { - "select 1 as c from tabl having c = 1", - NoTables, + sql: "select col as c from tabl having c = 1", + deps: TS0, }, { - "select t1.id from t1, t2 having id = 1", - TS0, + sql: "select 1 as c from tabl having c = 1", + deps: NoTables, }, { - "select t.id from t, t1 having id = 1", - TS0, + sql: "select t1.id from t1, t2 having id = 1", + deps: TS0, }, { - "select t.id, count(*) as a from t, t1 group by t.id having a = 1", - MergeTableSets(TS0, TS1), + sql: "select t.id from t, t1 having id = 1", + deps: TS0, }, { - "select t.id, sum(t2.name) as a from t, t2 group by t.id having a = 1", - TS1, + sql: "select t.id, count(*) as a from t, t1 group by t.id having a = 1", + deps: MergeTableSets(TS0, TS1), + }, { + sql: "select t.id, sum(t2.name) as a from t, t2 group by t.id having a = 1", + deps: TS1, }, { sql: "select u2.a, u1.a from u1, u2 having u2.a = 2", deps: TS1, diff --git a/go/vt/vtgate/semantics/binder.go b/go/vt/vtgate/semantics/binder.go index b010649e067..9d91f6523cf 100644 --- a/go/vt/vtgate/semantics/binder.go +++ b/go/vt/vtgate/semantics/binder.go @@ -228,28 +228,32 @@ func (b *binder) setSubQueryDependencies(subq *sqlparser.Subquery, currScope *sc } func (b *binder) resolveColumn(colName *sqlparser.ColName, current *scope, allowMulti, singleTableFallBack bool) (dependency, error) { + if !current.stmtScope && current.inGroupBy { + return b.resolveColInGroupBy(colName, current, allowMulti, singleTableFallBack) + } + if !current.stmtScope && current.inHaving && !current.inHavingAggr { + return b.resolveColumnInHaving(colName, current, allowMulti) + } + var thisDeps dependencies first := true var tableName *sqlparser.TableName + for current != nil { var err error thisDeps, err = b.resolveColumnInScope(current, colName, allowMulti) if err != nil { - err = makeAmbiguousError(colName, err) - if thisDeps == nil { - return dependency{}, err - } + return dependency{}, makeAmbiguousError(colName, err) } if !thisDeps.empty() { - deps, thisErr := thisDeps.get() - if thisErr != nil { - err = makeAmbiguousError(colName, thisErr) - } - return deps, err - } else if err != nil { - return dependency{}, err + deps, err := thisDeps.get() + return deps, makeAmbiguousError(colName, err) } - if current.parent == nil && len(current.tables) == 1 && first && colName.Qualifier.IsEmpty() && singleTableFallBack { + if current.parent == nil && + len(current.tables) == 1 && + first && + colName.Qualifier.IsEmpty() && + singleTableFallBack { // if this is the top scope, and we still haven't been able to find a match, we know we are about to fail // we can check this last scope and see if there is a single table. if there is just one table in the scope // we assume that the column is meant to come from this table. @@ -267,6 +271,146 @@ func (b *binder) resolveColumn(colName *sqlparser.ColName, current *scope, allow return dependency{}, ShardedError{ColumnNotFoundError{Column: colName, Table: tableName}} } +func isColumnNotFound(err error) bool { + switch err := err.(type) { + case ColumnNotFoundError: + return true + case ShardedError: + return isColumnNotFound(err.Inner) + default: + return false + } +} + +func (b *binder) resolveColumnInHaving(colName *sqlparser.ColName, current *scope, allowMulti bool) (dependency, error) { + if current.inHavingAggr { + // when inside an aggregation, we'll search the FROM clause before the SELECT expressions + deps, err := b.resolveColumn(colName, current.parent, allowMulti, true) + if deps.direct.NotEmpty() || (err != nil && !isColumnNotFound(err)) { + return deps, err + } + } + + // Here we are searching among the SELECT expressions for a match + thisDeps, err := b.resolveColumnInScope(current, colName, allowMulti) + if err != nil { + return dependency{}, makeAmbiguousError(colName, err) + } + + if !thisDeps.empty() { + // we found something! let's return it + deps, err := thisDeps.get() + if err != nil { + err = makeAmbiguousError(colName, err) + } + return deps, err + } + + notFoundErr := &ColumnNotFoundClauseError{Column: colName.Name.String(), Clause: "having clause"} + if current.inHavingAggr { + // if we are inside an aggregation, we've already looked everywhere. now it's time to give up + return dependency{}, notFoundErr + } + + // Now we'll search the FROM clause, but with a twist. If we find it in the FROM clause, the column must also + // exist as a standalone expression in the SELECT list + deps, err := b.resolveColumn(colName, current.parent, allowMulti, true) + if deps.direct.IsEmpty() { + return dependency{}, notFoundErr + } + + sel := current.stmt.(*sqlparser.Select) // we can be sure of this, since HAVING doesn't exist on UNION + if selDeps := b.searchInSelectExpressions(colName, deps, sel); selDeps.direct.NotEmpty() { + return selDeps, nil + } + + if !current.inHavingAggr && len(sel.GroupBy) == 0 { + // if we are not inside an aggregation, and there is no GROUP BY, we consider the FROM clause before failing + if deps.direct.NotEmpty() || (err != nil && !isColumnNotFound(err)) { + return deps, err + } + } + + return dependency{}, notFoundErr +} + +// searchInSelectExpressions searches for the ColName among the SELECT and GROUP BY expressions +// It used dependency information to match the columns +func (b *binder) searchInSelectExpressions(colName *sqlparser.ColName, deps dependency, stmt *sqlparser.Select) dependency { + for _, selectExpr := range stmt.SelectExprs { + ae, ok := selectExpr.(*sqlparser.AliasedExpr) + if !ok { + continue + } + selectCol, ok := ae.Expr.(*sqlparser.ColName) + if !ok || !selectCol.Name.Equal(colName.Name) { + continue + } + + _, direct, _ := b.org.depsForExpr(selectCol) + if deps.direct == direct { + // we have found the ColName in the SELECT expressions, so it's safe to use here + direct, recursive, typ := b.org.depsForExpr(ae.Expr) + return dependency{certain: true, direct: direct, recursive: recursive, typ: typ} + } + } + + for _, gb := range stmt.GroupBy { + selectCol, ok := gb.(*sqlparser.ColName) + if !ok || !selectCol.Name.Equal(colName.Name) { + continue + } + + _, direct, _ := b.org.depsForExpr(selectCol) + if deps.direct == direct { + // we have found the ColName in the GROUP BY expressions, so it's safe to use here + direct, recursive, typ := b.org.depsForExpr(gb) + return dependency{certain: true, direct: direct, recursive: recursive, typ: typ} + } + } + return dependency{} +} + +// resolveColInGroupBy handles the special rules we have when binding on the GROUP BY column +func (b *binder) resolveColInGroupBy( + colName *sqlparser.ColName, + current *scope, + allowMulti bool, + singleTableFallBack bool, +) (dependency, error) { + if current.parent == nil { + return dependency{}, vterrors.VT13001("did not expect this to be the last scope") + } + // if we are in GROUP BY, we have to search the FROM clause before we search the SELECT expressions + deps, firstErr := b.resolveColumn(colName, current.parent, allowMulti, false) + if firstErr == nil { + return deps, nil + } + + // either we didn't find the column on a table, or it was ambiguous. + // in either case, next step is to search the SELECT expressions + if colName.Qualifier.NonEmpty() { + // if the col name has a qualifier, none of the SELECT expressions are going to match + return dependency{}, nil + } + vtbl, ok := current.tables[0].(*vTableInfo) + if !ok { + return dependency{}, vterrors.VT13001("expected the table info to be a *vTableInfo") + } + + dependencies, err := vtbl.dependenciesInGroupBy(colName.Name.String(), b.org) + if err != nil { + return dependency{}, err + } + if dependencies.empty() { + if isColumnNotFound(firstErr) { + return dependency{}, &ColumnNotFoundClauseError{Column: colName.Name.String(), Clause: "group statement"} + } + return deps, firstErr + } + return dependencies.get() +} + func (b *binder) resolveColumnInScope(current *scope, expr *sqlparser.ColName, allowMulti bool) (dependencies, error) { var deps dependencies = ¬hing{} for _, table := range current.tables { diff --git a/go/vt/vtgate/semantics/dependencies.go b/go/vt/vtgate/semantics/dependencies.go index e68c5100ef5..714fa97c2c4 100644 --- a/go/vt/vtgate/semantics/dependencies.go +++ b/go/vt/vtgate/semantics/dependencies.go @@ -32,6 +32,7 @@ type ( merge(other dependencies, allowMulti bool) dependencies } dependency struct { + certain bool direct TableSet recursive TableSet typ evalengine.Type @@ -52,6 +53,7 @@ var ambigousErr = vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "ambiguous") func createCertain(direct TableSet, recursive TableSet, qt evalengine.Type) *certain { c := &certain{ dependency: dependency{ + certain: true, direct: direct, recursive: recursive, }, @@ -65,6 +67,7 @@ func createCertain(direct TableSet, recursive TableSet, qt evalengine.Type) *cer func createUncertain(direct TableSet, recursive TableSet) *uncertain { return &uncertain{ dependency: dependency{ + certain: false, direct: direct, recursive: recursive, }, @@ -131,7 +134,7 @@ func (n *nothing) empty() bool { } func (n *nothing) get() (dependency, error) { - return dependency{}, nil + return dependency{certain: true}, nil } func (n *nothing) merge(d dependencies, _ bool) dependencies { diff --git a/go/vt/vtgate/semantics/early_rewriter.go b/go/vt/vtgate/semantics/early_rewriter.go index 646b5b71e41..db3c8cff396 100644 --- a/go/vt/vtgate/semantics/early_rewriter.go +++ b/go/vt/vtgate/semantics/early_rewriter.go @@ -35,6 +35,7 @@ type earlyRewriter struct { expandedColumns map[sqlparser.TableName][]*sqlparser.ColName env *vtenv.Environment aliasMapCache map[*sqlparser.Select]map[string]exprContainer + tables *tableCollector // reAnalyze is used when we are running in the late stage, after the other parts of semantic analysis // have happened, and we are introducing or changing the AST. We invoke it so all parts of the query have been @@ -44,8 +45,6 @@ type earlyRewriter struct { func (r *earlyRewriter) down(cursor *sqlparser.Cursor) error { switch node := cursor.Node().(type) { - case *sqlparser.Where: - return r.handleWhereClause(node, cursor.Parent()) case sqlparser.SelectExprs: return r.handleSelectExprs(cursor, node) case *sqlparser.JoinTableExpr: @@ -56,13 +55,6 @@ func (r *earlyRewriter) down(cursor *sqlparser.Cursor) error { rewriteAndExpr(r.env, cursor, node) case *sqlparser.NotExpr: rewriteNotExpr(cursor, node) - case sqlparser.GroupBy: - r.clause = "group clause" - iter := &exprIterator{ - node: node, - idx: -1, - } - return r.handleGroupBy(cursor.Parent(), iter) case *sqlparser.ComparisonExpr: return handleComparisonExpr(cursor, node) case *sqlparser.With: @@ -83,6 +75,13 @@ func (r *earlyRewriter) up(cursor *sqlparser.Cursor) error { // this rewriting is done in the `up` phase, because we need the vindex hints to have been // processed while collecting the tables. return removeVindexHints(node) + case sqlparser.GroupBy: + r.clause = "group clause" + iter := &exprIterator{ + node: node, + idx: -1, + } + return r.handleGroupBy(cursor.Parent(), iter) case sqlparser.OrderBy: r.clause = "order clause" iter := &orderByIterator{ @@ -91,6 +90,11 @@ func (r *earlyRewriter) up(cursor *sqlparser.Cursor) error { r: r, } return r.handleOrderBy(cursor.Parent(), iter) + case *sqlparser.Where: + if node.Type == sqlparser.HavingClause { + return r.handleHavingClause(node, cursor.Parent()) + } + } return nil } @@ -196,22 +200,18 @@ func removeVindexHints(node *sqlparser.AliasedTableExpr) error { return nil } -// handleWhereClause processes WHERE clauses, specifically the HAVING clause. -func (r *earlyRewriter) handleWhereClause(node *sqlparser.Where, parent sqlparser.SQLNode) error { +// handleHavingClause processes the HAVING clause +func (r *earlyRewriter) handleHavingClause(node *sqlparser.Where, parent sqlparser.SQLNode) error { sel, ok := parent.(*sqlparser.Select) if !ok { return nil } - if node.Type != sqlparser.HavingClause { - return nil - } - expr, err := r.rewriteAliasesInHavingAndGroupBy(node.Expr, sel) + expr, err := r.rewriteAliasesInHaving(node.Expr, sel) if err != nil { return err } - node.Expr = expr - return nil + return r.reAnalyze(expr) } // handleSelectExprs expands * in SELECT expressions. @@ -290,7 +290,7 @@ func (r *earlyRewriter) replaceLiteralsInOrderBy(e sqlparser.Expr, iter iterator return false, nil } - newExpr, recheck, err := r.rewriteOrderByExpr(lit) + newExpr, recheck, err := r.rewriteOrderByLiteral(lit) if err != nil { return false, err } @@ -328,15 +328,15 @@ func (r *earlyRewriter) replaceLiteralsInOrderBy(e sqlparser.Expr, iter iterator return true, nil } -func (r *earlyRewriter) replaceLiteralsInGroupBy(e sqlparser.Expr, iter iterator) (bool, error) { +func (r *earlyRewriter) replaceLiteralsInGroupBy(e sqlparser.Expr) (sqlparser.Expr, error) { lit := getIntLiteral(e) if lit == nil { - return false, nil + return nil, nil } newExpr, err := r.rewriteGroupByExpr(lit) if err != nil { - return false, err + return nil, err } if getIntLiteral(newExpr) == nil { @@ -359,8 +359,7 @@ func (r *earlyRewriter) replaceLiteralsInGroupBy(e sqlparser.Expr, iter iterator newExpr = sqlparser.NewStrLiteral("") } - err = iter.replace(newExpr) - return true, err + return newExpr, nil } func getIntLiteral(e sqlparser.Expr) *sqlparser.Literal { @@ -426,19 +425,23 @@ func (r *earlyRewriter) handleGroupBy(parent sqlparser.SQLNode, iter iterator) e sel := sqlparser.GetFirstSelect(stmt) for e := iter.next(); e != nil; e = iter.next() { - lit, err := r.replaceLiteralsInGroupBy(e, iter) + expr, err := r.replaceLiteralsInGroupBy(e) if err != nil { return err } - if lit { - continue + if expr == nil { + expr, err = r.rewriteAliasesInGroupBy(e, sel) + if err != nil { + return err + } + } - expr, err := r.rewriteAliasesInHavingAndGroupBy(e, sel) + err = iter.replace(expr) if err != nil { return err } - err = iter.replace(expr) - if err != nil { + + if err = r.reAnalyze(expr); err != nil { return err } } @@ -452,35 +455,14 @@ func (r *earlyRewriter) handleGroupBy(parent sqlparser.SQLNode, iter iterator) e // in SELECT points to that expression, not any table column. // - However, if the aliased expression is an aggregation and the column identifier in // the HAVING/ORDER BY clause is inside an aggregation function, the rule does not apply. -func (r *earlyRewriter) rewriteAliasesInHavingAndGroupBy(node sqlparser.Expr, sel *sqlparser.Select) (expr sqlparser.Expr, err error) { +func (r *earlyRewriter) rewriteAliasesInGroupBy(node sqlparser.Expr, sel *sqlparser.Select) (expr sqlparser.Expr, err error) { type ExprContainer struct { expr sqlparser.Expr ambiguous bool } - aliases := map[string]ExprContainer{} - for _, e := range sel.SelectExprs { - ae, ok := e.(*sqlparser.AliasedExpr) - if !ok { - continue - } - - var alias string - - item := ExprContainer{expr: ae.Expr} - if ae.As.NotEmpty() { - alias = ae.As.Lowered() - } else if col, ok := ae.Expr.(*sqlparser.ColName); ok { - alias = col.Name.Lowered() - } - - if old, alreadyExists := aliases[alias]; alreadyExists && !sqlparser.Equals.Expr(old.expr, item.expr) { - item.ambiguous = true - } - - aliases[alias] = item - } - + currentScope := r.scoper.currentScope() + aliases := r.getAliasMap(sel) insideAggr := false downF := func(node, _ sqlparser.SQLNode) bool { switch node.(type) { @@ -498,7 +480,7 @@ func (r *earlyRewriter) rewriteAliasesInHavingAndGroupBy(node sqlparser.Expr, se case sqlparser.AggrFunc: insideAggr = false case *sqlparser.ColName: - if !col.Qualifier.IsEmpty() { + if col.Qualifier.NonEmpty() { // we are only interested in columns not qualified by table names break } @@ -508,43 +490,111 @@ func (r *earlyRewriter) rewriteAliasesInHavingAndGroupBy(node sqlparser.Expr, se break } + isColumnOnTable, sure := r.isColumnOnTable(col, currentScope) + if found && isColumnOnTable { + r.warning = fmt.Sprintf("Column '%s' in group statement is ambiguous", sqlparser.String(col)) + } + + if isColumnOnTable && sure { + break + } + + if !sure { + r.warning = "Missing table info, so not binding to anything on the FROM clause" + } + if item.ambiguous { err = &AmbiguousColumnError{Column: sqlparser.String(col)} + } else if insideAggr && sqlparser.ContainsAggregation(item.expr) { + err = &InvalidUseOfGroupFunction{} + } + if err != nil { cursor.StopTreeWalk() return } - if insideAggr && sqlparser.ContainsAggregation(item.expr) { - // I'm not sure about this, but my experiments point to this being the behaviour mysql has - // mysql> select min(name) as name from user order by min(name); - // 1 row in set (0.00 sec) - // - // mysql> select id % 2, min(name) as name from user group by id % 2 order by min(name); - // 2 rows in set (0.00 sec) - // - // mysql> select id % 2, 'foobar' as name from user group by id % 2 order by min(name); - // 2 rows in set (0.00 sec) - // - // mysql> select id % 2 from user group by id % 2 order by min(min(name)); - // ERROR 1111 (HY000): Invalid use of group function - // - // mysql> select id % 2, min(name) as k from user group by id % 2 order by min(k); - // ERROR 1111 (HY000): Invalid use of group function - // - // mysql> select id % 2, -id as name from user group by id % 2, -id order by min(name); - // 6 rows in set (0.01 sec) - break + cursor.Replace(sqlparser.CloneExpr(item.expr)) + } + }, nil) + + expr = output.(sqlparser.Expr) + return +} + +func (r *earlyRewriter) rewriteAliasesInHaving(node sqlparser.Expr, sel *sqlparser.Select) (expr sqlparser.Expr, err error) { + currentScope := r.scoper.currentScope() + if currentScope.isUnion { + // It is not safe to rewrite order by clauses in unions. + return node, nil + } + + aliases := r.getAliasMap(sel) + insideAggr := false + dontEnterSubquery := func(node, _ sqlparser.SQLNode) bool { + switch node.(type) { + case *sqlparser.Subquery: + return false + case sqlparser.AggrFunc: + insideAggr = true + } + + return true + } + output := sqlparser.CopyOnRewrite(node, dontEnterSubquery, func(cursor *sqlparser.CopyOnWriteCursor) { + var col *sqlparser.ColName + + switch node := cursor.Node().(type) { + case sqlparser.AggrFunc: + insideAggr = false + return + case *sqlparser.ColName: + col = node + default: + return + } + + if col.Qualifier.NonEmpty() { + // we are only interested in columns not qualified by table names + return + } + + item, found := aliases[col.Name.Lowered()] + if insideAggr { + // inside aggregations, we want to first look for columns in the FROM clause + isColumnOnTable, sure := r.isColumnOnTable(col, currentScope) + if isColumnOnTable { + if found && sure { + r.warning = fmt.Sprintf("Column '%s' in having clause is ambiguous", sqlparser.String(col)) + } + return } + } else if !found { + // if outside aggregations, we don't care about FROM columns + // if there is no matching alias, there is no rewriting needed + return + } - cursor.Replace(sqlparser.CloneExpr(item.expr)) + // If we get here, it means we have found an alias and want to use it + if item.ambiguous { + err = &AmbiguousColumnError{Column: sqlparser.String(col)} + } else if insideAggr && sqlparser.ContainsAggregation(item.expr) { + err = &InvalidUseOfGroupFunction{} } + if err != nil { + cursor.StopTreeWalk() + return + } + + newColName := sqlparser.CopyOnRewrite(item.expr, nil, r.fillInQualifiers, nil) + + cursor.Replace(newColName) }, nil) expr = output.(sqlparser.Expr) return } -// rewriteAliasesInOrderBy rewrites columns in the ORDER BY and HAVING clauses to use aliases +// rewriteAliasesInOrderBy rewrites columns in the ORDER BY to use aliases // from the SELECT expressions when applicable, following MySQL scoping rules: // - A column identifier without a table qualifier that matches an alias introduced // in SELECT points to that expression, not any table column. @@ -567,8 +617,7 @@ func (r *earlyRewriter) rewriteAliasesInOrderBy(node sqlparser.Expr, sel *sqlpar insideAggr = true } - _, isSubq := node.(*sqlparser.Subquery) - return !isSubq + return true } output := sqlparser.CopyOnRewrite(node, dontEnterSubquery, func(cursor *sqlparser.CopyOnWriteCursor) { var col *sqlparser.ColName @@ -583,46 +632,80 @@ func (r *earlyRewriter) rewriteAliasesInOrderBy(node sqlparser.Expr, sel *sqlpar return } - if !col.Qualifier.IsEmpty() { + if col.Qualifier.NonEmpty() { // we are only interested in columns not qualified by table names return } - item, found := aliases[col.Name.Lowered()] + var item exprContainer + var found bool + + item, found = aliases[col.Name.Lowered()] if !found { // if there is no matching alias, there is no rewriting needed return } + isColumnOnTable, sure := r.isColumnOnTable(col, currentScope) + if found && isColumnOnTable && sure { + r.warning = fmt.Sprintf("Column '%s' in order by statement is ambiguous", sqlparser.String(col)) + } topLevel := col == node - if !topLevel && r.isColumnOnTable(col, currentScope) { + if isColumnOnTable && sure && !topLevel { // we only want to replace columns that are not coming from the table return } + if !sure { + r.warning = "Missing table info, so not binding to anything on the FROM clause" + } + if item.ambiguous { err = &AmbiguousColumnError{Column: sqlparser.String(col)} } else if insideAggr && sqlparser.ContainsAggregation(item.expr) { - err = &InvalidUserOfGroupFunction{} + err = &InvalidUseOfGroupFunction{} } if err != nil { cursor.StopTreeWalk() return } - cursor.Replace(sqlparser.CloneExpr(item.expr)) + newColName := sqlparser.CopyOnRewrite(item.expr, nil, r.fillInQualifiers, nil) + + cursor.Replace(newColName) }, nil) expr = output.(sqlparser.Expr) return } -func (r *earlyRewriter) isColumnOnTable(col *sqlparser.ColName, currentScope *scope) bool { +// fillInQualifiers adds qualifiers to any columns we have rewritten +func (r *earlyRewriter) fillInQualifiers(cursor *sqlparser.CopyOnWriteCursor) { + col, ok := cursor.Node().(*sqlparser.ColName) + if !ok || col.Qualifier.NonEmpty() { + return + } + ts, found := r.binder.direct[col] + if !found { + panic("uh oh") + } + tbl := r.tables.Tables[ts.TableOffset()] + tblName, err := tbl.Name() + if err != nil { + panic(err) + } + cursor.Replace(sqlparser.NewColNameWithQualifier(col.Name.String(), tblName)) +} + +func (r *earlyRewriter) isColumnOnTable(col *sqlparser.ColName, currentScope *scope) (isColumn bool, isCertain bool) { if !currentScope.stmtScope && currentScope.parent != nil { currentScope = currentScope.parent } - _, err := r.binder.resolveColumn(col, currentScope, false, false) - return err == nil + deps, err := r.binder.resolveColumn(col, currentScope, false, false) + if err != nil { + return false, true + } + return true, deps.certain } func (r *earlyRewriter) getAliasMap(sel *sqlparser.Select) (aliases map[string]exprContainer) { @@ -661,7 +744,7 @@ type exprContainer struct { ambiguous bool } -func (r *earlyRewriter) rewriteOrderByExpr(node *sqlparser.Literal) (expr sqlparser.Expr, needReAnalysis bool, err error) { +func (r *earlyRewriter) rewriteOrderByLiteral(node *sqlparser.Literal) (expr sqlparser.Expr, needReAnalysis bool, err error) { scope, found := r.scoper.specialExprScopes[node] if !found { return node, false, nil diff --git a/go/vt/vtgate/semantics/early_rewriter_test.go b/go/vt/vtgate/semantics/early_rewriter_test.go index cf93a52447c..a5c16ba6b78 100644 --- a/go/vt/vtgate/semantics/early_rewriter_test.go +++ b/go/vt/vtgate/semantics/early_rewriter_test.go @@ -172,8 +172,8 @@ func TestExpandStar(t *testing.T) { sql: "select * from t1 join t5 using (b) having b = 12", expSQL: "select t1.b as b, t1.a as a, t1.c as c, t5.a as a from t1 join t5 on t1.b = t5.b having t1.b = 12", }, { - sql: "select 1 from t1 join t5 using (b) having b = 12", - expSQL: "select 1 from t1 join t5 on t1.b = t5.b having t1.b = 12", + sql: "select 1 from t1 join t5 using (b) where b = 12", + expSQL: "select 1 from t1 join t5 on t1.b = t5.b where t1.b = 12", }, { sql: "select * from (select 12) as t", expSQL: "select `12` from (select 12 from dual) as t", @@ -304,6 +304,90 @@ func TestRewriteJoinUsingColumns(t *testing.T) { } +func TestGroupByColumnName(t *testing.T) { + schemaInfo := &FakeSI{ + Tables: map[string]*vindexes.Table{ + "t1": { + Name: sqlparser.NewIdentifierCS("t1"), + Columns: []vindexes.Column{{ + Name: sqlparser.NewIdentifierCI("id"), + Type: sqltypes.Int32, + }, { + Name: sqlparser.NewIdentifierCI("col1"), + Type: sqltypes.Int32, + }}, + ColumnListAuthoritative: true, + }, + "t2": { + Name: sqlparser.NewIdentifierCS("t2"), + Columns: []vindexes.Column{{ + Name: sqlparser.NewIdentifierCI("id"), + Type: sqltypes.Int32, + }, { + Name: sqlparser.NewIdentifierCI("col2"), + Type: sqltypes.Int32, + }}, + ColumnListAuthoritative: true, + }, + }, + } + cDB := "db" + tcases := []struct { + sql string + expSQL string + expDeps TableSet + expErr string + warning string + }{{ + sql: "select t3.col from t3 group by kj", + expSQL: "select t3.col from t3 group by kj", + expDeps: TS0, + }, { + sql: "select t2.col2 as xyz from t2 group by xyz", + expSQL: "select t2.col2 as xyz from t2 group by t2.col2", + expDeps: TS0, + }, { + sql: "select id from t1 group by unknown", + expErr: "Unknown column 'unknown' in 'group statement'", + }, { + sql: "select t1.c as x, sum(t2.id) as x from t1 join t2 group by x", + expErr: "VT03005: cannot group on 'x'", + }, { + sql: "select t1.col1, sum(t2.id) as col1 from t1 join t2 group by col1", + expSQL: "select t1.col1, sum(t2.id) as col1 from t1 join t2 group by col1", + expDeps: TS0, + warning: "Column 'col1' in group statement is ambiguous", + }, { + sql: "select t2.col2 as id, sum(t2.id) as x from t1 join t2 group by id", + expSQL: "select t2.col2 as id, sum(t2.id) as x from t1 join t2 group by t2.col2", + expDeps: TS1, + }, { + sql: "select sum(t2.col2) as id, sum(t2.id) as x from t1 join t2 group by id", + expErr: "VT03005: cannot group on 'id'", + }, { + sql: "select count(*) as x from t1 group by x", + expErr: "VT03005: cannot group on 'x'", + }} + for _, tcase := range tcases { + t.Run(tcase.sql, func(t *testing.T) { + ast, err := sqlparser.NewTestParser().Parse(tcase.sql) + require.NoError(t, err) + selectStatement := ast.(*sqlparser.Select) + st, err := AnalyzeStrict(selectStatement, cDB, schemaInfo) + if tcase.expErr == "" { + require.NoError(t, err) + assert.Equal(t, tcase.expSQL, sqlparser.String(selectStatement)) + gb := selectStatement.GroupBy + deps := st.RecursiveDeps(gb[0]) + assert.Equal(t, tcase.expDeps, deps) + assert.Equal(t, tcase.warning, st.Warning) + } else { + require.EqualError(t, err, tcase.expErr) + } + }) + } +} + func TestGroupByLiteral(t *testing.T) { schemaInfo := &FakeSI{ Tables: map[string]*vindexes.Table{}, @@ -432,33 +516,89 @@ func TestOrderByLiteral(t *testing.T) { } func TestHavingColumnName(t *testing.T) { - schemaInfo := &FakeSI{ - Tables: map[string]*vindexes.Table{}, - } + schemaInfo := getSchemaWithKnownColumns() cDB := "db" tcases := []struct { - sql string - expSQL string - expErr string + sql string + expSQL string + expDeps TableSet + expErr string + warning string }{{ - sql: "select id, sum(foo) as sumOfFoo from t1 having sumOfFoo > 1", - expSQL: "select id, sum(foo) as sumOfFoo from t1 having sum(foo) > 1", + sql: "select id, sum(foo) as sumOfFoo from t1 having sumOfFoo > 1", + expSQL: "select id, sum(foo) as sumOfFoo from t1 having sum(t1.foo) > 1", + expDeps: TS0, + }, { + sql: "select id as X, sum(foo) as X from t1 having X > 1", + expErr: "Column 'X' in field list is ambiguous", + }, { + sql: "select id, sum(t1.foo) as foo from t1 having sum(foo) > 1", + expSQL: "select id, sum(t1.foo) as foo from t1 having sum(foo) > 1", + expDeps: TS0, + warning: "Column 'foo' in having clause is ambiguous", + }, { + sql: "select id, sum(t1.foo) as XYZ from t1 having sum(XYZ) > 1", + expErr: "Invalid use of group function", + }, { + sql: "select foo + 2 as foo from t1 having foo = 42", + expSQL: "select foo + 2 as foo from t1 having t1.foo + 2 = 42", + expDeps: TS0, + }, { + sql: "select count(*), ename from emp group by ename having comm > 1000", + expErr: "Unknown column 'comm' in 'having clause'", + }, { + sql: "select sal, ename from emp having empno > 1000", + expSQL: "select sal, ename from emp having empno > 1000", + expDeps: TS0, + }, { + sql: "select foo, count(*) foo from t1 group by foo having foo > 1000", + expErr: "Column 'foo' in field list is ambiguous", + }, { + sql: "select foo, count(*) foo from t1, emp group by foo having sum(sal) > 1000", + expSQL: "select foo, count(*) as foo from t1, emp group by foo having sum(sal) > 1000", + expDeps: TS1, + warning: "Column 'foo' in group statement is ambiguous", + }, { + sql: "select foo as X, sal as foo from t1, emp having sum(X) > 1000", + expSQL: "select foo as X, sal as foo from t1, emp having sum(t1.foo) > 1000", + expDeps: TS0, + }, { + sql: "select count(*) a from someTable having a = 10", + expSQL: "select count(*) as a from someTable having count(*) = 10", + expDeps: TS0, + }, { + sql: "select count(*) from emp having ename = 10", + expSQL: "select count(*) from emp having ename = 10", + expDeps: TS0, + }, { + sql: "select sum(sal) empno from emp where ename > 0 having empno = 2", + expSQL: "select sum(sal) as empno from emp where ename > 0 having sum(emp.sal) = 2", + expDeps: TS0, }, { - sql: "select id, sum(foo) as foo from t1 having sum(foo) > 1", - expSQL: "select id, sum(foo) as foo from t1 having sum(foo) > 1", + // test with missing schema info + sql: "select foo, count(bar) as x from someTable group by foo having id > avg(baz)", + expErr: "Unknown column 'id' in 'having clause'", }, { - sql: "select foo + 2 as foo from t1 having foo = 42", - expSQL: "select foo + 2 as foo from t1 having foo + 2 = 42", + sql: "select t1.foo as alias, count(bar) as x from t1 group by foo having foo+54 = 56", + expSQL: "select t1.foo as alias, count(bar) as x from t1 group by foo having foo + 54 = 56", + expDeps: TS0, + }, { + sql: "select 1 from t1 group by foo having foo = 1 and count(*) > 1", + expSQL: "select 1 from t1 group by foo having foo = 1 and count(*) > 1", + expDeps: TS0, }} + for _, tcase := range tcases { t.Run(tcase.sql, func(t *testing.T) { ast, err := sqlparser.NewTestParser().Parse(tcase.sql) require.NoError(t, err) - selectStatement := ast.(sqlparser.SelectStatement) - _, err = Analyze(selectStatement, cDB, schemaInfo) + selectStatement := ast.(*sqlparser.Select) + semTbl, err := AnalyzeStrict(selectStatement, cDB, schemaInfo) if tcase.expErr == "" { require.NoError(t, err) assert.Equal(t, tcase.expSQL, sqlparser.String(selectStatement)) + assert.Equal(t, tcase.expDeps, semTbl.RecursiveDeps(selectStatement.Having.Expr)) + assert.Equal(t, tcase.warning, semTbl.Warning, "warning") } else { require.EqualError(t, err, tcase.expErr) } @@ -466,7 +606,7 @@ func TestHavingColumnName(t *testing.T) { } } -func TestOrderByColumnName(t *testing.T) { +func getSchemaWithKnownColumns() *FakeSI { schemaInfo := &FakeSI{ Tables: map[string]*vindexes.Table{ "t1": { @@ -484,66 +624,114 @@ func TestOrderByColumnName(t *testing.T) { }}, ColumnListAuthoritative: true, }, + "emp": { + Keyspace: &vindexes.Keyspace{Name: "ks", Sharded: true}, + Name: sqlparser.NewIdentifierCS("emp"), + Columns: []vindexes.Column{{ + Name: sqlparser.NewIdentifierCI("empno"), + Type: sqltypes.Int64, + }, { + Name: sqlparser.NewIdentifierCI("ename"), + Type: sqltypes.VarChar, + }, { + Name: sqlparser.NewIdentifierCI("sal"), + Type: sqltypes.Int64, + }}, + ColumnListAuthoritative: true, + }, }, } + return schemaInfo +} + +func TestOrderByColumnName(t *testing.T) { + schemaInfo := getSchemaWithKnownColumns() cDB := "db" tcases := []struct { - sql string - expSQL string - expErr string + sql string + expSQL string + expErr string + warning string + deps TableSet }{{ sql: "select id, sum(foo) as sumOfFoo from t1 order by sumOfFoo", - expSQL: "select id, sum(foo) as sumOfFoo from t1 order by sum(foo) asc", + expSQL: "select id, sum(foo) as sumOfFoo from t1 order by sum(t1.foo) asc", + deps: TS0, }, { sql: "select id, sum(foo) as sumOfFoo from t1 order by sumOfFoo + 1", - expSQL: "select id, sum(foo) as sumOfFoo from t1 order by sum(foo) + 1 asc", + expSQL: "select id, sum(foo) as sumOfFoo from t1 order by sum(t1.foo) + 1 asc", + deps: TS0, }, { sql: "select id, sum(foo) as sumOfFoo from t1 order by abs(sumOfFoo)", - expSQL: "select id, sum(foo) as sumOfFoo from t1 order by abs(sum(foo)) asc", + expSQL: "select id, sum(foo) as sumOfFoo from t1 order by abs(sum(t1.foo)) asc", + deps: TS0, }, { sql: "select id, sum(foo) as sumOfFoo from t1 order by max(sumOfFoo)", expErr: "Invalid use of group function", }, { - sql: "select id, sum(foo) as foo from t1 order by foo + 1", - expSQL: "select id, sum(foo) as foo from t1 order by foo + 1 asc", + sql: "select id, sum(foo) as foo from t1 order by foo + 1", + expSQL: "select id, sum(foo) as foo from t1 order by foo + 1 asc", + deps: TS0, + warning: "Column 'foo' in order by statement is ambiguous", }, { - sql: "select id, sum(foo) as foo from t1 order by foo", - expSQL: "select id, sum(foo) as foo from t1 order by sum(foo) asc", + sql: "select id, sum(foo) as foo from t1 order by foo", + expSQL: "select id, sum(foo) as foo from t1 order by sum(t1.foo) asc", + deps: TS0, + warning: "Column 'foo' in order by statement is ambiguous", }, { - sql: "select id, lower(min(foo)) as foo from t1 order by min(foo)", - expSQL: "select id, lower(min(foo)) as foo from t1 order by min(foo) asc", + sql: "select id, lower(min(foo)) as foo from t1 order by min(foo)", + expSQL: "select id, lower(min(foo)) as foo from t1 order by min(foo) asc", + deps: TS0, + warning: "Column 'foo' in order by statement is ambiguous", }, { - sql: "select id, lower(min(foo)) as foo from t1 order by foo", - expSQL: "select id, lower(min(foo)) as foo from t1 order by lower(min(foo)) asc", + sql: "select id, lower(min(foo)) as foo from t1 order by foo", + expSQL: "select id, lower(min(foo)) as foo from t1 order by lower(min(t1.foo)) asc", + deps: TS0, + warning: "Column 'foo' in order by statement is ambiguous", }, { - sql: "select id, lower(min(foo)) as foo from t1 order by abs(foo)", - expSQL: "select id, lower(min(foo)) as foo from t1 order by abs(foo) asc", + sql: "select id, lower(min(foo)) as foo from t1 order by abs(foo)", + expSQL: "select id, lower(min(foo)) as foo from t1 order by abs(foo) asc", + deps: TS0, + warning: "Column 'foo' in order by statement is ambiguous", }, { - sql: "select id, t1.bar as foo from t1 group by id order by min(foo)", - expSQL: "select id, t1.bar as foo from t1 group by id order by min(foo) asc", + sql: "select id, t1.bar as foo from t1 group by id order by min(foo)", + expSQL: "select id, t1.bar as foo from t1 group by id order by min(foo) asc", + deps: TS0, + warning: "Column 'foo' in order by statement is ambiguous", }, { sql: "select id, bar as id, count(*) from t1 order by id", expErr: "Column 'id' in field list is ambiguous", }, { - sql: "select id, id, count(*) from t1 order by id", - expSQL: "select id, id, count(*) from t1 order by id asc", + sql: "select id, id, count(*) from t1 order by id", + expSQL: "select id, id, count(*) from t1 order by t1.id asc", + deps: TS0, + warning: "Column 'id' in order by statement is ambiguous", }, { - sql: "select id, count(distinct foo) k from t1 group by id order by k", - expSQL: "select id, count(distinct foo) as k from t1 group by id order by count(distinct foo) asc", + sql: "select id, count(distinct foo) k from t1 group by id order by k", + expSQL: "select id, count(distinct foo) as k from t1 group by id order by count(distinct t1.foo) asc", + deps: TS0, + warning: "Column 'id' in group statement is ambiguous", }, { sql: "select user.id as foo from user union select col from user_extra order by foo", expSQL: "select `user`.id as foo from `user` union select col from user_extra order by foo asc", - }, - } + deps: MergeTableSets(TS0, TS1), + }, { + sql: "select foo as X, sal as foo from t1, emp order by sum(X)", + expSQL: "select foo as X, sal as foo from t1, emp order by sum(t1.foo) asc", + deps: TS0, + }} for _, tcase := range tcases { t.Run(tcase.sql, func(t *testing.T) { ast, err := sqlparser.NewTestParser().Parse(tcase.sql) require.NoError(t, err) selectStatement := ast.(sqlparser.SelectStatement) - _, err = Analyze(selectStatement, cDB, schemaInfo) + semTable, err := AnalyzeStrict(selectStatement, cDB, schemaInfo) if tcase.expErr == "" { require.NoError(t, err) assert.Equal(t, tcase.expSQL, sqlparser.String(selectStatement)) + orderByExpr := selectStatement.GetOrderBy()[0].Expr + assert.Equal(t, tcase.deps, semTable.RecursiveDeps(orderByExpr)) + assert.Equal(t, tcase.warning, semTable.Warning) } else { require.EqualError(t, err, tcase.expErr) } diff --git a/go/vt/vtgate/semantics/errors.go b/go/vt/vtgate/semantics/errors.go index a903408ba9d..297f2b9613e 100644 --- a/go/vt/vtgate/semantics/errors.go +++ b/go/vt/vtgate/semantics/errors.go @@ -52,7 +52,8 @@ type ( SubqueryColumnCountError struct{ Expected int } ColumnsMissingInSchemaError struct{} CantUseMultipleVindexHints struct{ Table string } - InvalidUserOfGroupFunction struct{} + InvalidUseOfGroupFunction struct{} + CantGroupOn struct{ Column string } NoSuchVindexFound struct { Table string @@ -71,6 +72,10 @@ type ( Column *sqlparser.ColName Table *sqlparser.TableName } + ColumnNotFoundClauseError struct { + Column string + Clause string + } ) func eprintf(e error, format string, args ...any) string { @@ -289,15 +294,41 @@ func (c *NoSuchVindexFound) ErrorCode() vtrpcpb.Code { return vtrpcpb.Code_FAILED_PRECONDITION } -// InvalidUserOfGroupFunction -func (*InvalidUserOfGroupFunction) Error() string { +// InvalidUseOfGroupFunction +func (*InvalidUseOfGroupFunction) Error() string { return "Invalid use of group function" } -func (*InvalidUserOfGroupFunction) ErrorCode() vtrpcpb.Code { +func (*InvalidUseOfGroupFunction) ErrorCode() vtrpcpb.Code { return vtrpcpb.Code_INVALID_ARGUMENT } -func (*InvalidUserOfGroupFunction) ErrorState() vterrors.State { +func (*InvalidUseOfGroupFunction) ErrorState() vterrors.State { return vterrors.InvalidGroupFuncUse } + +// CantGroupOn +func (e *CantGroupOn) Error() string { + return vterrors.VT03005(e.Column).Error() +} + +func (*CantGroupOn) ErrorCode() vtrpcpb.Code { + return vtrpcpb.Code_INVALID_ARGUMENT +} + +func (e *CantGroupOn) ErrorState() vterrors.State { + return vterrors.VT03005(e.Column).State +} + +// ColumnNotFoundInGroupByError +func (e *ColumnNotFoundClauseError) Error() string { + return fmt.Sprintf("Unknown column '%s' in '%s'", e.Column, e.Clause) +} + +func (*ColumnNotFoundClauseError) ErrorCode() vtrpcpb.Code { + return vtrpcpb.Code_INVALID_ARGUMENT +} + +func (e *ColumnNotFoundClauseError) ErrorState() vterrors.State { + return vterrors.BadFieldError +} diff --git a/go/vt/vtgate/semantics/scoper.go b/go/vt/vtgate/semantics/scoper.go index faef930b488..3a6fbe4c35c 100644 --- a/go/vt/vtgate/semantics/scoper.go +++ b/go/vt/vtgate/semantics/scoper.go @@ -40,13 +40,16 @@ type ( } scope struct { - parent *scope - stmt sqlparser.Statement - tables []TableInfo - isUnion bool - joinUsing map[string]TableSet - stmtScope bool - ctes map[string]*sqlparser.CommonTableExpr + parent *scope + stmt sqlparser.Statement + tables []TableInfo + isUnion bool + joinUsing map[string]TableSet + stmtScope bool + ctes map[string]*sqlparser.CommonTableExpr + inGroupBy bool + inHaving bool + inHavingAggr bool } ) @@ -76,9 +79,19 @@ func (s *scoper) down(cursor *sqlparser.Cursor) error { return s.addColumnInfoForOrderBy(cursor, node) case sqlparser.GroupBy: return s.addColumnInfoForGroupBy(cursor, node) + case sqlparser.AggrFunc: + if !s.currentScope().inHaving { + break + } + s.currentScope().inHavingAggr = true case *sqlparser.Where: if node.Type == sqlparser.HavingClause { - return s.createSpecialScopePostProjection(cursor.Parent()) + err := s.createSpecialScopePostProjection(cursor.Parent()) + if err != nil { + return err + } + s.currentScope().inHaving = true + return nil } } return nil @@ -97,10 +110,12 @@ func (s *scoper) addColumnInfoForGroupBy(cursor *sqlparser.Cursor, node sqlparse if err != nil { return err } + currentScope := s.currentScope() + currentScope.inGroupBy = true for _, expr := range node { lit := keepIntLiteral(expr) if lit != nil { - s.specialExprScopes[lit] = s.currentScope() + s.specialExprScopes[lit] = currentScope } } return nil @@ -210,6 +225,8 @@ func (s *scoper) up(cursor *sqlparser.Cursor) error { break } s.popScope() + case sqlparser.AggrFunc: + s.currentScope().inHavingAggr = false case sqlparser.TableExpr: if isParentSelect(cursor) { curScope := s.currentScope() diff --git a/go/vt/vtgate/semantics/vtable.go b/go/vt/vtgate/semantics/vtable.go index 133e38ff505..81f81de3813 100644 --- a/go/vt/vtgate/semantics/vtable.go +++ b/go/vt/vtgate/semantics/vtable.go @@ -42,10 +42,25 @@ func (v *vTableInfo) dependencies(colName string, org originable) (dependencies, if name != colName { continue } - directDeps, recursiveDeps, qt := org.depsForExpr(v.cols[i]) + deps = deps.merge(v.createCertainForCol(org, i), false) + } + if deps.empty() && v.hasStar() { + return createUncertain(v.tables, v.tables), nil + } + return deps, nil +} - newDeps := createCertain(directDeps, recursiveDeps, qt) - deps = deps.merge(newDeps, false) +func (v *vTableInfo) dependenciesInGroupBy(colName string, org originable) (dependencies, error) { + // this method is consciously very similar to vTableInfo.dependencies and should remain so + var deps dependencies = ¬hing{} + for i, name := range v.columnNames { + if name != colName { + continue + } + if sqlparser.ContainsAggregation(v.cols[i]) { + return nil, &CantGroupOn{name} + } + deps = deps.merge(v.createCertainForCol(org, i), false) } if deps.empty() && v.hasStar() { return createUncertain(v.tables, v.tables), nil @@ -53,6 +68,12 @@ func (v *vTableInfo) dependencies(colName string, org originable) (dependencies, return deps, nil } +func (v *vTableInfo) createCertainForCol(org originable, i int) *certain { + directDeps, recursiveDeps, qt := org.depsForExpr(v.cols[i]) + newDeps := createCertain(directDeps, recursiveDeps, qt) + return newDeps +} + // IsInfSchema implements the TableInfo interface func (v *vTableInfo) IsInfSchema() bool { return false