Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing ClientRequestFilter.java: A New Plugin for Merging Additional Headers into Client Requests in the Authentication Filter #23380

Merged
merged 1 commit into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions presto-docs/src/main/sphinx/develop.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ This guide is intended for Presto contributors and plugin developers.
develop/serialized-page
develop/presto-console
develop/presto-authenticator
develop/client-request-filter
26 changes: 26 additions & 0 deletions presto-docs/src/main/sphinx/develop/client-request-filter.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@

======================
Client Request Filter
======================

Presto allows operators to customize the headers used to process queries. Some example use cases include customized authentication workflows, or enriching query attributes such as the query source. Use the Client Request Filter plugin to control header customization during query execution.

Implementation
--------------

The ``ClientRequestFilterFactory`` is responsible for creating instances of ``ClientRequestFilter``. It also defines
the name of the filter.

The ``ClientRequestFilter`` interface provides two methods: ``getExtraHeaders()``, which allows the runtime to quickly check if it needs to apply a more expensive call to enrich the headers, and ``getHeaderNames()``, which returns a list of header names used as the header names in client requests.

The implementation of ``ClientRequestFilterFactory`` must be wrapped as a plugin and installed on the Presto cluster.

After installing a plugin that implements ``ClientRequestFilterFactory`` on the coordinator, the ``AuthenticationFilter`` class passes the ``principal`` object to the request filter, which returns the header values as a map.

Presto uses the request filter to determine whether a header is present in the blocklist. The blocklist includes headers such as ``X-Presto-Transaction-Id``, ``X-Presto-Started-Transaction-Id``, ``X-Presto-Clear-Transaction-Id``, and ``X-Presto-Trace-Token``, which are not allowed to be overridden.

For a complete list of headers, see the `Java source`_.

Note: The `Java source`_ includes these blocklist headers that are not eligible for overriding. The other headers not mentioned here can be overridden.

.. _Java source: https://github.com/prestodb/presto/blob/master/presto-client/src/main/java/com/facebook/presto/client/PrestoHeaders.java
6 changes: 6 additions & 0 deletions presto-main/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,12 @@
<groupId>io.netty</groupId>
<artifactId>netty-transport</artifactId>
</dependency>

<dependency>
tdcmeehan marked this conversation as resolved.
Show resolved Hide resolved
<groupId>com.squareup.okhttp3</groupId>
<artifactId>mockwebserver</artifactId>
<scope>test</scope>
</dependency>
</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.facebook.presto;

import com.facebook.presto.spi.ClientRequestFilter;
import com.facebook.presto.spi.ClientRequestFilterFactory;
import com.facebook.presto.spi.PrestoException;
import com.google.common.collect.ImmutableList;

import javax.annotation.concurrent.GuardedBy;
import javax.annotation.concurrent.ThreadSafe;

import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;

import static com.facebook.presto.spi.StandardErrorCode.ALREADY_EXISTS;
import static com.facebook.presto.spi.StandardErrorCode.CONFIGURATION_INVALID;
import static com.facebook.presto.spi.StandardErrorCode.INVALID_ARGUMENTS;
import static com.google.common.collect.ImmutableList.toImmutableList;

