diff --git a/ranger-authn/pom.xml b/ranger-authn/pom.xml
index 9c0ad8dc59..3d75939692 100644
--- a/ranger-authn/pom.xml
+++ b/ranger-authn/pom.xml
@@ -86,6 +86,12 @@
${nimbus-jose-jwt.version}
+
+ org.springframework.security
+ spring-security-core
+ ${springframework.security.version}
+
+
org.junit.jupiter
diff --git a/ranger-authn/src/main/java/org/apache/ranger/authz/authority/JwtAuthority.java b/ranger-authn/src/main/java/org/apache/ranger/authz/authority/JwtAuthority.java
new file mode 100644
index 0000000000..d184324009
--- /dev/null
+++ b/ranger-authn/src/main/java/org/apache/ranger/authz/authority/JwtAuthority.java
@@ -0,0 +1,25 @@
+package org.apache.ranger.authz.authority;
+
+import java.util.Set;
+import org.springframework.security.core.GrantedAuthority;
+
+public final class JwtAuthority implements GrantedAuthority {
+ private static final long serialVersionUID = 12323L;
+ private final String role;
+ private final Set groups;
+
+ public JwtAuthority(String role, Set groups) {
+ this.role = role;
+ this.groups = groups;
+ }
+
+ public String getAuthority() {
+ return this.role;
+ }
+
+ public Set getGroups() { return this.groups; }
+
+ public String toString() {
+ return this.role;
+ }
+}
\ No newline at end of file
diff --git a/ranger-authn/src/main/java/org/apache/ranger/authz/handler/jwt/RangerJwtAuthHandler.java b/ranger-authn/src/main/java/org/apache/ranger/authz/handler/jwt/RangerJwtAuthHandler.java
index 17063cedfb..7a9cdfa247 100644
--- a/ranger-authn/src/main/java/org/apache/ranger/authz/handler/jwt/RangerJwtAuthHandler.java
+++ b/ranger-authn/src/main/java/org/apache/ranger/authz/handler/jwt/RangerJwtAuthHandler.java
@@ -20,6 +20,8 @@
import java.net.URL;
import java.text.ParseException;
+import java.util.Set;
+import java.util.HashSet;
import java.util.Arrays;
import java.util.Date;
import java.util.List;
@@ -43,12 +45,14 @@
import com.nimbusds.jose.proc.JWSVerificationKeySelector;
import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jwt.SignedJWT;
+import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.proc.ConfigurableJWTProcessor;
public abstract class RangerJwtAuthHandler implements RangerAuthHandler {
private static final Logger LOG = LoggerFactory.getLogger(RangerJwtAuthHandler.class);
private JWSVerifier verifier = null;
+ protected SignedJWT signedJWT = null;
private String jwksProviderUrl = null;
public static final String TYPE = "ranger-jwt"; // Constant that identifies the authentication mechanism.
public static final String KEY_PROVIDER_URL = "jwks.provider-url"; // JWKS provider URL
@@ -56,7 +60,10 @@ public abstract class RangerJwtAuthHandler implements RangerAuthHandler {
public static final String KEY_JWT_COOKIE_NAME = "jwt.cookie-name"; // JWT cookie name
public static final String KEY_JWT_AUDIENCES = "jwt.audiences";
public static final String JWT_AUTHZ_PREFIX = "Bearer ";
+ public static final String CUSTOM_JWT_CLAIM_GROUP_KEY_PARAM = "custom.jwt.claim.group.key";
+ public static final String CUSTOM_JWT_CLAIM_GROUP_KEY_VALUE_DEFAULT = "knox.groups";
+ public String CUSTOM_JWT_CLAIM_GROUP_KEY_VALUE = null;
protected List audiences = null;
protected JWKSource keySource = null;
@@ -76,6 +83,7 @@ public void initialize(final Properties config) throws Exception {
// optional configurations
String pemPublicKey = config.getProperty(KEY_JWT_PUBLIC_KEY);
+ CUSTOM_JWT_CLAIM_GROUP_KEY_VALUE = config.getProperty(CUSTOM_JWT_CLAIM_GROUP_KEY_PARAM, CUSTOM_JWT_CLAIM_GROUP_KEY_VALUE_DEFAULT);
// setup JWT provider public key if configured
if (StringUtils.isNotBlank(pemPublicKey)) {
@@ -112,30 +120,36 @@ protected AuthenticationToken authenticate(final String jwtAuthHeader, final Str
if (StringUtils.isNotBlank(serializedJWT)) {
try {
- final SignedJWT jwtToken = SignedJWT.parse(serializedJWT);
- boolean valid = validateToken(jwtToken);
+ signedJWT = SignedJWT.parse(serializedJWT);
+ JWTClaimsSet claimsSet = getJWTClaimsSet();
+
+ if(LOG.isDebugEnabled()){
+ LOG.debug("RangerJwtAuthHandler.authenticate(): JWTClaimsSet - {}", claimsSet);
+ }
+
+ boolean valid = validateToken();
if (valid) {
String userName;
if (StringUtils.isNotBlank(doAsUser)) {
userName = doAsUser.trim();
} else {
- userName = jwtToken.getJWTClaimsSet().getSubject();
+ userName = claimsSet.getSubject();
}
if (LOG.isDebugEnabled()) {
LOG.debug("RangerJwtAuthHandler.authenticate(): Issuing AuthenticationToken for user: [{}]", userName);
- LOG.debug("RangerJwtAuthHandler.authenticate(): Authentication successful for user [{}] and doAs user is [{}]", jwtToken.getJWTClaimsSet().getSubject(), doAsUser);
+ LOG.debug("RangerJwtAuthHandler.authenticate(): Authentication successful for user [{}] and doAs user is [{}]", claimsSet.getSubject(), doAsUser);
}
token = new AuthenticationToken(userName, userName, TYPE);
} else {
- LOG.warn("RangerJwtAuthHandler.authenticate(): Validation failed for JWT token: [{}] ", jwtToken.serialize());
+ LOG.warn("RangerJwtAuthHandler.authenticate(): Validation failed for JWT: [{}] ", signedJWT.serialize());
}
- } catch (ParseException pe) {
- LOG.warn("RangerJwtAuthHandler.authenticate(): Unable to parse the JWT token", pe);
+ } catch (ParseException | RuntimeException exp) {
+ LOG.warn("RangerJwtAuthHandler.authenticate(): Unable to parse the JWT", exp);
}
} else {
- LOG.warn("RangerJwtAuthHandler.authenticate(): JWT token not found.");
+ LOG.warn("RangerJwtAuthHandler.authenticate(): JWT not found");
}
}
@@ -145,6 +159,31 @@ protected AuthenticationToken authenticate(final String jwtAuthHeader, final Str
return token;
}
+
+ protected JWTClaimsSet getJWTClaimsSet() throws ParseException {
+ return signedJWT.getJWTClaimsSet();
+ }
+
+ public Set getGroupsFromClaimSet() {
+ List groupsClaim = null;
+ try {
+ groupsClaim = (List) getJWTClaimsSet().getClaim(CUSTOM_JWT_CLAIM_GROUP_KEY_VALUE);
+ } catch (ParseException e) {
+ LOG.error("Unable to parse JWT claim set", e);
+ }
+
+ if (groupsClaim == null) {
+ LOG.warn("No group claim found!");
+ return new HashSet<>();
+ }
+
+ Set groups = new HashSet<>(groupsClaim);
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Groups present in Claim [{}]: {}", CUSTOM_JWT_CLAIM_GROUP_KEY_VALUE, groups);
+ }
+ return groups;
+ }
+
protected String getJWT(final String jwtAuthHeader, final String jwtCookie) {
String serializedJWT = null;
@@ -171,19 +210,18 @@ protected String getJWT(final String jwtAuthHeader, final String jwtCookie) {
* implementation through submethods used within but also allows for the
* override of the entire token validation algorithm.
*
- * @param jwtToken the token to validate
* @return true if valid
*/
- protected boolean validateToken(final SignedJWT jwtToken) {
- boolean expValid = validateExpiration(jwtToken);
+ protected boolean validateToken() throws ParseException {
+ boolean expValid = validateExpiration();
boolean sigValid = false;
boolean audValid = false;
if (expValid) {
- sigValid = validateSignature(jwtToken);
+ sigValid = validateSignature();
if (sigValid) {
- audValid = validateAudiences(jwtToken);
+ audValid = validateAudiences();
}
}
@@ -195,41 +233,40 @@ protected boolean validateToken(final SignedJWT jwtToken) {
}
/**
- * Verify the signature of the JWT token in this method. This method depends on
+ * Verify the signature of the JWT in this method. This method depends on
* the public key that was established during init based upon the provisioned
* public key. Override this method in subclasses in order to customize the
* signature verification behavior.
*
- * @param jwtToken the token that contains the signature to be validated
* @return valid true if signature verifies successfully; false otherwise
*/
- protected boolean validateSignature(final SignedJWT jwtToken) {
+ protected boolean validateSignature() {
boolean valid = false;
- if (JWSObject.State.SIGNED == jwtToken.getState()) {
+ if (JWSObject.State.SIGNED == signedJWT.getState()) {
if (LOG.isDebugEnabled()) {
- LOG.debug("JWT token is in a SIGNED state");
+ LOG.debug("JWT is in a SIGNED state");
}
- if (jwtToken.getSignature() != null) {
+ if (signedJWT.getSignature() != null) {
try {
if (StringUtils.isNotBlank(jwksProviderUrl)) {
- JWSKeySelector keySelector = new JWSVerificationKeySelector<>(jwtToken.getHeader().getAlgorithm(), keySource);
+ JWSKeySelector keySelector = new JWSVerificationKeySelector<>(signedJWT.getHeader().getAlgorithm(), keySource);
// Create a JWT processor for the access tokens
ConfigurableJWTProcessor jwtProcessor = getJwtProcessor(keySelector);
// Process the token
- jwtProcessor.process(jwtToken, null);
+ jwtProcessor.process(signedJWT, null);
valid = true;
if (LOG.isDebugEnabled()) {
- LOG.debug("JWT token has been successfully verified.");
+ LOG.debug("JWT has been successfully verified.");
}
} else if (verifier != null) {
- if (jwtToken.verify(verifier)) {
+ if (signedJWT.verify(verifier)) {
valid = true;
if (LOG.isDebugEnabled()) {
- LOG.debug("JWT token has been successfully verified.");
+ LOG.debug("JWT has been successfully verified.");
}
} else {
LOG.warn("JWT signature verification failed.");
@@ -257,61 +294,51 @@ protected boolean validateSignature(final SignedJWT jwtToken) {
* token claims list for audience. Override this method in subclasses in order
* to customize the audience validation behavior.
*
- * @param jwtToken the JWT token where the allowed audiences will be found
* @return true if an expected audience is present, otherwise false
*/
- protected boolean validateAudiences(final SignedJWT jwtToken) {
+ protected boolean validateAudiences() throws ParseException {
boolean valid = false;
- try {
- List tokenAudienceList = jwtToken.getJWTClaimsSet().getAudience();
- // if there were no expected audiences configured then just
- // consider any audience acceptable
- if (audiences == null) {
- valid = true;
- } else {
- // if any of the configured audiences is found then consider it
- // acceptable
- for (String aud : tokenAudienceList) {
- if (audiences.contains(aud)) {
- if (LOG.isDebugEnabled()) {
- LOG.debug("JWT token audience has been successfully validated.");
- }
- valid = true;
- break;
+ JWTClaimsSet claimsSet = getJWTClaimsSet();
+ List tokenAudienceList = claimsSet.getAudience();
+ // if there were no expected audiences configured then just consider any audience acceptable
+ if (audiences == null) {
+ valid = true;
+ } else {
+ // if any of the configured audiences is found then consider it acceptable
+ for (String aud : tokenAudienceList) {
+ if (audiences.contains(aud)) {
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("JWT audience has been successfully validated.");
}
- }
- if (!valid) {
- LOG.warn("JWT audience validation failed.");
+ valid = true;
+ break;
}
}
- } catch (ParseException pe) {
- LOG.warn("Unable to parse the JWT token.", pe);
+ }
+
+ if (!valid) {
+ LOG.warn("JWT audience validation failed.");
}
return valid;
}
/**
- * Validate that the expiration time of the JWT token has not been violated. If
- * it has then throw an AuthenticationException. Override this method in
+ * Validate that the expiration time of the JWT has not been violated. If
+ * it has, then throw an AuthenticationException. Override this method in
* subclasses in order to customize the expiration validation behavior.
*
- * @param jwtToken the token that contains the expiration date to validate
* @return valid true if the token has not expired; false otherwise
*/
- protected boolean validateExpiration(final SignedJWT jwtToken) {
+ protected boolean validateExpiration() throws ParseException {
boolean valid = false;
- try {
- Date expires = jwtToken.getJWTClaimsSet().getExpirationTime();
- if (expires == null || new Date().before(expires)) {
- valid = true;
- if (LOG.isDebugEnabled()) {
- LOG.debug("JWT token expiration date has been successfully validated.");
- }
- } else {
- LOG.warn("JWT token provided is expired.");
+ Date expires = getJWTClaimsSet().getExpirationTime();
+ if (expires == null || new Date().before(expires)) {
+ valid = true;
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("JWT expiration date has been successfully validated.");
}
- } catch (ParseException pe) {
- LOG.warn("Failed to validate JWT expiry.", pe);
+ } else {
+ LOG.warn("JWT provided has expired.");
}
return valid;
diff --git a/security-admin/src/main/java/org/apache/ranger/security/web/filter/RangerJwtAuthFilter.java b/security-admin/src/main/java/org/apache/ranger/security/web/filter/RangerJwtAuthFilter.java
index f14adaaa8d..27acdd7dd7 100644
--- a/security-admin/src/main/java/org/apache/ranger/security/web/filter/RangerJwtAuthFilter.java
+++ b/security-admin/src/main/java/org/apache/ranger/security/web/filter/RangerJwtAuthFilter.java
@@ -19,6 +19,7 @@
package org.apache.ranger.security.web.filter;
import java.io.IOException;
+import java.util.Set;
import java.util.Arrays;
import java.util.List;
import java.util.Properties;
@@ -33,6 +34,7 @@
import javax.servlet.http.HttpServletRequest;
import org.apache.log4j.Logger;
+import org.apache.ranger.authz.authority.JwtAuthority;
import org.apache.ranger.authz.handler.RangerAuth;
import org.apache.ranger.authz.handler.jwt.RangerDefaultJwtAuthHandler;
import org.apache.ranger.authz.handler.jwt.RangerJwtAuthHandler;
@@ -42,7 +44,6 @@
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
-import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.userdetails.User;
import org.springframework.security.core.userdetails.UserDetails;
@@ -101,7 +102,8 @@ public void doFilter(ServletRequest request, ServletResponse response, FilterCha
RangerAuth rangerAuth = authenticate(httpServletRequest);
if (rangerAuth != null) {
- final List grantedAuths = Arrays.asList(new SimpleGrantedAuthority(DEFAULT_RANGER_ROLE));
+ final Set groups = getGroupsFromClaimSet();
+ final List grantedAuths = Arrays.asList(new JwtAuthority(DEFAULT_RANGER_ROLE, groups));
final UserDetails principal = new User(rangerAuth.getUserName(), "", grantedAuths);
final Authentication finalAuthentication = new UsernamePasswordAuthenticationToken(principal, "", grantedAuths);
final WebAuthenticationDetails webDetails = new WebAuthenticationDetails(httpServletRequest);