Skip to content

Commit

Permalink
feat: Implement Client-Side CAB token generation.
Browse files Browse the repository at this point in the history
Change-Id: I2c217656584cf5805297f02340cbbabca471f609
  • Loading branch information
huangjiahua committed Jan 9, 2025
1 parent 0d96dcf commit 1177072
Show file tree
Hide file tree
Showing 5 changed files with 379 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,13 @@

import com.google.api.client.util.Clock;
import com.google.auth.Credentials;
import com.google.auth.credentialaccessboundary.ClientSideCredentialAccessBoundaryFactory.RefreshTask;
import com.google.auth.credentialaccessboundary.protobuf.ClientSideAccessBoundaryProto.ClientSideAccessBoundary;
import com.google.auth.credentialaccessboundary.protobuf.ClientSideAccessBoundaryProto.ClientSideAccessBoundaryRule;
import com.google.auth.http.HttpTransportFactory;
import com.google.auth.oauth2.AccessToken;
import com.google.auth.oauth2.CredentialAccessBoundary;
import com.google.auth.oauth2.CredentialAccessBoundary.AccessBoundaryRule;
import com.google.auth.oauth2.GoogleCredentials;
import com.google.auth.oauth2.OAuth2Utils;
import com.google.auth.oauth2.StsRequestHandler;
Expand All @@ -53,12 +57,28 @@
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListenableFutureTask;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.crypto.tink.Aead;
import com.google.crypto.tink.InsecureSecretKeyAccess;
import com.google.crypto.tink.KeysetHandle;
import com.google.crypto.tink.RegistryConfiguration;
import com.google.crypto.tink.TinkProtoKeysetFormat;
import com.google.crypto.tink.aead.AeadConfig;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import dev.cel.common.CelAbstractSyntaxTree;
import dev.cel.common.CelOptions;
import dev.cel.common.CelProtoAbstractSyntaxTree;
import dev.cel.common.CelValidationException;
import dev.cel.compiler.CelCompiler;
import dev.cel.compiler.CelCompilerFactory;
import dev.cel.expr.Expr;
import java.io.IOException;
import java.time.Duration;
import java.util.Date;
import java.util.concurrent.ExecutionException;
import javax.annotation.Nullable;
import java.util.Base64;
import java.util.List;
import java.security.GeneralSecurityException;

