Skip to content

Commit

Permalink
feat: support custom claims check for publications (#309)(#306)
Browse files Browse the repository at this point in the history
  • Loading branch information
Maxim-Gadalov authored Apr 24, 2024
1 parent 248d363 commit 7a3937f
Show file tree
Hide file tree
Showing 8 changed files with 189 additions and 30 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.epam.aidial.core.security;

import java.util.List;
import java.util.Map;

public record ExtractedClaims(String sub, List<String> userRoles, String userHash) {
public record ExtractedClaims(String sub, List<String> userRoles, String userHash, Map<String, List<String>> userClaims) {
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.auth0.jwk.JwkProvider;
import com.auth0.jwt.JWT;
import com.auth0.jwt.algorithms.Algorithm;
import com.auth0.jwt.interfaces.Claim;
import com.auth0.jwt.interfaces.DecodedJWT;
import io.vertx.core.Future;
import io.vertx.core.Vertx;
Expand All @@ -15,6 +16,7 @@
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.interfaces.RSAPublicKey;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -201,6 +203,31 @@ private String extractUserHash(DecodedJWT decodedJwt) {
return keyClaim;
}

/**
* Extracts user claims from decoded JWT. Currently only strings or list of strings/primitives supported.
* If any other type provided - claim value will not be extracted, see IdentityProviderTest.testExtractClaims_13()
*
* @param decodedJwt - decoded JWT
* @return map of extracted user claims
*/
private Map<String, List<String>> extractUserClaims(DecodedJWT decodedJwt) {
Map<String, List<String>> userClaims = new HashMap<>();
for (Map.Entry<String, Claim> entry : decodedJwt.getClaims().entrySet()) {
String claimName = entry.getKey();
Claim claimValue = entry.getValue();
if (claimValue.asString() != null) {
userClaims.put(claimName, List.of(claimValue.asString()));
} else if (claimValue.asList(String.class) != null) {
userClaims.put(claimName, claimValue.asList(String.class));
} else {
// if claim value doesn't match supported type - add claim with empty value
userClaims.put(claimName, List.of());
}
}

return userClaims;
}

Future<ExtractedClaims> extractClaims(DecodedJWT decodedJwt) {
if (decodedJwt == null) {
return Future.failedFuture(new IllegalArgumentException("decoded JWT must not be null"));
Expand All @@ -212,7 +239,7 @@ Future<ExtractedClaims> extractClaims(DecodedJWT decodedJwt) {
}

private ExtractedClaims from(DecodedJWT jwt) {
return new ExtractedClaims(extractUserSub(jwt), extractUserRoles(jwt), extractUserHash(jwt));
return new ExtractedClaims(extractUserSub(jwt), extractUserRoles(jwt), extractUserHash(jwt), extractUserClaims(jwt));
}

boolean match(DecodedJWT jwt) {
Expand Down
46 changes: 28 additions & 18 deletions src/main/java/com/epam/aidial/core/security/RuleMatcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,57 +6,67 @@

import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;

@UtilityClass
public class RuleMatcher {

/**
*
* @return true if any of the provided rule matched (OR condition behaviour), otherwise - false
*/
public boolean match(ProxyContext context, Collection<Rule> rules) {
ExtractedClaims claims = context.getExtractedClaims();
if (claims == null) {
return false;
}

List<String> roles = claims.userRoles();
if (roles == null || roles.isEmpty() || rules == null || rules.isEmpty()) {
return false;
}
Map<String, List<String>> userClaims = claims.userClaims();

for (Rule rule : rules) {
if (!rule.getSource().equals("roles")) {
return false;
String targetClaim = rule.getSource();
List<String> sources;
if (targetClaim.equals("roles")) {
sources = claims.userRoles();
} else {
sources = userClaims.get(targetClaim);
}

if (sources == null) {
continue;
}

List<String> targets = rule.getTargets();
boolean match = switch (rule.getFunction()) {
case TRUE -> true;
case FALSE -> false;
case EQUAL -> equal(roles, targets);
case CONTAIN -> contain(roles, targets);
case REGEX -> regex(roles, targets);
case EQUAL -> equal(sources, targets);
case CONTAIN -> contain(sources, targets);
case REGEX -> regex(sources, targets);
};

if (!match) {
return false;
if (match) {
return true;
}
}

return true;
return false;
}

private boolean equal(List<String> roles, List<String> targets) {
private boolean equal(List<String> sources, List<String> targets) {
for (String target : targets) {
if (roles.contains(target)) {
if (sources.contains(target)) {
return true;
}
}

return false;
}

private boolean contain(List<String> roles, List<String> targets) {
private boolean contain(List<String> sources, List<String> targets) {
for (String target : targets) {
for (String role : roles) {
for (String role : sources) {
if (role.contains(target)) {
return true;
}
Expand All @@ -66,10 +76,10 @@ private boolean contain(List<String> roles, List<String> targets) {
return false;
}

private boolean regex(List<String> roles, List<String> targets) {
private boolean regex(List<String> sources, List<String> targets) {
for (String target : targets) {
Pattern pattern = Pattern.compile(target);
for (String role : roles) {
for (String role : sources) {
if (pattern.matcher(role).matches()) {
return true;
}
Expand Down
3 changes: 2 additions & 1 deletion src/test/java/com/epam/aidial/core/ResourceBaseTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
Expand Down Expand Up @@ -144,7 +145,7 @@ void init() throws Exception {
}

if (authorization.equals("user") || authorization.equals("admin")) {
return Future.succeededFuture(new ExtractedClaims(authorization, List.of(authorization), authorization));
return Future.succeededFuture(new ExtractedClaims(authorization, List.of(authorization), authorization, Map.of()));
}

return Future.failedFuture("Not authorized");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,8 @@ public void testLimit_User_LimitFound() {
config.getRoles().put("role2", role2);

ApiKeyData apiKeyData = new ApiKeyData();
ProxyContext proxyContext = new ProxyContext(config, request, apiKeyData, new ExtractedClaims("sub", List.of("role1", "role2"), "user-hash"), "trace-id", "span-id");
ProxyContext proxyContext = new ProxyContext(config, request, apiKeyData,
new ExtractedClaims("sub", List.of("role1", "role2"), "user-hash", Map.of()), "trace-id", "span-id");
Model model = new Model();
model.setName("model");
proxyContext.setDeployment(model);
Expand Down Expand Up @@ -386,7 +387,8 @@ public void testLimit_User_DefaultLimit() {
Config config = new Config();

ApiKeyData apiKeyData = new ApiKeyData();
ProxyContext proxyContext = new ProxyContext(config, request, apiKeyData, new ExtractedClaims("sub", List.of("role1"), "user-hash"), "trace-id", "span-id");
ProxyContext proxyContext = new ProxyContext(config, request, apiKeyData,
new ExtractedClaims("sub", List.of("role1"), "user-hash", Map.of()), "trace-id", "span-id");
Model model = new Model();
model.setName("model");
proxyContext.setDeployment(model);
Expand Down Expand Up @@ -441,7 +443,8 @@ public void testLimit_User_RequestLimit() {
config.getRoles().put("role2", role2);

ApiKeyData apiKeyData = new ApiKeyData();
ProxyContext proxyContext = new ProxyContext(config, request, apiKeyData, new ExtractedClaims("sub", List.of("role1", "role2"), "user-hash"), "trace-id", "span-id");
ProxyContext proxyContext = new ProxyContext(config, request, apiKeyData,
new ExtractedClaims("sub", List.of("role1", "role2"), "user-hash", Map.of()), "trace-id", "span-id");
Model model = new Model();
model.setName("model");
proxyContext.setDeployment(model);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.security.interfaces.RSAPublicKey;
import java.util.Collections;
import java.util.List;
import java.util.Map;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
Expand Down Expand Up @@ -105,7 +106,7 @@ public void testExtractClaims_05() throws NoSuchAlgorithmException {
when(provider1.match(any(DecodedJWT.class))).thenReturn(false);
IdentityProvider provider2 = mock(IdentityProvider.class);
when(provider2.match(any(DecodedJWT.class))).thenReturn(true);
when(provider2.extractClaims(any(DecodedJWT.class))).thenReturn(Future.succeededFuture(new ExtractedClaims("sub", Collections.emptyList(), "hash")));
when(provider2.extractClaims(any(DecodedJWT.class))).thenReturn(Future.succeededFuture(new ExtractedClaims("sub", Collections.emptyList(), "hash", Map.of())));
List<IdentityProvider> providerList = List.of(provider1, provider2);
validator.setProviders(providerList);
KeyPair keyPair = generateRsa256Pair();
Expand All @@ -127,7 +128,7 @@ public void testExtractClaims_05() throws NoSuchAlgorithmException {
public void testExtractClaims_06() throws NoSuchAlgorithmException {
AccessTokenValidator validator = new AccessTokenValidator(idpConfig, vertx);
IdentityProvider provider = mock(IdentityProvider.class);
when(provider.extractClaims(any(DecodedJWT.class))).thenReturn(Future.succeededFuture(new ExtractedClaims("sub", Collections.emptyList(), "hash")));
when(provider.extractClaims(any(DecodedJWT.class))).thenReturn(Future.succeededFuture(new ExtractedClaims("sub", Collections.emptyList(), "hash", Map.of())));
List<IdentityProvider> providerList = List.of(provider);
validator.setProviders(providerList);
KeyPair keyPair = generateRsa256Pair();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey;
import java.util.Collections;
import java.util.Date;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -320,6 +321,50 @@ public void testExtractClaims_12() {
});
}

@Test
public void testExtractClaims_13() {
settings.put("disableJwtVerification", Boolean.TRUE);
IdentityProvider identityProvider = new IdentityProvider(settings, vertx, url -> jwkProvider);
Algorithm algorithm = Algorithm.RSA256((RSAPublicKey) keyPair.getPublic(), (RSAPrivateKey) keyPair.getPrivate());

String token = JWT.create().withHeader(Map.of("kid", "kid1"))
.withClaim("roles", List.of("role"))
.withClaim("email", "test@email.com")
.withClaim("id", 15)
.withClaim("title", "title")
.withClaim("access", List.of("read", "write"))
.withClaim("expire", new Date(1713355825858L))
.withClaim("numberList", List.of(15, 17, 34))
.withClaim("map", Map.of("a", List.of("b")))
.withClaim("sub", "sub").sign(algorithm);

Future<ExtractedClaims> result = identityProvider.extractClaims(JWT.decode(token));

verifyNoInteractions(jwkProvider);

assertNotNull(result);
result.onComplete(res -> {
assertTrue(res.succeeded());
ExtractedClaims claims = res.result();
assertNotNull(claims);
assertEquals(List.of("role"), claims.userRoles());
assertEquals("sub", claims.sub());
assertNotNull(claims.userHash());
Map<String, List<String>> userClaims = claims.userClaims();
// assert user claim
assertEquals(9, userClaims.size());
assertEquals(List.of("sub"), userClaims.get("sub"));
assertEquals(List.of("read", "write"), userClaims.get("access"));
assertEquals(List.of("role"), userClaims.get("roles"));
assertEquals(List.of(), userClaims.get("expire"));
assertEquals(List.of("15", "17", "34"), userClaims.get("numberList"));
assertEquals(List.of(), userClaims.get("id"));
assertEquals(List.of("title"), userClaims.get("title"));
assertEquals(List.of(), userClaims.get("map"));
assertEquals(List.of("test@email.com"), userClaims.get("email"));
});
}

@Test
public void testMatch_Failure() {
IdentityProvider identityProvider = new IdentityProvider(settings, vertx, url -> jwkProvider);
Expand Down
79 changes: 75 additions & 4 deletions src/test/java/com/epam/aidial/core/security/RuleMatcherTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
import org.mockito.Mockito;

import java.util.List;
import java.util.Map;

class RuleMatcherTest {

@Test
void testRules() {
void testUserRoleRules() {
verify(rule("roles", Rule.Function.TRUE), true, "any-role");
verify(rule("roles", Rule.Function.FALSE), false, "any-role");

Expand All @@ -29,14 +30,84 @@ void testRules() {
verify(rule("roles", Rule.Function.REGEX, "(admin|user)$"), false, "user2");
}

void verify(Rule rule, boolean expected, String... roles) {
@Test
void testUserClaimRules() {
verify(List.of(rule("title", Rule.Function.TRUE)),
List.of("user1"), Map.of("title", List.of("Software Engineer")), true);
verify(List.of(rule("title", Rule.Function.FALSE)),
List.of("user2"), Map.of("title", List.of("Software Engineer")), false);

verify(List.of(rule("title", Rule.Function.EQUAL, "Software Engineer")),
List.of("admin"), Map.of("title", List.of("Software Engineer")), true);
verify(List.of(rule("title", Rule.Function.EQUAL, "Engineer")),
List.of("user"), Map.of("title", List.of("Software Engineer")), false);

verify(List.of(rule("email", Rule.Function.CONTAIN, "@example.com")),
List.of("admin"), Map.of("email", List.of("foo_bar@example.com")), true);
verify(List.of(rule("email", Rule.Function.CONTAIN, "@example.com")),
List.of("user"), Map.of("email", List.of("foo_bar@mail.com")), false);
verify(List.of(rule("email", Rule.Function.CONTAIN, "@example.com")),
List.of("user"), Map.of(), false);

verify(List.of(rule("title", Rule.Function.REGEX, ".*")),
List.of("admin"), Map.of("title", List.of("Developer")), true);
verify(List.of(rule("title", Rule.Function.REGEX, "(Developer|Manager)")),
List.of("user"), Map.of("title", List.of("Manager")), true);
verify(List.of(rule("title", Rule.Function.REGEX, ".*(Manager|Developer)$")),
List.of("user"), Map.of("title", List.of("Senior Delivery Manager")), true);
verify(List.of(rule("title", Rule.Function.REGEX, ".*(Manager|Developer)$")),
List.of("user"), Map.of("title", List.of("Manager Senior")), false);
}

@Test
void testCombinedRules() {
verify(List.of(rule("title", Rule.Function.TRUE), rule("roles", Rule.Function.EQUAL, "dial")),
List.of("user1"), Map.of("title", List.of("Software Engineer")), true);
verify(List.of(rule("title", Rule.Function.TRUE), rule("roles", Rule.Function.EQUAL, "dial")),
List.of("dial"), Map.of(), true);
verify(List.of(rule("title", Rule.Function.CONTAIN, "Software"), rule("roles", Rule.Function.EQUAL, "dial")),
List.of("custom"), Map.of("title", List.of("System Engineer")), false);

verify(List.of(rule("title", Rule.Function.EQUAL, "Software Engineer"), rule("roles", Rule.Function.EQUAL, "dial")),
List.of("admin"), Map.of("title", List.of("Software Engineer")), true);
verify(List.of(rule("title", Rule.Function.EQUAL, "Software Engineer"), rule("roles", Rule.Function.EQUAL, "dial")),
List.of("dial"), Map.of("title", List.of("Manager")), true);
verify(List.of(rule("title", Rule.Function.EQUAL, "Engineer"), rule("roles", Rule.Function.EQUAL, "dial")),
List.of("user"), Map.of("title", List.of("Software Engineer")), false);

verify(List.of(rule("email", Rule.Function.CONTAIN, "@example.com"), rule("roles", Rule.Function.EQUAL, "dial")),
List.of("admin"), Map.of("email", List.of("foo_bar@example.com")), true);
verify(List.of(rule("email", Rule.Function.CONTAIN, "@example.com"), rule("roles", Rule.Function.EQUAL, "dial")),
List.of("dial"), Map.of("email", List.of("foo_bar@example2.com")), true);
verify(List.of(rule("email", Rule.Function.CONTAIN, "@example.com"), rule("roles", Rule.Function.EQUAL, "dial")),
List.of("user"), Map.of("email", List.of("foo_bar@mail.com")), false);
verify(List.of(rule("email", Rule.Function.CONTAIN, "@example.com"), rule("roles", Rule.Function.EQUAL, "dial")),
List.of("user"), Map.of(), false);

verify(List.of(rule("title", Rule.Function.REGEX, ".*"), rule("roles", Rule.Function.EQUAL, "dial")),
List.of("admin"), Map.of("title", List.of("Developer")), true);
verify(List.of(rule("title", Rule.Function.REGEX, "^Developer$"), rule("roles", Rule.Function.EQUAL, "dial")),
List.of("dial"), Map.of("title", List.of("Manager")), true);
verify(List.of(rule("title", Rule.Function.REGEX, "(Developer|Manager)"), rule("roles", Rule.Function.EQUAL, "dial")),
List.of("dial"), Map.of("title", List.of("Human Resource")), true);
verify(List.of(rule("title", Rule.Function.REGEX, ".*(Manager|Developer)$"), rule("roles", Rule.Function.EQUAL, "dial")),
List.of("user"), Map.of("title", List.of("Senior Delivery Manager")), true);
verify(List.of(rule("title", Rule.Function.REGEX, ".*(Manager|Developer)$"), rule("roles", Rule.Function.EQUAL, "dial")),
List.of("user"), Map.of("title", List.of("Manager Senior")), false);
}

void verify(List<Rule> rules, List<String> userRoles, Map<String, List<String>> userClaims, boolean expected) {
ProxyContext context = Mockito.mock(ProxyContext.class);
ExtractedClaims claims = new ExtractedClaims("sub", List.of(roles), "hash");
ExtractedClaims claims = new ExtractedClaims("sub", userRoles, "hash", userClaims);
Mockito.when(context.getExtractedClaims()).thenReturn(claims);
boolean actual = RuleMatcher.match(context, List.of(rule));
boolean actual = RuleMatcher.match(context, rules);
Assertions.assertEquals(expected, actual);
}

void verify(Rule rule, boolean expected, String... roles) {
verify(List.of(rule), List.of(roles), Map.of(), expected);
}

Rule rule(String source, Rule.Function function, String... targets) {
Rule rule = new Rule();
rule.setSource(source);
Expand Down

0 comments on commit 7a3937f

Please sign in to comment.