From 28c5aa4eb21c076d544687300a8b110e8aa48604 Mon Sep 17 00:00:00 2001 From: SthuthiGhosh9400 Date: Fri, 18 Oct 2024 00:46:50 +0530 Subject: [PATCH] Add additional headers to the client request. Add ClientRequestFilter.java interface in Presto-spi. Improve Request Headers in the Authentication Filter Class. --- .../sphinx/develop/client-request-filter.rst | 22 ++ presto-main/pom.xml | 5 + .../presto/ClientRequestFilterManager.java | 72 +++++++ .../presto/ClientRequestFilterModule.java | 28 +++ .../facebook/presto/server/PluginManager.java | 10 + .../facebook/presto/server/PrestoServer.java | 4 + .../server/security/AuthenticationFilter.java | 100 ++++++++- .../server/testing/TestingPrestoServer.java | 13 ++ .../presto/testing/LocalQueryRunner.java | 2 + .../presto/TestClientRequestFilterPlugin.java | 196 ++++++++++++++++++ .../presto/server/MockHttpServletRequest.java | 8 + .../presto/spark/PrestoSparkModule.java | 3 +- .../presto/spi/ClientRequestFilter.java | 25 +++ .../spi/ClientRequestFilterFactory.java | 21 ++ .../java/com/facebook/presto/spi/Plugin.java | 5 + .../presto/spi/StandardErrorCode.java | 1 + 16 files changed, 511 insertions(+), 4 deletions(-) create mode 100644 presto-docs/src/main/sphinx/develop/client-request-filter.rst create mode 100644 presto-main/src/main/java/com/facebook/presto/ClientRequestFilterManager.java create mode 100644 presto-main/src/main/java/com/facebook/presto/ClientRequestFilterModule.java create mode 100644 presto-main/src/test/java/com/facebook/presto/TestClientRequestFilterPlugin.java create mode 100644 presto-spi/src/main/java/com/facebook/presto/spi/ClientRequestFilter.java create mode 100644 presto-spi/src/main/java/com/facebook/presto/spi/ClientRequestFilterFactory.java diff --git a/presto-docs/src/main/sphinx/develop/client-request-filter.rst b/presto-docs/src/main/sphinx/develop/client-request-filter.rst new file mode 100644 index 0000000000000..12a5a5ecb3088 --- /dev/null +++ b/presto-docs/src/main/sphinx/develop/client-request-filter.rst @@ -0,0 +1,22 @@ + +====================== +Client Request Filter +====================== + +Presto allows operators to customize the headers used by the Presto runtime to process queries. Some example use cases include customized authentication workflows, or enriching query attributes such as the query source. This can be achieved using the Client Request Filter plugin, which allows control over header customization during query execution. + +Implementation +-------------- + +The ``ClientRequestFilterFactory`` is responsible for creating instances of ``ClientRequestFilter``. It also defines +the name of the filter. + +The ``ClientRequestFilter`` interface provides two methods: ``getExtraHeaders()``, which allows the runtime to quickly check if it needs to apply a more expensive call to enrich the headers, and ``getHeaderNames()``, which returns a list of header names used as the header names in client requests. + +The implementation of ``ClientRequestFilterFactory`` must be wrapped as a plugin and installed on the Presto cluster. + +After installing a plugin that implements ``ClientRequestFilterFactory`` on the coordinator, the ``AuthenticationFilter`` class passes the ``principal`` object to the request filter, which returns the header values as a map. + +Presto uses the request filter to determine whether a header is present in the blocklist. The blocklist includes headers such as ``X-Presto-Transaction-Id``, ``X-Presto-Started-Transaction-Id``, ``X-Presto-Clear-Transaction-Id``, and ``X-Presto-Trace-Token``, which are not allowed to be overridden. For a complete list of headers that may be overridden, please refer to the `Java source`_ (note that the blocklist headers are also listed there, but they are not eligible for overriding). + +.. _Java source: https://github.com/prestodb/presto/blob/master/presto-client/src/main/java/com/facebook/presto/client/PrestoHeaders.java diff --git a/presto-main/pom.xml b/presto-main/pom.xml index f6ab94cfd3330..a8994923ca5fc 100644 --- a/presto-main/pom.xml +++ b/presto-main/pom.xml @@ -503,6 +503,11 @@ ratis-common true + + com.squareup.okhttp3 + mockwebserver + test + diff --git a/presto-main/src/main/java/com/facebook/presto/ClientRequestFilterManager.java b/presto-main/src/main/java/com/facebook/presto/ClientRequestFilterManager.java new file mode 100644 index 0000000000000..212193a2d08aa --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/ClientRequestFilterManager.java @@ -0,0 +1,72 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto; + +import com.facebook.presto.spi.ClientRequestFilter; +import com.facebook.presto.spi.ClientRequestFilterFactory; +import com.facebook.presto.spi.PrestoException; +import com.google.common.collect.ImmutableList; + +import javax.annotation.concurrent.GuardedBy; +import javax.annotation.concurrent.ThreadSafe; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; + +import static com.facebook.presto.spi.StandardErrorCode.ALREADY_EXISTS; +import static com.facebook.presto.spi.StandardErrorCode.CONFIGURATION_INVALID; +import static com.facebook.presto.spi.StandardErrorCode.INVALID_ARGUMENTS; +import static com.google.common.collect.ImmutableList.toImmutableList; + +@ThreadSafe +public class ClientRequestFilterManager +{ + private Map factories = new ConcurrentHashMap<>(); + + @GuardedBy("this") + private volatile List filters = ImmutableList.of(); + private final AtomicBoolean loaded = new AtomicBoolean(); + + public void registerClientRequestFilterFactory(ClientRequestFilterFactory factory) + { + if (loaded.get()) { + throw new PrestoException(INVALID_ARGUMENTS, "Cannot register factories after filters are loaded."); + } + + String name = factory.getName(); + if (factories.putIfAbsent(name, factory) != null) { + throw new PrestoException(ALREADY_EXISTS, "A factory with the name '" + name + "' is already registered."); + } + } + + public void loadClientRequestFilters() + { + if (!loaded.compareAndSet(false, true)) { + throw new PrestoException(CONFIGURATION_INVALID, "loadClientRequestFilters can only be called once."); + } + + filters = factories.values().stream() + .map(factory -> factory.create(factory.getName())) + .collect(toImmutableList()); + factories = null; + } + + public List getClientRequestFilters() + { + return filters; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/ClientRequestFilterModule.java b/presto-main/src/main/java/com/facebook/presto/ClientRequestFilterModule.java new file mode 100644 index 0000000000000..4beaa15db8e6f --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/ClientRequestFilterModule.java @@ -0,0 +1,28 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto; + +import com.google.inject.Binder; +import com.google.inject.Module; +import com.google.inject.Scopes; + +public class ClientRequestFilterModule + implements Module +{ + @Override + public void configure(Binder binder) + { + binder.bind(ClientRequestFilterManager.class).in(Scopes.SINGLETON); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/PluginManager.java b/presto-main/src/main/java/com/facebook/presto/server/PluginManager.java index f1cef40a1a3f6..97d86ee3322ea 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/PluginManager.java +++ b/presto-main/src/main/java/com/facebook/presto/server/PluginManager.java @@ -15,6 +15,7 @@ import com.facebook.airlift.log.Logger; import com.facebook.airlift.node.NodeInfo; +import com.facebook.presto.ClientRequestFilterManager; import com.facebook.presto.common.block.BlockEncoding; import com.facebook.presto.common.block.BlockEncodingManager; import com.facebook.presto.common.type.ParametricType; @@ -27,6 +28,7 @@ import com.facebook.presto.metadata.Metadata; import com.facebook.presto.security.AccessControlManager; import com.facebook.presto.server.security.PasswordAuthenticatorManager; +import com.facebook.presto.spi.ClientRequestFilterFactory; import com.facebook.presto.spi.CoordinatorPlugin; import com.facebook.presto.spi.Plugin; import com.facebook.presto.spi.analyzer.AnalyzerProvider; @@ -134,6 +136,7 @@ public class PluginManager private final AnalyzerProviderManager analyzerProviderManager; private final QueryPreparerProviderManager queryPreparerProviderManager; private final NodeStatusNotificationManager nodeStatusNotificationManager; + private final ClientRequestFilterManager clientRequestFilterManager; private final PlanCheckerProviderManager planCheckerProviderManager; @Inject @@ -157,6 +160,7 @@ public PluginManager( HistoryBasedPlanStatisticsManager historyBasedPlanStatisticsManager, TracerProviderManager tracerProviderManager, NodeStatusNotificationManager nodeStatusNotificationManager, + ClientRequestFilterManager clientRequestFilterManager, PlanCheckerProviderManager planCheckerProviderManager) { requireNonNull(nodeInfo, "nodeInfo is null"); @@ -189,6 +193,7 @@ public PluginManager( this.analyzerProviderManager = requireNonNull(analyzerProviderManager, "analyzerProviderManager is null"); this.queryPreparerProviderManager = requireNonNull(queryPreparerProviderManager, "queryPreparerProviderManager is null"); this.nodeStatusNotificationManager = requireNonNull(nodeStatusNotificationManager, "nodeStatusNotificationManager is null"); + this.clientRequestFilterManager = requireNonNull(clientRequestFilterManager, "clientRequestFilterManager is null"); this.planCheckerProviderManager = requireNonNull(planCheckerProviderManager, "planCheckerProviderManager is null"); } @@ -354,6 +359,11 @@ public void installPlugin(Plugin plugin) log.info("Registering node status notification provider %s", nodeStatusNotificationProviderFactory.getName()); nodeStatusNotificationManager.addNodeStatusNotificationProviderFactory(nodeStatusNotificationProviderFactory); } + + for (ClientRequestFilterFactory clientRequestFilterFactory : plugin.getClientRequestFilterFactories()) { + log.info("Registering client request filter factory"); + clientRequestFilterManager.registerClientRequestFilterFactory(clientRequestFilterFactory); + } } public void installCoordinatorPlugin(CoordinatorPlugin plugin) diff --git a/presto-main/src/main/java/com/facebook/presto/server/PrestoServer.java b/presto-main/src/main/java/com/facebook/presto/server/PrestoServer.java index 1c0ce0dff9bbc..56a4e0aa7e77b 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/PrestoServer.java +++ b/presto-main/src/main/java/com/facebook/presto/server/PrestoServer.java @@ -31,6 +31,8 @@ import com.facebook.airlift.tracetoken.TraceTokenModule; import com.facebook.drift.server.DriftServer; import com.facebook.drift.transport.netty.server.DriftNettyServerTransport; +import com.facebook.presto.ClientRequestFilterManager; +import com.facebook.presto.ClientRequestFilterModule; import com.facebook.presto.dispatcher.QueryPrerequisitesManager; import com.facebook.presto.dispatcher.QueryPrerequisitesManagerModule; import com.facebook.presto.eventlistener.EventListenerManager; @@ -115,6 +117,7 @@ public void run() new NodeModule(), new DiscoveryModule(), new HttpServerModule(), + new ClientRequestFilterModule(), new JsonModule(), installModuleIf( FeaturesConfig.class, @@ -188,6 +191,7 @@ public void run() PluginNodeManager pluginNodeManager = new PluginNodeManager(nodeManager, nodeInfo.getEnvironment()); planCheckerProviderManager.loadPlanCheckerProviders(pluginNodeManager); + injector.getInstance(ClientRequestFilterManager.class).loadClientRequestFilters(); startAssociatedProcesses(injector); injector.getInstance(Announcer.class).start(); diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/AuthenticationFilter.java b/presto-main/src/main/java/com/facebook/presto/server/security/AuthenticationFilter.java index 96866a51594bf..6914f8d7b03f5 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/security/AuthenticationFilter.java +++ b/presto-main/src/main/java/com/facebook/presto/server/security/AuthenticationFilter.java @@ -15,9 +15,14 @@ import com.facebook.airlift.http.server.AuthenticationException; import com.facebook.airlift.http.server.Authenticator; +import com.facebook.presto.ClientRequestFilterManager; +import com.facebook.presto.spi.ClientRequestFilter; +import com.facebook.presto.spi.PrestoException; import com.google.common.base.Joiner; import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.common.net.HttpHeaders; import javax.inject.Inject; @@ -35,14 +40,20 @@ import java.io.InputStream; import java.io.PrintWriter; import java.security.Principal; +import java.util.Enumeration; +import java.util.HashSet; import java.util.LinkedHashSet; import java.util.List; +import java.util.Map; import java.util.Set; +import static com.facebook.presto.spi.StandardErrorCode.HEADER_MODIFICATION_ATTEMPT; import static com.google.common.io.ByteStreams.copy; import static com.google.common.io.ByteStreams.nullOutputStream; import static com.google.common.net.HttpHeaders.WWW_AUTHENTICATE; import static com.google.common.net.MediaType.PLAIN_TEXT_UTF_8; +import static java.util.Collections.enumeration; +import static java.util.Collections.list; import static java.util.Objects.requireNonNull; import static javax.servlet.http.HttpServletResponse.SC_UNAUTHORIZED; @@ -52,12 +63,15 @@ public class AuthenticationFilter private static final String HTTPS_PROTOCOL = "https"; private final List authenticators; private final boolean allowForwardedHttps; + private final ClientRequestFilterManager clientRequestFilterManager; + private final List headersBlockList = ImmutableList.of("X-Presto-Transaction-Id", "X-Presto-Started-Transaction-Id", "X-Presto-Clear-Transaction-Id", "X-Presto-Trace-Token"); @Inject - public AuthenticationFilter(List authenticators, SecurityConfig securityConfig) + public AuthenticationFilter(List authenticators, SecurityConfig securityConfig, ClientRequestFilterManager clientRequestFilterManager) { this.authenticators = ImmutableList.copyOf(requireNonNull(authenticators, "authenticators is null")); this.allowForwardedHttps = requireNonNull(securityConfig, "securityConfig is null").getAllowForwardedHttps(); + this.clientRequestFilterManager = requireNonNull(clientRequestFilterManager, "clientRequestFilterManager is null"); } @Override @@ -95,9 +109,9 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo e.getAuthenticateHeader().ifPresent(authenticateHeaders::add); continue; } - // authentication succeeded - nextFilter.doFilter(withPrincipal(request, principal), response); + HttpServletRequest wrappedRequest = mergeExtraHeaders(request, principal); + nextFilter.doFilter(withPrincipal(wrappedRequest, principal), response); return; } @@ -126,6 +140,47 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo } } + public HttpServletRequest mergeExtraHeaders(HttpServletRequest request, Principal principal) + { + List clientRequestFilters = clientRequestFilterManager.getClientRequestFilters(); + + if (clientRequestFilters.isEmpty()) { + return request; + } + + ImmutableMap.Builder extraHeadersMapBuilder = ImmutableMap.builder(); + Set addedHeaders = new HashSet<>(); + + for (ClientRequestFilter requestFilter : clientRequestFilters) { + boolean headersPresent = requestFilter.getExtraHeaderKeys().stream() + .allMatch(headerName -> request.getHeader(headerName) != null); + + if (!headersPresent) { + Map extraHeaderValueMap = requestFilter.getExtraHeaders(principal); + + if (!extraHeaderValueMap.isEmpty()) { + for (Map.Entry extraHeaderEntry : extraHeaderValueMap.entrySet()) { + String headerKey = extraHeaderEntry.getKey(); + if (headersBlockList.contains(headerKey)) { + throw new PrestoException(HEADER_MODIFICATION_ATTEMPT, + "Modification attempt detected: The header " + headerKey + " is not allowed to be modified. The following headers cannot be modified: " + + String.join(", ", headersBlockList)); + } + if (addedHeaders.contains(headerKey)) { + throw new PrestoException(HEADER_MODIFICATION_ATTEMPT, "Header conflict detected: " + headerKey + " already added by another filter."); + } + if (request.getHeader(headerKey) == null && requestFilter.getExtraHeaderKeys().contains(headerKey)) { + extraHeadersMapBuilder.put(headerKey, extraHeaderEntry.getValue()); + addedHeaders.add(headerKey); + } + } + } + } + } + + return new ModifiedHttpServletRequest(request, extraHeadersMapBuilder.build()); + } + private boolean doesRequestSupportAuthentication(HttpServletRequest request) { if (authenticators.isEmpty()) { @@ -166,4 +221,43 @@ private static void skipRequestBody(HttpServletRequest request) copy(inputStream, nullOutputStream()); } } + + public static class ModifiedHttpServletRequest + extends HttpServletRequestWrapper + { + private final Map customHeaders; + + public ModifiedHttpServletRequest(HttpServletRequest request, Map headers) + { + super(request); + this.customHeaders = ImmutableMap.copyOf(requireNonNull(headers, "headers is null")); + } + + @Override + public String getHeader(String name) + { + if (customHeaders.containsKey(name)) { + return customHeaders.get(name); + } + return super.getHeader(name); + } + + @Override + public Enumeration getHeaderNames() + { + return enumeration(ImmutableSet.builder() + .addAll(customHeaders.keySet()) + .addAll(list(super.getHeaderNames())) + .build()); + } + + @Override + public Enumeration getHeaders(String name) + { + if (customHeaders.containsKey(name)) { + return enumeration(ImmutableList.of(customHeaders.get(name))); + } + return super.getHeaders(name); + } + } } diff --git a/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java b/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java index 034763c47212c..4360533f9db9f 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java +++ b/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java @@ -32,6 +32,8 @@ import com.facebook.airlift.tracetoken.TraceTokenModule; import com.facebook.drift.server.DriftServer; import com.facebook.drift.transport.netty.server.DriftNettyServerTransport; +import com.facebook.presto.ClientRequestFilterManager; +import com.facebook.presto.ClientRequestFilterModule; import com.facebook.presto.connector.ConnectorManager; import com.facebook.presto.cost.StatsCalculator; import com.facebook.presto.dispatcher.DispatchManager; @@ -60,6 +62,7 @@ import com.facebook.presto.server.ServerMainModule; import com.facebook.presto.server.ShutdownAction; import com.facebook.presto.server.security.ServerSecurityModule; +import com.facebook.presto.spi.ClientRequestFilterFactory; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.CoordinatorPlugin; import com.facebook.presto.spi.NodeManager; @@ -178,6 +181,7 @@ public class TestingPrestoServer private final ResourceManagerClusterStateProvider clusterStateProvider; private final PlanCheckerProviderManager planCheckerProviderManager; private final NodeManager pluginNodeManager; + private final ClientRequestFilterManager clientRequestFilterManager; public static class TestShutdownAction implements ShutdownAction @@ -311,6 +315,7 @@ public TestingPrestoServer( .add(new QueryPrerequisitesManagerModule()) .add(new NodeTtlFetcherManagerModule()) .add(new ClusterTtlProviderManagerModule()) + .add(new ClientRequestFilterModule()) .add(binder -> { binder.bind(TestingAccessControlManager.class).in(Scopes.SINGLETON); binder.bind(TestingEventListenerManager.class).in(Scopes.SINGLETON); @@ -444,6 +449,7 @@ else if (catalogServer) { requestBlocker = injector.getInstance(RequestBlocker.class); serverInfoResource = injector.getInstance(ServerInfoResource.class); pluginNodeManager = injector.getInstance(PluginNodeManager.class); + clientRequestFilterManager = injector.getInstance(ClientRequestFilterManager.class); // Announce Thrift server address DriftServer driftServer = injector.getInstance(DriftServer.class); @@ -861,4 +867,11 @@ private static int driftServerPort(DriftServer server) { return ((DriftNettyServerTransport) server.getServerTransport()).getPort(); } + + public ClientRequestFilterManager getClientRequestFilterManager(List requestFilterFactory) + { + requestFilterFactory.forEach(clientRequestFilterManager::registerClientRequestFilterFactory); + clientRequestFilterManager.loadClientRequestFilters(); + return clientRequestFilterManager; + } } diff --git a/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java b/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java index edf83a139c3a2..cc60682a5fabe 100644 --- a/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java +++ b/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java @@ -14,6 +14,7 @@ package com.facebook.presto.testing; import com.facebook.airlift.node.NodeInfo; +import com.facebook.presto.ClientRequestFilterManager; import com.facebook.presto.GroupByHashPageIndexerFactory; import com.facebook.presto.PagesIndexPageSorter; import com.facebook.presto.Session; @@ -529,6 +530,7 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, historyBasedPlanStatisticsManager, new TracerProviderManager(new TracingConfig()), new NodeStatusNotificationManager(), + new ClientRequestFilterManager(), planCheckerProviderManager); connectorManager.addConnectorFactory(globalSystemConnectorFactory); diff --git a/presto-main/src/test/java/com/facebook/presto/TestClientRequestFilterPlugin.java b/presto-main/src/test/java/com/facebook/presto/TestClientRequestFilterPlugin.java new file mode 100644 index 0000000000000..dc1daebdee092 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/TestClientRequestFilterPlugin.java @@ -0,0 +1,196 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto; + +import com.facebook.airlift.http.server.Authenticator; +import com.facebook.presto.server.MockHttpServletRequest; +import com.facebook.presto.server.security.AuthenticationFilter; +import com.facebook.presto.server.security.SecurityConfig; +import com.facebook.presto.server.testing.TestingPrestoServer; +import com.facebook.presto.spi.ClientRequestFilter; +import com.facebook.presto.spi.ClientRequestFilterFactory; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import org.testng.annotations.Test; + +import javax.servlet.http.HttpServletRequest; + +import java.security.Principal; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.testng.Assert.assertEquals; + +public class TestClientRequestFilterPlugin +{ + @Test + public void testCustomRequestFilterWithHeaders() throws Exception + { + MockHttpServletRequest request = new MockHttpServletRequest(ImmutableListMultimap.of("X-Custom-Header", "CustomValue")); + List requestFilterFactory = getClientRequestFilterFactory(); + AuthenticationFilter filter = setupAuthenticationFilter(requestFilterFactory); + PrincipalStub testPrincipal = new PrincipalStub(); + + HttpServletRequest wrappedRequest = filter.mergeExtraHeaders(request, testPrincipal); + + assertEquals("CustomValue", wrappedRequest.getHeader("X-Custom-Header")); + assertEquals("ExpectedExtraValue", wrappedRequest.getHeader("ExpectedExtraHeader")); + } + + @Test( + expectedExceptions = RuntimeException.class, + expectedExceptionsMessageRegExp = "Modification attempt detected: The header X-Presto-Transaction-Id is not allowed to be modified. The following headers cannot be modified: " + + "X-Presto-Transaction-Id, X-Presto-Started-Transaction-Id, X-Presto-Clear-Transaction-Id, X-Presto-Trace-Token") + public void testCustomRequestFilterWithHeadersInBlockList() throws Exception + { + MockHttpServletRequest request = new MockHttpServletRequest(ImmutableListMultimap.of("X-Custom-Header", "CustomValue")); + List requestFilterFactory = getClientRequestFilterInBlockList(); + AuthenticationFilter filter = setupAuthenticationFilter(requestFilterFactory); + PrincipalStub testPrincipal = new PrincipalStub(); + + filter.mergeExtraHeaders(request, testPrincipal); + } + + @Test( + expectedExceptions = RuntimeException.class, + expectedExceptionsMessageRegExp = "Header conflict detected: ExpectedExtraValue already added by another filter.") + public void testCustomRequestFilterHandlesConflict() throws Exception + { + MockHttpServletRequest request = new MockHttpServletRequest(ImmutableListMultimap.of("X-Custom-Header", "CustomValue")); + List requestFilterFactory = getClientRequestFilterFactoryHandlesConflict(); + AuthenticationFilter filter = setupAuthenticationFilter(requestFilterFactory); + PrincipalStub testPrincipal = new PrincipalStub(); + + filter.mergeExtraHeaders(request, testPrincipal); + } + + private List getClientRequestFilterFactory() + { + return createFilterFactories( + new String[][] { + {"CustomModifier", "ExpectedExtraHeader", "ExpectedExtraValue"} + }); + } + + private List getClientRequestFilterInBlockList() + { + return createFilterFactories( + new String[][] { + {"BlockListModifier", "X-Presto-Transaction-Id", "CustomValue"} + }); + } + + private List getClientRequestFilterFactoryHandlesConflict() + { + return createFilterFactories( + new String[][] { + {"Filter1", "ExpectedExtraValue", "ExpectedExtraHeader_1"}, + {"Filter2", "ExpectedExtraValue", "ExpectedExtraHeader_2"} + }); + } + + private AuthenticationFilter setupAuthenticationFilter(List requestFilterFactory) throws Exception + { + try (TestingPrestoServer testingPrestoServer = new TestingPrestoServer()) { + ClientRequestFilterManager clientRequestFilterManager = testingPrestoServer.getClientRequestFilterManager(requestFilterFactory); + + List authenticators = createAuthenticators(); + SecurityConfig securityConfig = createSecurityConfig(); + + return new AuthenticationFilter(authenticators, securityConfig, clientRequestFilterManager); + } + } + + private List createFilterFactories(String[][] filterConfigs) + { + ImmutableList.Builder factories = ImmutableList.builder(); + for (String[] config : filterConfigs) { + factories.add(new GenericClientRequestFilterFactory(config[0], config[1], config[2])); + } + return factories.build(); + } + + private List createAuthenticators() + { + return Collections.emptyList(); + } + + private SecurityConfig createSecurityConfig() + { + return new SecurityConfig() { + @Override + public boolean getAllowForwardedHttps() + { + return true; + } + }; + } + + static class GenericClientRequestFilterFactory + implements ClientRequestFilterFactory + { + private final String name; + private final String headerName; + private final String headerValue; + + public GenericClientRequestFilterFactory(String name, String headerName, String headerValue) + { + this.name = name; + this.headerName = headerName; + this.headerValue = headerValue; + } + + @Override + public String getName() + { + return name; + } + + @Override + public ClientRequestFilter create(String filterName) + { + return new CustomClientRequestFilter(); + } + + private class CustomClientRequestFilter + implements ClientRequestFilter + { + @Override + public Set getExtraHeaderKeys() + { + return ImmutableSet.of(headerName); + } + + @Override + public Map getExtraHeaders(Principal principal) + { + return ImmutableMap.of(headerName, headerValue); + } + } + } + + static class PrincipalStub + implements Principal + { + @Override + public String getName() + { + return "TestPrincipal"; + } + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/server/MockHttpServletRequest.java b/presto-main/src/test/java/com/facebook/presto/server/MockHttpServletRequest.java index d43b34dd4eef2..000eda9b853e0 100644 --- a/presto-main/src/test/java/com/facebook/presto/server/MockHttpServletRequest.java +++ b/presto-main/src/test/java/com/facebook/presto/server/MockHttpServletRequest.java @@ -14,6 +14,7 @@ package com.facebook.presto.server; import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ListMultimap; import javax.servlet.AsyncContext; @@ -45,6 +46,7 @@ public class MockHttpServletRequest implements HttpServletRequest { + private static final String DEFAULT_ADDRESS = "127.0.0.1"; private final ListMultimap headers; private final String remoteAddress; private final Map attributes; @@ -56,6 +58,12 @@ public MockHttpServletRequest(ListMultimap headers, String remot this.attributes = new HashMap<>(requireNonNull(attributes, "attributes is null")); } + public MockHttpServletRequest(ListMultimap headers) + { + // Default remoteAddress and empty attributes + this(headers, DEFAULT_ADDRESS, ImmutableMap.of()); + } + @Override public String getAuthType() { diff --git a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java index ed017e654cceb..9f69fc7c19256 100644 --- a/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java +++ b/presto-spark-base/src/main/java/com/facebook/presto/spark/PrestoSparkModule.java @@ -19,6 +19,7 @@ import com.facebook.airlift.json.smile.SmileCodec; import com.facebook.airlift.node.NodeConfig; import com.facebook.airlift.node.NodeInfo; +import com.facebook.presto.ClientRequestFilterManager; import com.facebook.presto.GroupByHashPageIndexerFactory; import com.facebook.presto.PagesIndexPageSorter; import com.facebook.presto.SystemSessionProperties; @@ -552,7 +553,7 @@ protected void setup(Binder binder) // extra credentials and authenticator for Presto-on-Spark newSetBinder(binder, PrestoSparkCredentialsProvider.class); newSetBinder(binder, PrestoSparkAuthenticatorProvider.class); - + binder.bind(ClientRequestFilterManager.class).in(Scopes.SINGLETON); binder.bind(PlanCheckerProviderManager.class).in(Scopes.SINGLETON); } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/ClientRequestFilter.java b/presto-spi/src/main/java/com/facebook/presto/spi/ClientRequestFilter.java new file mode 100644 index 0000000000000..dc7d8f0b146cf --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/ClientRequestFilter.java @@ -0,0 +1,25 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi; + +import java.security.Principal; +import java.util.Map; +import java.util.Set; + +public interface ClientRequestFilter +{ + Set getExtraHeaderKeys(); + + Map getExtraHeaders(Principal principal); +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/ClientRequestFilterFactory.java b/presto-spi/src/main/java/com/facebook/presto/spi/ClientRequestFilterFactory.java new file mode 100644 index 0000000000000..ca88dd63975cd --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/ClientRequestFilterFactory.java @@ -0,0 +1,21 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi; + +public interface ClientRequestFilterFactory +{ + String getName(); + + ClientRequestFilter create(String filterName); +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/Plugin.java b/presto-spi/src/main/java/com/facebook/presto/spi/Plugin.java index 81e8f55b0a665..8def2d14573c2 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/Plugin.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/Plugin.java @@ -142,4 +142,9 @@ default Iterable getNodeStatusNotificatio { return emptyList(); } + + default Iterable getClientRequestFilterFactories() + { + return emptyList(); + } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/StandardErrorCode.java b/presto-spi/src/main/java/com/facebook/presto/spi/StandardErrorCode.java index 900a72526e542..894a8abea5120 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/StandardErrorCode.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/StandardErrorCode.java @@ -140,6 +140,7 @@ public enum StandardErrorCode EXCEEDED_WRITTEN_INTERMEDIATE_BYTES_LIMIT(0x0002_0012, INSUFFICIENT_RESOURCES), TOO_MANY_SIDECARS(0x0002_0013, INTERNAL_ERROR), NO_CPP_SIDECARS(0x0002_0014, INTERNAL_ERROR), + HEADER_MODIFICATION_ATTEMPT(0x0002_0015, INTERNAL_ERROR), /**/; // Error code range 0x0003 is reserved for Presto-on-Spark