Skip to content

Commit

Permalink
feat: move credential handling responsibility to the proxy
Browse files Browse the repository at this point in the history
stack-info: PR: #9690, branch: igorbernstein2/stack/3
  • Loading branch information
igorbernstein2 committed Nov 22, 2024
1 parent 8fecd00 commit 96d2449
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,22 @@

package com.google.cloud.bigtable.examples.proxy.commands;

import com.google.auth.Credentials;
import com.google.auth.oauth2.GoogleCredentials;
import com.google.bigtable.admin.v2.BigtableInstanceAdminGrpc;
import com.google.bigtable.admin.v2.BigtableTableAdminGrpc;
import com.google.bigtable.v2.BigtableGrpc;
import com.google.cloud.bigtable.examples.proxy.core.ProxyHandler;
import com.google.cloud.bigtable.examples.proxy.core.Registry;
import com.google.common.collect.ImmutableMap;
import com.google.longrunning.OperationsGrpc;
import io.grpc.CallCredentials;
import io.grpc.InsecureServerCredentials;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.Server;
import io.grpc.ServerCallHandler;
import io.grpc.auth.MoreCallCredentials;
import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder;
import java.io.IOException;
import java.net.InetSocketAddress;
Expand Down Expand Up @@ -67,6 +71,7 @@ public class Serve implements Callable<Void> {

ManagedChannel adminChannel = null;
ManagedChannel dataChannel = null;
Credentials credentials = null;
Server server;

@Override
Expand Down Expand Up @@ -94,17 +99,21 @@ void start() throws IOException {
.disableRetry()
.build();
}
if (credentials == null) {
credentials = GoogleCredentials.getApplicationDefault();
}
CallCredentials callCredentials = MoreCallCredentials.from(credentials);

Map<String, ServerCallHandler<byte[], byte[]>> serviceMap =
ImmutableMap.of(
BigtableGrpc.SERVICE_NAME,
new ProxyHandler<>(dataChannel),
new ProxyHandler<>(dataChannel, callCredentials),
BigtableInstanceAdminGrpc.SERVICE_NAME,
new ProxyHandler<>(adminChannel),
new ProxyHandler<>(adminChannel, callCredentials),
BigtableTableAdminGrpc.SERVICE_NAME,
new ProxyHandler<>(adminChannel),
new ProxyHandler<>(adminChannel, callCredentials),
OperationsGrpc.SERVICE_NAME,
new ProxyHandler<>(adminChannel));
new ProxyHandler<>(adminChannel, callCredentials));