public class ClientSideCredentialAccessBoundaryFactory {
static final Duration DEFAULT_REFRESH_MARGIN = Duration.ofMinutes(30);
Expand All @@ -72,6 +92,7 @@ public class ClientSideCredentialAccessBoundaryFactory {
private final Object refreshLock = new byte[0];
private volatile IntermediateCredentials intermediateCredentials = null;
private final Clock clock;
private final CelCompiler celCompiler;

enum RefreshType {
NONE,
Expand All @@ -83,6 +104,19 @@ private ClientSideCredentialAccessBoundaryFactory(Builder builder) {
this.transportFactory = builder.transportFactory;
this.sourceCredential = builder.sourceCredential;
this.tokenExchangeEndpoint = builder.tokenExchangeEndpoint;

try {
AeadConfig.register();
} catch (GeneralSecurityException e) {
throw new IllegalStateException("Error occurred when registering Tink");
}

CelOptions options = CelOptions.current().build();
this.celCompiler = CelCompilerFactory
.standardCelCompilerBuilder()
.setOptions(options)
.build();

this.refreshMargin =
builder.refreshMargin != null ? builder.refreshMargin : DEFAULT_REFRESH_MARGIN;
this.minimumTokenLifetime =
Expand All @@ -92,11 +126,6 @@ private ClientSideCredentialAccessBoundaryFactory(Builder builder) {
this.clock = builder.clock;
}

public AccessToken generateToken(CredentialAccessBoundary accessBoundary) {
// TODO(negarb/jiahuah): Implement generateToken
// Note: This method will call refreshCredentialsIfRequired().
throw new UnsupportedOperationException("generateToken is not yet implemented.");
}

/**
* Refreshes the intermediate access token and access boundary session key if required.
Expand Down Expand Up @@ -403,6 +432,90 @@ public void run() {
}
}

public AccessToken generateToken(CredentialAccessBoundary accessBoundary) throws IOException {
this.refreshCredentialsIfRequired();

String intermediaryToken, sessionKey;
Date intermediaryTokenExpirationTime;

synchronized (this) {
intermediaryToken =
this.intermediateCredentials.intermediateAccessToken.getTokenValue();
intermediaryTokenExpirationTime =
this.intermediateCredentials.intermediateAccessToken
.getExpirationTime();
sessionKey = this.intermediateCredentials.accessBoundarySessionKey;
}

byte[] rawRestrictions =
this.serializeCredentialAccessBoundary(accessBoundary);

byte[] encryptedRestrictions =
this.encryptRestrictions(rawRestrictions, sessionKey);

String tokenValue =
intermediaryToken + "." +
Base64.getUrlEncoder().encodeToString(encryptedRestrictions);

return new AccessToken(tokenValue, intermediaryTokenExpirationTime);
}

private byte[] serializeCredentialAccessBoundary(
CredentialAccessBoundary credentialAccessBoundary) throws IOException {
List<AccessBoundaryRule> rules =
credentialAccessBoundary.getAccessBoundaryRules();
ClientSideAccessBoundary.Builder accessBoundaryBuilder =
ClientSideAccessBoundary.newBuilder();

for (AccessBoundaryRule rule : rules) {
ClientSideAccessBoundaryRule.Builder ruleBuilder =
accessBoundaryBuilder.addAccessBoundaryRulesBuilder()
.addAllAvailablePermissions(rule.getAvailablePermissions())
.setAvailableResource(rule.getAvailableResource());

if (rule.getAvailabilityCondition() != null) {
String availabilityCondition =
rule.getAvailabilityCondition().getExpression();

Expr availabilityConditionExpr = this.compileCel(availabilityCondition);
ruleBuilder.setCompiledAvailabilityCondition(availabilityConditionExpr);
}
}

return accessBoundaryBuilder.build().toByteArray();
}

private Expr compileCel(String expr) throws IOException {
try {
CelAbstractSyntaxTree ast = celCompiler.parse(expr).getAst();

CelProtoAbstractSyntaxTree astProto =
CelProtoAbstractSyntaxTree.fromCelAst(ast);

return astProto.getExpr();

} catch (CelValidationException exception) {
throw new IOException("Failed to parse CEL expression: " +
exception.getMessage());
}
}

private byte[] encryptRestrictions(byte[] restriction, String sessionKey) throws InternalError {
try {
byte[] rawKey = Base64.getDecoder().decode(sessionKey);

KeysetHandle keysetHandle = TinkProtoKeysetFormat.parseKeyset(
rawKey, InsecureSecretKeyAccess.get());

Aead aead =
keysetHandle.getPrimitive(RegistryConfiguration.get(), Aead.class);

return aead.encrypt(restriction, /*associatedData=*/new byte[0]);
} catch (GeneralSecurityException exception) {
throw new InternalError("Failed to parse keyset: " + exception.getMessage());
}
}

public static Builder newBuilder() {
return new Builder();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

Expand All @@ -44,15 +45,28 @@
import com.google.auth.TestUtils;
import com.google.auth.credentialaccessboundary.ClientSideCredentialAccessBoundaryFactory.IntermediateCredentials;
import com.google.auth.credentialaccessboundary.ClientSideCredentialAccessBoundaryFactory.RefreshType;
import com.google.auth.credentialaccessboundary.protobuf.ClientSideAccessBoundaryProto.ClientSideAccessBoundary;
import com.google.auth.credentialaccessboundary.protobuf.ClientSideAccessBoundaryProto.ClientSideAccessBoundaryRule;
import com.google.auth.http.HttpTransportFactory;
import com.google.auth.oauth2.AccessToken;
import com.google.auth.oauth2.CredentialAccessBoundary;
import com.google.auth.oauth2.GoogleCredentials;
import com.google.auth.oauth2.MockStsTransport;
import com.google.auth.oauth2.MockTokenServerTransportFactory;
import com.google.auth.oauth2.OAuth2Utils;
import com.google.auth.oauth2.ServiceAccountCredentials;
import com.google.common.collect.ImmutableList;
import com.google.crypto.tink.Aead;
import com.google.crypto.tink.InsecureSecretKeyAccess;
import com.google.crypto.tink.KeysetHandle;
import com.google.crypto.tink.RegistryConfiguration;
import com.google.crypto.tink.TinkProtoKeysetFormat;

import dev.cel.expr.Expr;

import java.io.IOException;
import java.time.Duration;
import java.util.Base64;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import org.junit.Before;
Expand Down Expand Up @@ -572,4 +586,173 @@ private static void triggerConcurrentRefresh(
}
}
}

