Skip to content

Commit

Permalink
Add catalog and schema level access checks in USE statement
Browse files Browse the repository at this point in the history
  • Loading branch information
annmegha committed Dec 19, 2024
1 parent 004ee32 commit bcf13f1
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,24 @@

import com.facebook.presto.Session;
import com.facebook.presto.common.CatalogSchemaName;
import com.facebook.presto.common.transaction.TransactionId;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.security.AccessControl;
import com.facebook.presto.spi.security.AccessControlContext;
import com.facebook.presto.spi.security.AccessDeniedException;
import com.facebook.presto.spi.security.Identity;
import com.facebook.presto.sql.analyzer.SemanticException;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.Identifier;
import com.facebook.presto.sql.tree.Use;
import com.facebook.presto.transaction.TransactionManager;
import com.google.common.collect.ImmutableSet;
import com.google.common.util.concurrent.ListenableFuture;

import java.util.List;

import static com.facebook.presto.metadata.MetadataUtil.getConnectorIdOrThrow;
import static com.facebook.presto.spi.security.AccessDeniedException.denyCatalogAccess;
import static com.facebook.presto.sql.analyzer.SemanticErrorCode.CATALOG_NOT_SPECIFIED;
import static com.facebook.presto.sql.analyzer.SemanticErrorCode.MISSING_SCHEMA;
import static com.google.common.util.concurrent.Futures.immediateFuture;
Expand Down Expand Up @@ -54,9 +60,9 @@ public ListenableFuture<?> execute(

checkCatalogAndSessionPresent(statement, session);

checkAndSetCatalog(statement, metadata, stateMachine, session);
checkAndSetCatalog(statement, metadata, stateMachine, session, accessControl);

checkAndSetSchema(statement, metadata, stateMachine, session);
checkAndSetSchema(statement, metadata, stateMachine, session, accessControl);

return immediateFuture(null);
}
Expand All @@ -68,16 +74,16 @@ private void checkCatalogAndSessionPresent(Use statement, Session session)
}
}

private void checkAndSetCatalog(Use statement, Metadata metadata, QueryStateMachine stateMachine, Session session)
private void checkAndSetCatalog(Use statement, Metadata metadata, QueryStateMachine stateMachine, Session session, AccessControl accessControl)
{
if (statement.getCatalog().isPresent()) {
String catalog = statement.getCatalog().get().getValueLowerCase();
getConnectorIdOrThrow(session, metadata, catalog);
stateMachine.setSetCatalog(catalog);
}
String catalog = statement.getCatalog()
.map(Identifier::getValueLowerCase)
.orElseGet(() -> session.getCatalog().map(String::toLowerCase).get());
getConnectorIdOrThrow(session, metadata, catalog);
stateMachine.setSetCatalog(catalog);
}

