diff --git a/go/vt/vttablet/tabletserver/planbuilder/permission.go b/go/vt/vttablet/tabletserver/planbuilder/permission.go index 79b2f9eb430..57a35666122 100644 --- a/go/vt/vttablet/tabletserver/planbuilder/permission.go +++ b/go/vt/vttablet/tabletserver/planbuilder/permission.go @@ -39,17 +39,17 @@ func BuildPermissions(stmt sqlparser.Statement) []Permission { case *sqlparser.Union, *sqlparser.Select: permissions = buildSubqueryPermissions(node, tableacl.READER, permissions) case *sqlparser.Insert: - permissions = buildTableExprPermissions(node.Table, tableacl.WRITER, permissions) + permissions = buildTableExprPermissions(node.Table, tableacl.WRITER, nil, permissions) permissions = buildSubqueryPermissions(node, tableacl.READER, permissions) case *sqlparser.Update: - permissions = buildTableExprsPermissions(node.TableExprs, tableacl.WRITER, permissions) + permissions = buildTableExprsPermissions(node.TableExprs, tableacl.WRITER, nil, permissions) permissions = buildSubqueryPermissions(node, tableacl.READER, permissions) case *sqlparser.Delete: - permissions = buildTableExprsPermissions(node.TableExprs, tableacl.WRITER, permissions) + permissions = buildTableExprsPermissions(node.TableExprs, tableacl.WRITER, nil, permissions) permissions = buildSubqueryPermissions(node, tableacl.READER, permissions) case sqlparser.DDLStatement: for _, t := range node.AffectedTables() { - permissions = buildTableNamePermissions(t, tableacl.ADMIN, permissions) + permissions = buildTableNamePermissions(t, tableacl.ADMIN, nil, permissions) } case *sqlparser.AlterMigration, @@ -60,10 +60,10 @@ func BuildPermissions(stmt sqlparser.Statement) []Permission { permissions = []Permission{} // TODO(shlomi) what are the correct permissions here? Table is unknown case *sqlparser.Flush: for _, t := range node.TableNames { - permissions = buildTableNamePermissions(t, tableacl.ADMIN, permissions) + permissions = buildTableNamePermissions(t, tableacl.ADMIN, nil, permissions) } case *sqlparser.Analyze: - permissions = buildTableNamePermissions(node.Table, tableacl.WRITER, permissions) + permissions = buildTableNamePermissions(node.Table, tableacl.WRITER, nil, permissions) case *sqlparser.OtherAdmin, *sqlparser.CallProc, *sqlparser.Begin, *sqlparser.Commit, *sqlparser.Rollback, *sqlparser.Load, *sqlparser.Savepoint, *sqlparser.Release, *sqlparser.SRollback, *sqlparser.Set, *sqlparser.Show, sqlparser.Explain, *sqlparser.UnlockTables: @@ -75,43 +75,92 @@ func BuildPermissions(stmt sqlparser.Statement) []Permission { } func buildSubqueryPermissions(stmt sqlparser.Statement, role tableacl.Role, permissions []Permission) []Permission { - _ = sqlparser.Walk(func(node sqlparser.SQLNode) (bool, error) { - if sel, ok := node.(*sqlparser.Select); ok { - permissions = buildTableExprsPermissions(sel.From, role, permissions) + var cteScopes [][]sqlparser.IdentifierCS + sqlparser.Rewrite(stmt, func(cursor *sqlparser.Cursor) bool { + switch node := cursor.Node().(type) { + case *sqlparser.Select: + if node.With != nil { + cteScopes = append(cteScopes, gatherCTEs(node.With)) + } + var ctes []sqlparser.IdentifierCS + for _, cteScope := range cteScopes { + ctes = append(ctes, cteScope...) + } + permissions = buildTableExprsPermissions(node.From, role, ctes, permissions) + case *sqlparser.Delete: + if node.With != nil { + cteScopes = append(cteScopes, gatherCTEs(node.With)) + } + case *sqlparser.Update: + if node.With != nil { + cteScopes = append(cteScopes, gatherCTEs(node.With)) + } + case *sqlparser.Union: + if node.With != nil { + cteScopes = append(cteScopes, gatherCTEs(node.With)) + } } - return true, nil - }, stmt) + return true + }, func(cursor *sqlparser.Cursor) bool { + // When we encounter a With expression coming up, we should remove + // the last value from the cte scopes to ensure we none of the outer + // elements of the query see this table name. + _, isWith := cursor.Node().(*sqlparser.With) + if isWith { + cteScopes = cteScopes[:len(cteScopes)-1] + } + return true + }) return permissions } -func buildTableExprsPermissions(node []sqlparser.TableExpr, role tableacl.Role, permissions []Permission) []Permission { +// gatherCTEs gathers the CTEs from the WITH clause. +func gatherCTEs(with *sqlparser.With) []sqlparser.IdentifierCS { + var ctes []sqlparser.IdentifierCS + for _, cte := range with.CTEs { + ctes = append(ctes, cte.ID) + } + return ctes +} + +func buildTableExprsPermissions(node []sqlparser.TableExpr, role tableacl.Role, ctes []sqlparser.IdentifierCS, permissions []Permission) []Permission { for _, node := range node { - permissions = buildTableExprPermissions(node, role, permissions) + permissions = buildTableExprPermissions(node, role, ctes, permissions) } return permissions } -func buildTableExprPermissions(node sqlparser.TableExpr, role tableacl.Role, permissions []Permission) []Permission { +func buildTableExprPermissions(node sqlparser.TableExpr, role tableacl.Role, ctes []sqlparser.IdentifierCS, permissions []Permission) []Permission { switch node := node.(type) { case *sqlparser.AliasedTableExpr: // An AliasedTableExpr can also be a derived table, but we should skip them here // because the buildSubQueryPermissions walker will catch them and extract // the corresponding table names. if tblName, ok := node.Expr.(sqlparser.TableName); ok { - permissions = buildTableNamePermissions(tblName, role, permissions) + permissions = buildTableNamePermissions(tblName, role, ctes, permissions) } case *sqlparser.ParenTableExpr: - permissions = buildTableExprsPermissions(node.Exprs, role, permissions) + permissions = buildTableExprsPermissions(node.Exprs, role, ctes, permissions) case *sqlparser.JoinTableExpr: - permissions = buildTableExprPermissions(node.LeftExpr, role, permissions) - permissions = buildTableExprPermissions(node.RightExpr, role, permissions) + permissions = buildTableExprPermissions(node.LeftExpr, role, ctes, permissions) + permissions = buildTableExprPermissions(node.RightExpr, role, ctes, permissions) } return permissions } -func buildTableNamePermissions(node sqlparser.TableName, role tableacl.Role, permissions []Permission) []Permission { +func buildTableNamePermissions(node sqlparser.TableName, role tableacl.Role, ctes []sqlparser.IdentifierCS, permissions []Permission) []Permission { + tableName := node.Name.String() + // Check whether this table is a cte or not. + // If the table name is qualified, then it cannot be a cte. + if node.Qualifier.IsEmpty() { + for _, cte := range ctes { + if cte.String() == tableName { + return permissions + } + } + } permissions = append(permissions, Permission{ - TableName: node.Name.String(), + TableName: tableName, Role: role, }) return permissions diff --git a/go/vt/vttablet/tabletserver/planbuilder/permission_test.go b/go/vt/vttablet/tabletserver/planbuilder/permission_test.go index 6d42118cb0b..7a793dadbc3 100644 --- a/go/vt/vttablet/tabletserver/planbuilder/permission_test.go +++ b/go/vt/vttablet/tabletserver/planbuilder/permission_test.go @@ -174,6 +174,45 @@ func TestBuildPermissions(t *testing.T) { }, { TableName: "t1", // derived table in update or delete needs reader permission as they cannot be modified. }}, + }, { + input: "with t as (select count(*) as a from user) select a from t", + output: []Permission{{ + TableName: "user", + Role: tableacl.READER, + }}, + }, { + input: "with d as (select id, count(*) as a from user) select d.a from music join d on music.user_id = d.id group by 1", + output: []Permission{{ + TableName: "music", + Role: tableacl.READER, + }, { + TableName: "user", + Role: tableacl.READER, + }}, + }, { + input: "WITH t1 AS ( SELECT id FROM t2 ) SELECT * FROM t1 JOIN ks.t1 AS t3", + output: []Permission{{ + TableName: "t1", + Role: tableacl.READER, + }, { + TableName: "t2", + Role: tableacl.READER, + }}, + }, { + input: "WITH RECURSIVE t1 (n) AS ( SELECT id from t2 UNION ALL SELECT n + 1 FROM t1 WHERE n < 5 ) SELECT * FROM t1 JOIN t1 AS t3", + output: []Permission{{ + TableName: "t2", + Role: tableacl.READER, + }}, + }, { + input: "(with t1 as (select count(*) as a from user) select a from t1) union select * from t1", + output: []Permission{{ + TableName: "user", + Role: tableacl.READER, + }, { + TableName: "t1", + Role: tableacl.READER, + }}, }} for _, tcase := range tcases {