diff --git a/presto-main/src/main/java/com/facebook/presto/execution/UseTask.java b/presto-main/src/main/java/com/facebook/presto/execution/UseTask.java index 03a582bcec592..e7ba2b84ce731 100644 --- a/presto-main/src/main/java/com/facebook/presto/execution/UseTask.java +++ b/presto-main/src/main/java/com/facebook/presto/execution/UseTask.java @@ -15,18 +15,24 @@ import com.facebook.presto.Session; import com.facebook.presto.common.CatalogSchemaName; +import com.facebook.presto.common.transaction.TransactionId; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.security.AccessControl; +import com.facebook.presto.spi.security.AccessControlContext; +import com.facebook.presto.spi.security.AccessDeniedException; +import com.facebook.presto.spi.security.Identity; import com.facebook.presto.sql.analyzer.SemanticException; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.Identifier; import com.facebook.presto.sql.tree.Use; import com.facebook.presto.transaction.TransactionManager; +import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.ListenableFuture; import java.util.List; import static com.facebook.presto.metadata.MetadataUtil.getConnectorIdOrThrow; +import static com.facebook.presto.spi.security.AccessDeniedException.denyCatalogAccess; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.CATALOG_NOT_SPECIFIED; import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_SCHEMA; import static com.google.common.util.concurrent.Futures.immediateFuture; @@ -35,6 +41,8 @@ public class UseTask implements SessionTransactionControlTask { + String catalog; + String schema; @Override public String getName() { @@ -52,12 +60,28 @@ public ListenableFuture execute( { Session session = stateMachine.getSession(); + TransactionId transactionId = session.getTransactionId().get(); + + Identity identity = session.getIdentity(); + + AccessControlContext context = session.getAccessControlContext(); + checkCatalogAndSessionPresent(statement, session); checkAndSetCatalog(statement, metadata, stateMachine, session); checkAndSetSchema(statement, metadata, stateMachine, session); + if (!hasCatalogAccess(identity, context, catalog, accessControl)) { + denyCatalogAccess(catalog); + } + + CatalogSchemaName name = new CatalogSchemaName(catalog, schema); + + if (!hasSchemaAccess(transactionId, identity, context, catalog, schema, accessControl)) { + throw new AccessDeniedException("Cannot access schema: " + name); + } + return immediateFuture(null); } @@ -79,13 +103,23 @@ private void checkAndSetCatalog(Use statement, Metadata metadata, QueryStateMach private void checkAndSetSchema(Use statement, Metadata metadata, QueryStateMachine stateMachine, Session session) { - String catalog = statement.getCatalog() + catalog = statement.getCatalog() .map(Identifier::getValueLowerCase) .orElseGet(() -> session.getCatalog().map(String::toLowerCase).get()); - String schema = statement.getSchema().getValueLowerCase(); + schema = statement.getSchema().getValueLowerCase(); if (!metadata.getMetadataResolver(session).schemaExists(new CatalogSchemaName(catalog, schema))) { throw new SemanticException(MISSING_SCHEMA, format("Schema does not exist: %s.%s", catalog, schema)); } stateMachine.setSetSchema(schema); } + + private boolean hasCatalogAccess(Identity identity, AccessControlContext context, String catalog, AccessControl accessControl) + { + return !accessControl.filterCatalogs(identity, context, ImmutableSet.of(catalog)).isEmpty(); + } + + private boolean hasSchemaAccess(TransactionId transactionId, Identity identity, AccessControlContext context, String catalog, String schema, AccessControl accessControl) + { + return !accessControl.filterSchemas(transactionId, identity, context, catalog, ImmutableSet.of(schema)).isEmpty(); + } } diff --git a/presto-main/src/test/java/com/facebook/presto/execution/TestUseTask.java b/presto-main/src/test/java/com/facebook/presto/execution/TestUseTask.java new file mode 100644 index 0000000000000..99cbe23240d68 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/execution/TestUseTask.java @@ -0,0 +1,184 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.execution; + +import com.facebook.presto.Session; +import com.facebook.presto.connector.MockConnectorFactory; +import com.facebook.presto.metadata.Catalog; +import com.facebook.presto.metadata.CatalogManager; +import com.facebook.presto.metadata.MetadataManager; +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.connector.Connector; +import com.facebook.presto.spi.security.AccessControl; +import com.facebook.presto.spi.security.AccessDeniedException; +import com.facebook.presto.spi.security.AllowAllAccessControl; +import com.facebook.presto.spi.security.DenyAllAccessControl; +import com.facebook.presto.spi.security.Identity; +import com.facebook.presto.sql.analyzer.SemanticException; +import com.facebook.presto.sql.tree.Identifier; +import com.facebook.presto.sql.tree.Use; +import com.facebook.presto.testing.TestingConnectorContext; +import com.facebook.presto.transaction.TransactionManager; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import java.util.Optional; +import java.util.concurrent.ExecutorService; + +import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed; +import static com.facebook.presto.SessionTestUtils.TEST_SESSION; +import static com.facebook.presto.execution.TaskTestUtils.createQueryStateMachine; +import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; +import static com.facebook.presto.spi.ConnectorId.createInformationSchemaConnectorId; +import static com.facebook.presto.spi.ConnectorId.createSystemTablesConnectorId; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_CATALOG; +import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_SCHEMA; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static com.facebook.presto.transaction.InMemoryTransactionManager.createTestTransactionManager; +import static java.util.Collections.emptyList; +import static java.util.concurrent.Executors.newCachedThreadPool; +import static org.testng.Assert.assertEquals; + +@Test(singleThreaded = true) +public class TestUseTask +{ + private final ExecutorService executor = newCachedThreadPool(daemonThreadsNamed("test-%s")); + private CatalogManager catalogManager; + private TransactionManager transactionManager; + private MetadataManager metadata = createTestMetadataManager(); + MockConnectorFactory.Builder builder = MockConnectorFactory.builder(); + MockConnectorFactory mockConnectorFactory = builder.withListSchemaNames(connectorSession -> ImmutableList.of("test_schema")) + .build(); + @AfterClass(alwaysRun = true) + public void tearDown() + { + executor.shutdownNow(); + } + + @Test + public void testUse() + { + Use use = new Use(Optional.of(identifier("test_catalog")), identifier("test_schema")); + String sqlString = "USE test_catalog.test_schema"; + executeUse(use, sqlString, TEST_SESSION); + } + + @Test + public void testUseNoCatalog() + { + Use use = new Use(Optional.empty(), identifier("test_schema")); + String sqlString = "USE test_schema"; + Session session = testSessionBuilder() + .setCatalog(null) + .setSchema(null) + .build(); + try { + executeUse(use, sqlString, session); + } + catch (SemanticException e) { + assertEquals(e.getMessage(), "Catalog must be specified when session catalog is not set"); + } + } + + @Test + public void testUseInvalidCatalog() + { + Use use = new Use(Optional.of(identifier("invalid_catalog")), identifier("test_schema")); + String sqlString = "USE invalid_catalog.test_schema"; + try { + executeUse(use, sqlString, TEST_SESSION); + } + catch (SemanticException e) { + assertEquals(e.getCode(), MISSING_CATALOG); + assertEquals(e.getMessage(), "Catalog does not exist: invalid_catalog"); + } + } + + @Test + public void testUseInvalidSchema() + { + Use use = new Use(Optional.of(identifier("test_catalog")), identifier("invalid_schema")); + String sqlString = "USE test_catalog.invalid_schema"; + Session session = testSessionBuilder() + .setSchema("invalid_schema") + .build(); + try { + executeUse(use, sqlString, session); + } + catch (SemanticException e) { + assertEquals(e.getCode(), MISSING_SCHEMA); + assertEquals(e.getMessage(), "Schema does not exist: test_catalog.invalid_schema"); + } + } + + @Test + public void testUseAccessDenied() + { + Use use = new Use(Optional.of(identifier("test_catalog")), identifier("test_schema")); + String sqlString = "USE test_catalog.test_schema"; + Session session = testSessionBuilder() + .setIdentity(new Identity("user", Optional.empty())) + .build(); + AccessControl accessControl = new DenyAllAccessControl(); + try { + executeUse(use, sqlString, session); + } + catch (AccessDeniedException e) { + assertEquals(e.getMessage(), "Cannot access schema: test_catalog.test_schema"); + } + } + + private void executeUse(Use use, String sqlString, Session session) + { + executeUse(use, sqlString, session, new AllowAllAccessControl()); + } + + @BeforeMethod + public void setUp() + { + catalogManager = new CatalogManager(); + transactionManager = createTestTransactionManager(catalogManager); + metadata = createTestMetadataManager(transactionManager); + } + + private void executeUse(Use use, String sqlString, Session session, AccessControl accessControl) + { + catalogManager = new CatalogManager(); + transactionManager = createTestTransactionManager(catalogManager); + metadata = createTestMetadataManager(transactionManager); + + Connector testConnector = mockConnectorFactory.create("test", ImmutableMap.of(), new TestingConnectorContext()); + String catalogName = "test_catalog"; + ConnectorId connectorId = new ConnectorId(catalogName); + catalogManager.registerCatalog(new Catalog( + catalogName, + connectorId, + testConnector, + createInformationSchemaConnectorId(connectorId), + testConnector, + createSystemTablesConnectorId(connectorId), + testConnector)); + + QueryStateMachine stateMachine = createQueryStateMachine(sqlString, session, false, transactionManager, executor, metadata); + UseTask useTask = new UseTask(); + useTask.execute(use, transactionManager, metadata, accessControl, stateMachine, emptyList()); + } + private Identifier identifier(String name) + { + return new Identifier(name); + } +}