@Test
public void generateToken() throws Exception {
MockStsTransportFactory transportFactory = new MockStsTransportFactory();
transportFactory.transport.setReturnAccessBoundarySessionKey(true);

ClientSideCredentialAccessBoundaryFactory.Builder builder =
ClientSideCredentialAccessBoundaryFactory.newBuilder();

ClientSideCredentialAccessBoundaryFactory factory =
builder
.setSourceCredential(getServiceAccountSourceCredentials(
mockTokenServerTransportFactory))
.setHttpTransportFactory(transportFactory)
.build();

CredentialAccessBoundary.Builder cabBuilder =
CredentialAccessBoundary.newBuilder();
CredentialAccessBoundary accessBoundary =
cabBuilder
.addRule(
CredentialAccessBoundary.AccessBoundaryRule.newBuilder()
.setAvailableResource("//storage.googleapis.com/projects/"
+ "_/buckets/example-bucket")
.setAvailablePermissions(
ImmutableList.of("inRole:roles/storage.objectViewer"))
.setAvailabilityCondition(
CredentialAccessBoundary.AccessBoundaryRule
.AvailabilityCondition.newBuilder()
.setExpression(
"resource.name.startsWith('projects/_/"
+ "buckets/example-bucket/objects/customer-a')")
.build())
.build())
.build();

AccessToken token = factory.generateToken(accessBoundary);

String[] parts = token.getTokenValue().split("\\.");
assertEquals(parts.length, 2);
assertEquals(parts[0], "accessToken");

byte[] rawKey = Base64.getDecoder().decode(
transportFactory.transport.getAccessBoundarySessionKey());

KeysetHandle keysetHandle = TinkProtoKeysetFormat.parseKeyset(
rawKey, InsecureSecretKeyAccess.get());

Aead aead =
keysetHandle.getPrimitive(RegistryConfiguration.get(), Aead.class);
byte[] rawRestrictions =
aead.decrypt(Base64.getUrlDecoder().decode(parts[1]), new byte[0]);
ClientSideAccessBoundary clientSideAccessBoundary =
ClientSideAccessBoundary.parseFrom(rawRestrictions);
assertEquals(clientSideAccessBoundary.getAccessBoundaryRulesCount(), 1);
ClientSideAccessBoundaryRule rule =
clientSideAccessBoundary.getAccessBoundaryRules(0);
assertEquals(rule.getAvailableResource(),
"//storage.googleapis.com/projects/_/buckets/example-bucket");
assertEquals(rule.getAvailablePermissions(0),
"inRole:roles/storage.objectViewer");
Expr expr = rule.getCompiledAvailabilityCondition();
assertEquals(expr.getCallExpr()
.getTarget()
.getSelectExpr()
.getOperand()
.getIdentExpr()
.getName(),
"resource");
assertEquals(expr.getCallExpr().getFunction(), "startsWith");
assertEquals(expr.getCallExpr().getArgs(0).getConstExpr().getStringValue(),
"projects/_/buckets/example-bucket/objects/customer-a");
}