private void checkAndSetSchema(Use statement, Metadata metadata, QueryStateMachine stateMachine, Session session)
private void checkAndSetSchema(Use statement, Metadata metadata, QueryStateMachine stateMachine, Session session, AccessControl accessControl)
{
String catalog = statement.getCatalog()
.map(Identifier::getValueLowerCase)
Expand All @@ -86,6 +92,22 @@ private void checkAndSetSchema(Use statement, Metadata metadata, QueryStateMachi
if (!metadata.getMetadataResolver(session).schemaExists(new CatalogSchemaName(catalog, schema))) {
throw new SemanticException(MISSING_SCHEMA, format("Schema does not exist: %s.%s", catalog, schema));
}
if (!hasCatalogAccess(session.getIdentity(), session.getAccessControlContext(), catalog, accessControl)) {
denyCatalogAccess(catalog);
}
if (!hasSchemaAccess(session.getTransactionId().get(), session.getIdentity(), session.getAccessControlContext(), catalog, schema, accessControl)) {
throw new AccessDeniedException("Cannot access schema: " + new CatalogSchemaName(catalog, schema));
}
stateMachine.setSetSchema(schema);
}

private boolean hasCatalogAccess(Identity identity, AccessControlContext context, String catalog, AccessControl accessControl)
{
return !accessControl.filterCatalogs(identity, context, ImmutableSet.of(catalog)).isEmpty();
}

private boolean hasSchemaAccess(TransactionId transactionId, Identity identity, AccessControlContext context, String catalog, String schema, AccessControl accessControl)
{
return !accessControl.filterSchemas(transactionId, identity, context, catalog, ImmutableSet.of(schema)).isEmpty();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.execution;

import com.facebook.presto.Session;
import com.facebook.presto.connector.MockConnectorFactory;
import com.facebook.presto.metadata.Catalog;
import com.facebook.presto.metadata.CatalogManager;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.spi.ConnectorId;
import com.facebook.presto.spi.connector.Connector;
import com.facebook.presto.spi.security.AccessControl;
import com.facebook.presto.spi.security.AccessDeniedException;
import com.facebook.presto.spi.security.AllowAllAccessControl;
import com.facebook.presto.spi.security.DenyAllAccessControl;
import com.facebook.presto.spi.security.Identity;
import com.facebook.presto.sql.analyzer.SemanticException;
import com.facebook.presto.sql.tree.Identifier;
import com.facebook.presto.sql.tree.Use;
import com.facebook.presto.testing.TestingConnectorContext;
import com.facebook.presto.transaction.TransactionManager;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

import java.util.Optional;
import java.util.concurrent.ExecutorService;

import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed;
import static com.facebook.presto.SessionTestUtils.TEST_SESSION;
import static com.facebook.presto.execution.TaskTestUtils.createQueryStateMachine;
import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager;
import static com.facebook.presto.spi.ConnectorId.createInformationSchemaConnectorId;
import static com.facebook.presto.spi.ConnectorId.createSystemTablesConnectorId;
import static com.facebook.presto.testing.TestingSession.testSessionBuilder;
import static com.facebook.presto.transaction.InMemoryTransactionManager.createTestTransactionManager;
import static java.util.Collections.emptyList;
import static java.util.concurrent.Executors.newCachedThreadPool;

@Test(singleThreaded = true)
public class TestUseTask
{
private final ExecutorService executor = newCachedThreadPool(daemonThreadsNamed("test-%s"));
private CatalogManager catalogManager;
private TransactionManager transactionManager;
private MetadataManager metadata = createTestMetadataManager();
MockConnectorFactory.Builder builder = MockConnectorFactory.builder();
MockConnectorFactory mockConnectorFactory = builder.withListSchemaNames(connectorSession -> ImmutableList.of("test_schema"))
.build();

@BeforeClass
public void setUp()
{
catalogManager = new CatalogManager();
transactionManager = createTestTransactionManager(catalogManager);
metadata = createTestMetadataManager(transactionManager);
}

@AfterClass(alwaysRun = true)
public void tearDown()
{
executor.shutdownNow();
}

@Test
public void testUse()
{
Use use = new Use(Optional.of(identifier("test_catalog")), identifier("test_schema"));
String sqlString = "USE test_catalog.test_schema";
executeUse(use, sqlString, TEST_SESSION);
}

@Test(
expectedExceptions = SemanticException.class,
expectedExceptionsMessageRegExp = "Catalog must be specified when session catalog is not set")
public void testUseNoCatalog()
{
Use use = new Use(Optional.empty(), identifier("test_schema"));
String sqlString = "USE test_schema";
Session session = testSessionBuilder()
.setCatalog(null)
.setSchema(null)
.build();
executeUse(use, sqlString, session);
}

@Test(
expectedExceptions = SemanticException.class,
expectedExceptionsMessageRegExp = "Catalog does not exist: invalid_catalog")
public void testUseInvalidCatalog()
{
Use use = new Use(Optional.of(identifier("invalid_catalog")), identifier("test_schema"));
String sqlString = "USE invalid_catalog.test_schema";
executeUse(use, sqlString, TEST_SESSION);
}

@Test(
expectedExceptions = SemanticException.class,
expectedExceptionsMessageRegExp = "Schema does not exist: test_catalog.invalid_schema")
public void testUseInvalidSchema()
{
Use use = new Use(Optional.of(identifier("test_catalog")), identifier("invalid_schema"));
String sqlString = "USE test_catalog.invalid_schema";
Session session = testSessionBuilder()
.setSchema("invalid_schema")
.build();
executeUse(use, sqlString, session);
}

@Test(
expectedExceptions = AccessDeniedException.class,
expectedExceptionsMessageRegExp = "Access Denied: Cannot access catalog test_catalog")
public void testUseAccessDenied()
{
Use use = new Use(Optional.of(identifier("test_catalog")), identifier("test_schema"));
String sqlString = "USE test_catalog.test_schema";
Session session = testSessionBuilder()
.setIdentity(new Identity("user", Optional.empty()))
.build();
AccessControl accessControl = new DenyAllAccessControl();
executeUse(use, sqlString, session, accessControl);
}

private void executeUse(Use use, String sqlString, Session session)
{
executeUse(use, sqlString, session, new AllowAllAccessControl());
}

private void executeUse(Use use, String sqlString, Session session, AccessControl accessControl)
{
catalogManager = new CatalogManager();
transactionManager = createTestTransactionManager(catalogManager);
metadata = createTestMetadataManager(transactionManager);

Connector testConnector = mockConnectorFactory.create("test", ImmutableMap.of(), new TestingConnectorContext());
String catalogName = "test_catalog";
ConnectorId connectorId = new ConnectorId(catalogName);
catalogManager.registerCatalog(new Catalog(
catalogName,
connectorId,
testConnector,
createInformationSchemaConnectorId(connectorId),
testConnector,
createSystemTablesConnectorId(connectorId),
testConnector));

QueryStateMachine stateMachine = createQueryStateMachine(sqlString, session, false, transactionManager, executor, metadata);
UseTask useTask = new UseTask();
useTask.execute(use, transactionManager, metadata, accessControl, stateMachine, emptyList());
}
private Identifier identifier(String name)
{
return new Identifier(name);
}
}

0 comments on commit bcf13f1

Please sign in to comment.