@ThreadSafe
public class ClientRequestFilterManager
{
private Map<String, ClientRequestFilterFactory> factories = new ConcurrentHashMap<>();

@GuardedBy("this")
private volatile List<ClientRequestFilter> filters = ImmutableList.of();
private final AtomicBoolean loaded = new AtomicBoolean();

public void registerClientRequestFilterFactory(ClientRequestFilterFactory factory)
{
if (loaded.get()) {
throw new PrestoException(INVALID_ARGUMENTS, "Cannot register factories after filters are loaded.");
}

String name = factory.getName();
if (factories.putIfAbsent(name, factory) != null) {
throw new PrestoException(ALREADY_EXISTS, "A factory with the name '" + name + "' is already registered.");
}
}

public void loadClientRequestFilters()
tdcmeehan marked this conversation as resolved.
Show resolved Hide resolved
{
if (!loaded.compareAndSet(false, true)) {
throw new PrestoException(CONFIGURATION_INVALID, "loadClientRequestFilters can only be called once.");
}

filters = factories.values().stream()
.map(factory -> factory.create())
.collect(toImmutableList());
factories = null;
}
tdcmeehan marked this conversation as resolved.
Show resolved Hide resolved

public List<ClientRequestFilter> getClientRequestFilters()
{
return filters;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto;

import com.google.inject.Binder;
import com.google.inject.Module;
import com.google.inject.Scopes;

public class ClientRequestFilterModule
implements Module
{
@Override
public void configure(Binder binder)
{
binder.bind(ClientRequestFilterManager.class).in(Scopes.SINGLETON);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.facebook.airlift.log.Logger;
import com.facebook.airlift.node.NodeInfo;
import com.facebook.presto.ClientRequestFilterManager;
import com.facebook.presto.common.block.BlockEncoding;
import com.facebook.presto.common.block.BlockEncodingManager;
import com.facebook.presto.common.type.ParametricType;
Expand All @@ -28,6 +29,7 @@
import com.facebook.presto.security.AccessControlManager;
import com.facebook.presto.server.security.PasswordAuthenticatorManager;
import com.facebook.presto.server.security.PrestoAuthenticatorManager;
import com.facebook.presto.spi.ClientRequestFilterFactory;
import com.facebook.presto.spi.CoordinatorPlugin;
import com.facebook.presto.spi.Plugin;
import com.facebook.presto.spi.analyzer.AnalyzerProvider;
Expand Down Expand Up @@ -137,6 +139,7 @@ public class PluginManager
private final AnalyzerProviderManager analyzerProviderManager;
private final QueryPreparerProviderManager queryPreparerProviderManager;
private final NodeStatusNotificationManager nodeStatusNotificationManager;
private final ClientRequestFilterManager clientRequestFilterManager;
private final PlanCheckerProviderManager planCheckerProviderManager;

@Inject
Expand All @@ -161,6 +164,7 @@ public PluginManager(
HistoryBasedPlanStatisticsManager historyBasedPlanStatisticsManager,
TracerProviderManager tracerProviderManager,
NodeStatusNotificationManager nodeStatusNotificationManager,
ClientRequestFilterManager clientRequestFilterManager,
PlanCheckerProviderManager planCheckerProviderManager)
{
requireNonNull(nodeInfo, "nodeInfo is null");
Expand Down Expand Up @@ -194,6 +198,7 @@ public PluginManager(
this.analyzerProviderManager = requireNonNull(analyzerProviderManager, "analyzerProviderManager is null");
this.queryPreparerProviderManager = requireNonNull(queryPreparerProviderManager, "queryPreparerProviderManager is null");
this.nodeStatusNotificationManager = requireNonNull(nodeStatusNotificationManager, "nodeStatusNotificationManager is null");
this.clientRequestFilterManager = requireNonNull(clientRequestFilterManager, "clientRequestFilterManager is null");
this.planCheckerProviderManager = requireNonNull(planCheckerProviderManager, "planCheckerProviderManager is null");
}

Expand Down Expand Up @@ -364,6 +369,11 @@ public void installPlugin(Plugin plugin)
log.info("Registering node status notification provider %s", nodeStatusNotificationProviderFactory.getName());
nodeStatusNotificationManager.addNodeStatusNotificationProviderFactory(nodeStatusNotificationProviderFactory);
}

for (ClientRequestFilterFactory clientRequestFilterFactory : plugin.getClientRequestFilterFactories()) {
log.info("Registering client request filter factory");
clientRequestFilterManager.registerClientRequestFilterFactory(clientRequestFilterFactory);
}
}

public void installCoordinatorPlugin(CoordinatorPlugin plugin)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
import com.facebook.airlift.tracetoken.TraceTokenModule;
import com.facebook.drift.server.DriftServer;
import com.facebook.drift.transport.netty.server.DriftNettyServerTransport;
import com.facebook.presto.ClientRequestFilterManager;
import com.facebook.presto.ClientRequestFilterModule;
import com.facebook.presto.dispatcher.QueryPrerequisitesManager;
import com.facebook.presto.dispatcher.QueryPrerequisitesManagerModule;
import com.facebook.presto.eventlistener.EventListenerManager;
Expand Down Expand Up @@ -117,6 +119,7 @@ public void run()
new NodeModule(),
new DiscoveryModule(),
new HttpServerModule(),
new ClientRequestFilterModule(),
new JsonModule(),
installModuleIf(
FeaturesConfig.class,
Expand Down Expand Up @@ -192,6 +195,7 @@ public void run()
PluginNodeManager pluginNodeManager = new PluginNodeManager(nodeManager, nodeInfo.getEnvironment());
planCheckerProviderManager.loadPlanCheckerProviders(pluginNodeManager);

injector.getInstance(ClientRequestFilterManager.class).loadClientRequestFilters();
startAssociatedProcesses(injector);

injector.getInstance(Announcer.class).start();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,14 @@

import com.facebook.airlift.http.server.AuthenticationException;
import com.facebook.airlift.http.server.Authenticator;
import com.facebook.presto.ClientRequestFilterManager;
import com.facebook.presto.spi.ClientRequestFilter;
import com.facebook.presto.spi.PrestoException;
import com.google.common.base.Joiner;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.net.HttpHeaders;

import javax.inject.Inject;
Expand All @@ -35,14 +40,20 @@
import java.io.InputStream;
import java.io.PrintWriter;
import java.security.Principal;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static com.facebook.presto.spi.StandardErrorCode.HEADER_MODIFICATION_ATTEMPT;
import static com.google.common.io.ByteStreams.copy;
import static com.google.common.io.ByteStreams.nullOutputStream;
import static com.google.common.net.HttpHeaders.WWW_AUTHENTICATE;
import static com.google.common.net.MediaType.PLAIN_TEXT_UTF_8;
import static java.util.Collections.enumeration;
import static java.util.Collections.list;
import static java.util.Objects.requireNonNull;
import static javax.servlet.http.HttpServletResponse.SC_UNAUTHORIZED;

Expand All @@ -52,12 +63,15 @@ public class AuthenticationFilter
private static final String HTTPS_PROTOCOL = "https";
private final List<Authenticator> authenticators;
private final boolean allowForwardedHttps;
private final ClientRequestFilterManager clientRequestFilterManager;
private final List<String> headersBlockList = ImmutableList.of("X-Presto-Transaction-Id", "X-Presto-Started-Transaction-Id", "X-Presto-Clear-Transaction-Id", "X-Presto-Trace-Token");

@Inject
public AuthenticationFilter(List<Authenticator> authenticators, SecurityConfig securityConfig)
public AuthenticationFilter(List<Authenticator> authenticators, SecurityConfig securityConfig, ClientRequestFilterManager clientRequestFilterManager)
{
this.authenticators = ImmutableList.copyOf(requireNonNull(authenticators, "authenticators is null"));
this.allowForwardedHttps = requireNonNull(securityConfig, "securityConfig is null").getAllowForwardedHttps();
this.clientRequestFilterManager = requireNonNull(clientRequestFilterManager, "clientRequestFilterManager is null");
}

@Override
Expand Down Expand Up @@ -95,9 +109,9 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo
e.getAuthenticateHeader().ifPresent(authenticateHeaders::add);
continue;
}

// authentication succeeded
nextFilter.doFilter(withPrincipal(request, principal), response);
HttpServletRequest wrappedRequest = mergeExtraHeaders(request, principal);
nextFilter.doFilter(withPrincipal(wrappedRequest, principal), response);
return;
}

Expand Down Expand Up @@ -126,6 +140,47 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo
}
}