@Test
public void generateToken_withoutAvailabilityCondition() throws Exception {
MockStsTransportFactory transportFactory = new MockStsTransportFactory();
transportFactory.transport.setReturnAccessBoundarySessionKey(true);

ClientSideCredentialAccessBoundaryFactory.Builder builder =
ClientSideCredentialAccessBoundaryFactory.newBuilder();

ClientSideCredentialAccessBoundaryFactory factory =
builder
.setSourceCredential(getServiceAccountSourceCredentials(
mockTokenServerTransportFactory))
.setHttpTransportFactory(transportFactory)
.build();

CredentialAccessBoundary.Builder cabBuilder =
CredentialAccessBoundary.newBuilder();
CredentialAccessBoundary accessBoundary =
cabBuilder
.addRule(
CredentialAccessBoundary.AccessBoundaryRule.newBuilder()
.setAvailableResource("//storage.googleapis.com/projects/"
+ "_/buckets/example-bucket")
.setAvailablePermissions(
ImmutableList.of("inRole:roles/storage.objectViewer"))
.build())
.build();

AccessToken token = factory.generateToken(accessBoundary);

String[] parts = token.getTokenValue().split("\\.");
assertEquals(parts.length, 2);
assertEquals(parts[0], "accessToken");

byte[] rawKey = Base64.getDecoder().decode(
transportFactory.transport.getAccessBoundarySessionKey());

KeysetHandle keysetHandle = TinkProtoKeysetFormat.parseKeyset(
rawKey, InsecureSecretKeyAccess.get());

Aead aead =
keysetHandle.getPrimitive(RegistryConfiguration.get(), Aead.class);
byte[] rawRestrictions =
aead.decrypt(Base64.getUrlDecoder().decode(parts[1]), new byte[0]);
ClientSideAccessBoundary clientSideAccessBoundary =
ClientSideAccessBoundary.parseFrom(rawRestrictions);
assertEquals(clientSideAccessBoundary.getAccessBoundaryRulesCount(), 1);
ClientSideAccessBoundaryRule rule =
clientSideAccessBoundary.getAccessBoundaryRules(0);
assertEquals(rule.getAvailableResource(),
"//storage.googleapis.com/projects/_/buckets/example-bucket");
assertEquals(rule.getAvailablePermissions(0),
"inRole:roles/storage.objectViewer");
assertTrue(rule.getCompiledAvailabilityCondition().equals(
Expr.getDefaultInstance()));
}

@Test
public void generateToken_withInvalidCelExpression() throws Exception {
MockStsTransportFactory transportFactory = new MockStsTransportFactory();
transportFactory.transport.setReturnAccessBoundarySessionKey(true);

ClientSideCredentialAccessBoundaryFactory.Builder builder =
ClientSideCredentialAccessBoundaryFactory.newBuilder();

ClientSideCredentialAccessBoundaryFactory factory =
builder
.setSourceCredential(getServiceAccountSourceCredentials(
mockTokenServerTransportFactory))
.setHttpTransportFactory(transportFactory)
.build();

CredentialAccessBoundary.Builder cabBuilder =
CredentialAccessBoundary.newBuilder();
CredentialAccessBoundary accessBoundary =
cabBuilder
.addRule(
CredentialAccessBoundary.AccessBoundaryRule.newBuilder()
.setAvailableResource("//storage.googleapis.com/projects/"
+ "_/buckets/example-bucket")
.setAvailablePermissions(
ImmutableList.of("inRole:roles/storage.objectViewer"))
.setAvailabilityCondition(
CredentialAccessBoundary.AccessBoundaryRule
.AvailabilityCondition.newBuilder()
.setExpression(
"resource.name.startsWith('projects/_/"
+ "buckets/example-bucket/objects/customer-a'")
.build())
.build())
.build();

assertThrows(IOException.class,
() -> { factory.generateToken(accessBoundary); });
}
}
Loading

0 comments on commit 1177072

Please sign in to comment.