Skip to content

Commit

Permalink
feat: Ability to split roles by Delimiter (#540)
Browse files Browse the repository at this point in the history
  • Loading branch information
akurnosau authored Oct 23, 2024
1 parent 082dce7 commit 71da307
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Priority order:
| identityProviders.*.jwksUrl | - | Optional |Url to jwks provider. **Required** if `disabledVerifyJwt` is set to `false`. **Note**: Either `jwksUrl` or `userInfoEndpoint` must be provided.
| identityProviders.*.userInfoEndpoint | - | Optional |Url to user info endpoint. **Note**: Either `jwksUrl` or `userInfoEndpoint` must be provided or `disableJwtVerification` is unset. Refer to [Google example](sample/aidial.settings.json).
| identityProviders.*.rolePath | - | Yes |Path to the claim user roles in JWT token or user info response, e.g. `resource_access.chatbot-ui.roles` or just `roles`. Refer to [IDP Configuration](https://github.com/epam/ai-dial/blob/main/docs/Auth/2.%20Web/1.overview.md) to view guidelines for configuring supported providers.
| identityProviders.*.rolesDelimiter | - | No |Delimiter to split roles into array in case when list of roles presented as single String. e.g. `"rolesDelimiter": " "`
| identityProviders.*.loggingKey | - | No |User information to search in claims of JWT token. `email` or `sub` should be sufficient in most cases. **Note**: `email` might be unavailable for some IDPs. Please check your IDP documentation in this case.
| identityProviders.*.loggingSalt | - | No |Salt to hash user information for logging.
| identityProviders.*.positiveCacheExpirationMs | 600000 | No | How long to retain JWKS response in the cache in case of successfull response.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.interfaces.RSAPublicKey;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -40,6 +41,9 @@ public class IdentityProvider {
// path to the claim of user roles in JWT
private final String[] rolePath;

// Delimiter to split the roles if they are set as a single String
private final String rolesDelimiter;

private JwkProvider jwkProvider;

private URL userInfoUrl;
Expand Down Expand Up @@ -113,6 +117,7 @@ public IdentityProvider(JsonObject settings, Vertx vertx, HttpClient client,
String rolePathStr = Objects.requireNonNull(settings.getString("rolePath"), "rolePath is missed");
getUserRoleFn = factory.getUserRoleFn(rolePathStr);
rolePath = rolePathStr.split("\\.");
rolesDelimiter = settings.getString("rolesDelimiter");

loggingKey = settings.getString("loggingKey");
if (loggingKey != null) {
Expand Down Expand Up @@ -153,6 +158,12 @@ private List<String> extractUserRoles(Map<String, Object> map) {
if (next instanceof List) {
return (List<String>) next;
} else if (next instanceof String) {
if (rolesDelimiter != null) {
return Arrays.stream(((String) next)
.split(rolesDelimiter))
.filter(s -> !s.isBlank())
.toList();
}
return List.of((String) next);
}
} else {
Expand Down Expand Up @@ -324,4 +335,4 @@ boolean hasUserinfoUrl() {

private record JwkResult(Jwk jwk, Exception error, long expirationTime) {
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,162 @@ public void testExtractClaims_14() {
});
}

@Test
public void testExtractClaims_15() throws JwkException {
settings.put("rolesDelimiter", " ");
IdentityProvider identityProvider = new IdentityProvider(settings, vertx, client, url -> jwkProvider, factory);
Algorithm algorithm = Algorithm.RSA256((RSAPublicKey) keyPair.getPublic(), (RSAPrivateKey) keyPair.getPrivate());

String token = JWT.create().withHeader(Map.of("kid", "kid1")).withClaim("roles", "r1 r2 r3").sign(algorithm);
Jwk jwk = mock(Jwk.class);
when(jwk.getPublicKey()).thenReturn(keyPair.getPublic());
when(jwkProvider.get(eq("kid1"))).thenReturn(jwk);
when(vertx.executeBlocking(any(Callable.class), eq(false))).thenAnswer(invocation -> {
Callable<?> callable = invocation.getArgument(0);
return Future.succeededFuture(callable.call());
});

Future<ExtractedClaims> result = identityProvider.extractClaimsFromJwt(JWT.decode(token));

assertNotNull(result);
result.onComplete(res -> {
assertTrue(res.succeeded());
ExtractedClaims claims = res.result();
assertNotNull(claims);
assertEquals(List.of("r1", "r2", "r3"), claims.userRoles());
});
}

@Test
public void testExtractClaims_16() throws JwkException {
settings.put("rolesDelimiter", ":");
IdentityProvider identityProvider = new IdentityProvider(settings, vertx, client, url -> jwkProvider, factory);
Algorithm algorithm = Algorithm.RSA256((RSAPublicKey) keyPair.getPublic(), (RSAPrivateKey) keyPair.getPrivate());

String token = JWT.create().withHeader(Map.of("kid", "kid1")).withClaim("roles", "r1 r2 r3").sign(algorithm);
Jwk jwk = mock(Jwk.class);
when(jwk.getPublicKey()).thenReturn(keyPair.getPublic());
when(jwkProvider.get(eq("kid1"))).thenReturn(jwk);
when(vertx.executeBlocking(any(Callable.class), eq(false))).thenAnswer(invocation -> {
Callable<?> callable = invocation.getArgument(0);
return Future.succeededFuture(callable.call());
});

Future<ExtractedClaims> result = identityProvider.extractClaimsFromJwt(JWT.decode(token));

assertNotNull(result);
result.onComplete(res -> {
assertTrue(res.succeeded());
ExtractedClaims claims = res.result();
assertNotNull(claims);
assertEquals(List.of("r1 r2 r3"), claims.userRoles());
});
}

@Test
public void testExtractClaims_17() throws JwkException {
settings.put("rolesDelimiter", " ");
IdentityProvider identityProvider = new IdentityProvider(settings, vertx, client, url -> jwkProvider, factory);
Algorithm algorithm = Algorithm.RSA256((RSAPublicKey) keyPair.getPublic(), (RSAPrivateKey) keyPair.getPrivate());

String token = JWT.create().withHeader(Map.of("kid", "kid1")).withClaim("roles", List.of("r1", "r2 r3")).sign(algorithm);
Jwk jwk = mock(Jwk.class);
when(jwk.getPublicKey()).thenReturn(keyPair.getPublic());
when(jwkProvider.get(eq("kid1"))).thenReturn(jwk);
when(vertx.executeBlocking(any(Callable.class), eq(false))).thenAnswer(invocation -> {
Callable<?> callable = invocation.getArgument(0);
return Future.succeededFuture(callable.call());
});

Future<ExtractedClaims> result = identityProvider.extractClaimsFromJwt(JWT.decode(token));

assertNotNull(result);
result.onComplete(res -> {
assertTrue(res.succeeded());
ExtractedClaims claims = res.result();
assertNotNull(claims);
assertEquals(List.of("r1", "r2 r3"), claims.userRoles());
});
}

@Test
public void testExtractClaims_18() throws JwkException {
settings.put("rolesDelimiter", " ");
IdentityProvider identityProvider = new IdentityProvider(settings, vertx, client, url -> jwkProvider, factory);
Algorithm algorithm = Algorithm.RSA256((RSAPublicKey) keyPair.getPublic(), (RSAPrivateKey) keyPair.getPrivate());

String token = JWT.create().withHeader(Map.of("kid", "kid1")).withClaim("roles", (String) null).sign(algorithm);
Jwk jwk = mock(Jwk.class);
when(jwk.getPublicKey()).thenReturn(keyPair.getPublic());
when(jwkProvider.get(eq("kid1"))).thenReturn(jwk);
when(vertx.executeBlocking(any(Callable.class), eq(false))).thenAnswer(invocation -> {
Callable<?> callable = invocation.getArgument(0);
return Future.succeededFuture(callable.call());
});

Future<ExtractedClaims> result = identityProvider.extractClaimsFromJwt(JWT.decode(token));

assertNotNull(result);
result.onComplete(res -> {
assertTrue(res.succeeded());
ExtractedClaims claims = res.result();
assertNotNull(claims);
assertEquals(Collections.EMPTY_LIST, claims.userRoles());
});
}

@Test
public void testExtractClaims_19() throws JwkException {
settings.put("rolesDelimiter", " ");
IdentityProvider identityProvider = new IdentityProvider(settings, vertx, client, url -> jwkProvider, factory);
Algorithm algorithm = Algorithm.RSA256((RSAPublicKey) keyPair.getPublic(), (RSAPrivateKey) keyPair.getPrivate());

String token = JWT.create().withHeader(Map.of("kid", "kid1")).withClaim("roles", "").sign(algorithm);
Jwk jwk = mock(Jwk.class);
when(jwk.getPublicKey()).thenReturn(keyPair.getPublic());
when(jwkProvider.get(eq("kid1"))).thenReturn(jwk);
when(vertx.executeBlocking(any(Callable.class), eq(false))).thenAnswer(invocation -> {
Callable<?> callable = invocation.getArgument(0);
return Future.succeededFuture(callable.call());
});

Future<ExtractedClaims> result = identityProvider.extractClaimsFromJwt(JWT.decode(token));

assertNotNull(result);
result.onComplete(res -> {
assertTrue(res.succeeded());
ExtractedClaims claims = res.result();
assertNotNull(claims);
assertEquals(Collections.emptyList(), claims.userRoles());
});
}

@Test
public void testExtractClaims_20() throws JwkException {
settings.put("rolesDelimiter", " ");
IdentityProvider identityProvider = new IdentityProvider(settings, vertx, client, url -> jwkProvider, factory);
Algorithm algorithm = Algorithm.RSA256((RSAPublicKey) keyPair.getPublic(), (RSAPrivateKey) keyPair.getPrivate());

String token = JWT.create().withHeader(Map.of("kid", "kid1")).withClaim("roles", "r1 r2 r3 r4").sign(algorithm);
Jwk jwk = mock(Jwk.class);
when(jwk.getPublicKey()).thenReturn(keyPair.getPublic());
when(jwkProvider.get(eq("kid1"))).thenReturn(jwk);
when(vertx.executeBlocking(any(Callable.class), eq(false))).thenAnswer(invocation -> {
Callable<?> callable = invocation.getArgument(0);
return Future.succeededFuture(callable.call());
});

Future<ExtractedClaims> result = identityProvider.extractClaimsFromJwt(JWT.decode(token));

assertNotNull(result);
result.onComplete(res -> {
assertTrue(res.succeeded());
ExtractedClaims claims = res.result();
assertNotNull(claims);
assertEquals(List.of("r1", "r2", "r3", "r4"), claims.userRoles());
});
}

@Test
public void testExtractClaims_FromUserInfo_01() {
settings.remove("jwksUrl");
Expand Down

0 comments on commit 71da307

Please sign in to comment.