public HttpServletRequest mergeExtraHeaders(HttpServletRequest request, Principal principal)
{
List<ClientRequestFilter> clientRequestFilters = clientRequestFilterManager.getClientRequestFilters();

if (clientRequestFilters.isEmpty()) {
return request;
}

ImmutableMap.Builder<String, String> extraHeadersMapBuilder = ImmutableMap.builder();
Set<String> addedHeaders = new HashSet<>();

for (ClientRequestFilter requestFilter : clientRequestFilters) {
boolean headersPresent = requestFilter.getExtraHeaderKeys().stream()
.allMatch(headerName -> request.getHeader(headerName) != null);

if (!headersPresent) {
Map<String, String> extraHeaderValueMap = requestFilter.getExtraHeaders(principal);

if (!extraHeaderValueMap.isEmpty()) {
for (Map.Entry<String, String> extraHeaderEntry : extraHeaderValueMap.entrySet()) {
String headerKey = extraHeaderEntry.getKey();
if (headersBlockList.contains(headerKey)) {
throw new PrestoException(HEADER_MODIFICATION_ATTEMPT,
"Modification attempt detected: The header " + headerKey + " is not allowed to be modified. The following headers cannot be modified: " +
String.join(", ", headersBlockList));
}
if (addedHeaders.contains(headerKey)) {
throw new PrestoException(HEADER_MODIFICATION_ATTEMPT, "Header conflict detected: " + headerKey + " already added by another filter.");
}
if (request.getHeader(headerKey) == null && requestFilter.getExtraHeaderKeys().contains(headerKey)) {
extraHeadersMapBuilder.put(headerKey, extraHeaderEntry.getValue());
addedHeaders.add(headerKey);
}
}
}
}
}

return new ModifiedHttpServletRequest(request, extraHeadersMapBuilder.build());
}

private boolean doesRequestSupportAuthentication(HttpServletRequest request)
{
if (authenticators.isEmpty()) {
Expand Down Expand Up @@ -166,4 +221,43 @@ private static void skipRequestBody(HttpServletRequest request)
copy(inputStream, nullOutputStream());
}
}

public static class ModifiedHttpServletRequest
extends HttpServletRequestWrapper
{
private final Map<String, String> customHeaders;
tdcmeehan marked this conversation as resolved.
Show resolved Hide resolved

public ModifiedHttpServletRequest(HttpServletRequest request, Map<String, String> headers)
{
super(request);
this.customHeaders = ImmutableMap.copyOf(requireNonNull(headers, "headers is null"));
}

@Override
public String getHeader(String name)
{
if (customHeaders.containsKey(name)) {
return customHeaders.get(name);
}
return super.getHeader(name);
}

@Override
public Enumeration<String> getHeaderNames()
{
return enumeration(ImmutableSet.<String>builder()
.addAll(customHeaders.keySet())
.addAll(list(super.getHeaderNames()))
.build());
}

@Override
public Enumeration<String> getHeaders(String name)
{
if (customHeaders.containsKey(name)) {
return enumeration(ImmutableList.of(customHeaders.get(name)));
}
return super.getHeaders(name);
}
}
}
Loading
Loading