server =
NettyServerBuilder.forAddress(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.google.cloud.bigtable.examples.proxy.core;

import io.grpc.CallCredentials;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
Expand All @@ -25,15 +26,23 @@

/** A factory pairing of an incoming server call to an outgoing client call. */
public final class ProxyHandler<ReqT, RespT> implements ServerCallHandler<ReqT, RespT> {
private static final Metadata.Key<String> AUTHORIZATION_KEY =
Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER);

private final Channel channel;
private final CallCredentials callCredentials;

public ProxyHandler(Channel channel) {
public ProxyHandler(Channel channel, CallCredentials callCredentials) {
this.channel = channel;
this.callCredentials = callCredentials;
}

@Override
public ServerCall.Listener<ReqT> startCall(ServerCall<ReqT, RespT> serverCall, Metadata headers) {
CallOptions callOptions = CallOptions.DEFAULT;
// Strip incoming credentials
headers.removeAll(AUTHORIZATION_KEY);
// Inject proxy credentials
CallOptions callOptions = CallOptions.DEFAULT.withCallCredentials(callCredentials);

ClientCall<ReqT, RespT> clientCall =
channel.newCall(serverCall.getMethodDescriptor(), callOptions);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import static com.google.common.truth.Truth.assertThat;
import static com.google.common.truth.Truth.assertWithMessage;

import com.google.auth.Credentials;
import com.google.bigtable.admin.v2.BigtableInstanceAdminGrpc;
import com.google.bigtable.admin.v2.BigtableInstanceAdminGrpc.BigtableInstanceAdminFutureStub;
import com.google.bigtable.admin.v2.BigtableInstanceAdminGrpc.BigtableInstanceAdminImplBase;
Expand All @@ -35,6 +36,7 @@
import com.google.bigtable.v2.BigtableGrpc.BigtableImplBase;
import com.google.bigtable.v2.CheckAndMutateRowRequest;
import com.google.bigtable.v2.CheckAndMutateRowResponse;
import com.google.common.collect.Lists;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.longrunning.GetOperationRequest;
import com.google.longrunning.Operation;
Expand Down Expand Up @@ -63,6 +65,9 @@
import io.grpc.testing.GrpcCleanupRule;
import java.io.IOException;
import java.net.ServerSocket;
import java.net.URI;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.BlockingDeque;
import java.util.concurrent.BlockingQueue;
Expand Down Expand Up @@ -94,6 +99,7 @@ public class ServeTest {
private FakeTableAdminService tableAdminService;
private OperationService operationService;
private ManagedChannel fakeServiceChannel;
private FakeCredentials fakeCredentials;

// Proxy
private Serve serve;
Expand All @@ -108,6 +114,8 @@ public void setUp() throws IOException {
tableAdminService = new FakeTableAdminService();
operationService = new OperationService();

fakeCredentials = new FakeCredentials();

grpcCleanup.register(
InProcessServerBuilder.forName(targetServerName)
.intercept(metadataInterceptor)
Expand All @@ -123,7 +131,9 @@ public void setUp() throws IOException {
InProcessChannelBuilder.forName(targetServerName).usePlaintext().build());

// Create the proxy
serve = createAndStartCommand(fakeServiceChannel);
// Inject fakes for upstream calls. For unit tests we want to shim communications to the
// bigtable service.
serve = createAndStartCommand(fakeServiceChannel, fakeCredentials);

proxyChannel =
grpcCleanup.register(
Expand Down Expand Up @@ -321,11 +331,86 @@ public void onClose(Status status, Metadata trailers) {
assertThat(clientRecvTrailer.get()).hasValue("trailer", "trailer-value");
}

private static Serve createAndStartCommand(ManagedChannel targetChannel) throws IOException {
@Test
public void testCredentials() throws InterruptedException, ExecutionException, TimeoutException {
BigtableFutureStub proxyStub = BigtableGrpc.newFutureStub(proxyChannel);

CheckAndMutateRowRequest request =
CheckAndMutateRowRequest.newBuilder().setTableName("some-table").build();
final ListenableFuture<CheckAndMutateRowResponse> proxyFuture =
proxyStub.checkAndMutateRow(request);
StreamObserver<CheckAndMutateRowResponse> serverObserver =
dataService
.calls
.computeIfAbsent(request, (ignored) -> new LinkedBlockingDeque<>())
.poll(1, TimeUnit.SECONDS);

assertWithMessage("Timed out waiting for the proxied RPC on the fake server")
.that(serverObserver)
.isNotNull();

serverObserver.onNext(CheckAndMutateRowResponse.newBuilder().setPredicateMatched(true).build());
serverObserver.onCompleted();
proxyFuture.get(1, TimeUnit.SECONDS);

assertThat(metadataInterceptor.requestHeaders.poll(1, TimeUnit.SECONDS))
.hasValue("authorization", "fake-token");
}

@Test
public void testCredentialsClobber()
throws InterruptedException, ExecutionException, TimeoutException {
BigtableFutureStub proxyStub =
BigtableGrpc.newFutureStub(proxyChannel)
.withInterceptors(
new ClientInterceptor() {
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> methodDescriptor,
CallOptions callOptions,
Channel channel) {
return new SimpleForwardingClientCall<ReqT, RespT>(
channel.newCall(methodDescriptor, callOptions)) {
@Override
public void start(Listener<RespT> responseListener, Metadata headers) {
headers.put(
Metadata.Key.of("authorization", Metadata.ASCII_STRING_MARSHALLER),
"pre-proxied-value");
super.start(responseListener, headers);
}
};
}
});

CheckAndMutateRowRequest request =
CheckAndMutateRowRequest.newBuilder().setTableName("some-table").build();
final ListenableFuture<CheckAndMutateRowResponse> proxyFuture =
proxyStub.checkAndMutateRow(request);
StreamObserver<CheckAndMutateRowResponse> serverObserver =
dataService
.calls
.computeIfAbsent(request, (ignored) -> new LinkedBlockingDeque<>())
.poll(1, TimeUnit.SECONDS);

assertWithMessage("Timed out waiting for the proxied RPC on the fake server")
.that(serverObserver)
.isNotNull();

serverObserver.onNext(CheckAndMutateRowResponse.newBuilder().setPredicateMatched(true).build());
serverObserver.onCompleted();
proxyFuture.get(1, TimeUnit.SECONDS);

Metadata serverRequestHeaders = metadataInterceptor.requestHeaders.poll(1, TimeUnit.SECONDS);
assertThat(serverRequestHeaders).hasValue("authorization", "fake-token");
}

private static Serve createAndStartCommand(
ManagedChannel targetChannel, FakeCredentials targetCredentials) throws IOException {
for (int i = 10; i >= 0; i--) {
Serve s = new Serve();
s.dataChannel = targetChannel;
s.adminChannel = targetChannel;
s.credentials = targetCredentials;

try (ServerSocket serverSocket = new ServerSocket(0)) {
s.listenPort = serverSocket.getLocalPort();
Expand Down Expand Up @@ -423,4 +508,34 @@ public void getOperation(
.add(responseObserver);
}
}

private static class FakeCredentials extends Credentials {
private static final String HEADER_NAME = "authorization";
private String fakeValue = "fake-token";

@Override
public String getAuthenticationType() {
return "fake";
}

@Override
public Map<String, List<String>> getRequestMetadata(URI uri) throws IOException {
return Map.of(HEADER_NAME, Lists.newArrayList(fakeValue));
}

@Override
public boolean hasRequestMetadata() {
return true;
}

@Override
public boolean hasRequestMetadataOnly() {
return true;
}

@Override
public void refresh() throws IOException {
// noop
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import com.google.common.truth.FailureMetadata;
import com.google.common.truth.Subject;
import io.grpc.Metadata;
import java.util.ArrayList;
import java.util.Optional;
import org.jspecify.annotations.Nullable;

public class MetadataSubject extends Subject {
Expand Down Expand Up @@ -52,6 +54,17 @@ public void hasValue(String key, String value) {
}

public <T> void hasValue(Metadata.Key<T> key, T value) {
check("get(" + key + ")").that(metadata.get(key)).isEqualTo(value);
Iterable<T> actualValues = Optional.ofNullable(metadata.getAll(key)).orElse(new ArrayList<>());
check("get(" + key + ")").that(actualValues).containsExactly(value);
}

public void containsValue(String key, String value) {
check("get(" + key + ")")
.that(metadata.getAll(Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER)))
.contains(value);
}

public <T> void containsValue(Metadata.Key<T> key, T value) {
check("get(" + key + ")").that(metadata.getAll(key)).contains(value);
}
}

0 comments on commit 96d2449

Please sign in to comment.