diff --git a/ranger-authn/pom.xml b/ranger-authn/pom.xml index 9c0ad8dc59..3d75939692 100644 --- a/ranger-authn/pom.xml +++ b/ranger-authn/pom.xml @@ -86,6 +86,12 @@ ${nimbus-jose-jwt.version} + + org.springframework.security + spring-security-core + ${springframework.security.version} + + org.junit.jupiter diff --git a/ranger-authn/src/main/java/org/apache/ranger/authz/authority/JwtAuthority.java b/ranger-authn/src/main/java/org/apache/ranger/authz/authority/JwtAuthority.java new file mode 100644 index 0000000000..d184324009 --- /dev/null +++ b/ranger-authn/src/main/java/org/apache/ranger/authz/authority/JwtAuthority.java @@ -0,0 +1,25 @@ +package org.apache.ranger.authz.authority; + +import java.util.Set; +import org.springframework.security.core.GrantedAuthority; + +public final class JwtAuthority implements GrantedAuthority { + private static final long serialVersionUID = 12323L; + private final String role; + private final Set groups; + + public JwtAuthority(String role, Set groups) { + this.role = role; + this.groups = groups; + } + + public String getAuthority() { + return this.role; + } + + public Set getGroups() { return this.groups; } + + public String toString() { + return this.role; + } +} \ No newline at end of file diff --git a/ranger-authn/src/main/java/org/apache/ranger/authz/handler/jwt/RangerJwtAuthHandler.java b/ranger-authn/src/main/java/org/apache/ranger/authz/handler/jwt/RangerJwtAuthHandler.java index 17063cedfb..7a9cdfa247 100644 --- a/ranger-authn/src/main/java/org/apache/ranger/authz/handler/jwt/RangerJwtAuthHandler.java +++ b/ranger-authn/src/main/java/org/apache/ranger/authz/handler/jwt/RangerJwtAuthHandler.java @@ -20,6 +20,8 @@ import java.net.URL; import java.text.ParseException; +import java.util.Set; +import java.util.HashSet; import java.util.Arrays; import java.util.Date; import java.util.List; @@ -43,12 +45,14 @@ import com.nimbusds.jose.proc.JWSVerificationKeySelector; import com.nimbusds.jose.proc.SecurityContext; import com.nimbusds.jwt.SignedJWT; +import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.proc.ConfigurableJWTProcessor; public abstract class RangerJwtAuthHandler implements RangerAuthHandler { private static final Logger LOG = LoggerFactory.getLogger(RangerJwtAuthHandler.class); private JWSVerifier verifier = null; + protected SignedJWT signedJWT = null; private String jwksProviderUrl = null; public static final String TYPE = "ranger-jwt"; // Constant that identifies the authentication mechanism. public static final String KEY_PROVIDER_URL = "jwks.provider-url"; // JWKS provider URL @@ -56,7 +60,10 @@ public abstract class RangerJwtAuthHandler implements RangerAuthHandler { public static final String KEY_JWT_COOKIE_NAME = "jwt.cookie-name"; // JWT cookie name public static final String KEY_JWT_AUDIENCES = "jwt.audiences"; public static final String JWT_AUTHZ_PREFIX = "Bearer "; + public static final String CUSTOM_JWT_CLAIM_GROUP_KEY_PARAM = "custom.jwt.claim.group.key"; + public static final String CUSTOM_JWT_CLAIM_GROUP_KEY_VALUE_DEFAULT = "knox.groups"; + public String CUSTOM_JWT_CLAIM_GROUP_KEY_VALUE = null; protected List audiences = null; protected JWKSource keySource = null; @@ -76,6 +83,7 @@ public void initialize(final Properties config) throws Exception { // optional configurations String pemPublicKey = config.getProperty(KEY_JWT_PUBLIC_KEY); + CUSTOM_JWT_CLAIM_GROUP_KEY_VALUE = config.getProperty(CUSTOM_JWT_CLAIM_GROUP_KEY_PARAM, CUSTOM_JWT_CLAIM_GROUP_KEY_VALUE_DEFAULT); // setup JWT provider public key if configured if (StringUtils.isNotBlank(pemPublicKey)) { @@ -112,30 +120,36 @@ protected AuthenticationToken authenticate(final String jwtAuthHeader, final Str if (StringUtils.isNotBlank(serializedJWT)) { try { - final SignedJWT jwtToken = SignedJWT.parse(serializedJWT); - boolean valid = validateToken(jwtToken); + signedJWT = SignedJWT.parse(serializedJWT); + JWTClaimsSet claimsSet = getJWTClaimsSet(); + + if(LOG.isDebugEnabled()){ + LOG.debug("RangerJwtAuthHandler.authenticate(): JWTClaimsSet - {}", claimsSet); + } + + boolean valid = validateToken(); if (valid) { String userName; if (StringUtils.isNotBlank(doAsUser)) { userName = doAsUser.trim(); } else { - userName = jwtToken.getJWTClaimsSet().getSubject(); + userName = claimsSet.getSubject(); } if (LOG.isDebugEnabled()) { LOG.debug("RangerJwtAuthHandler.authenticate(): Issuing AuthenticationToken for user: [{}]", userName); - LOG.debug("RangerJwtAuthHandler.authenticate(): Authentication successful for user [{}] and doAs user is [{}]", jwtToken.getJWTClaimsSet().getSubject(), doAsUser); + LOG.debug("RangerJwtAuthHandler.authenticate(): Authentication successful for user [{}] and doAs user is [{}]", claimsSet.getSubject(), doAsUser); } token = new AuthenticationToken(userName, userName, TYPE); } else { - LOG.warn("RangerJwtAuthHandler.authenticate(): Validation failed for JWT token: [{}] ", jwtToken.serialize()); + LOG.warn("RangerJwtAuthHandler.authenticate(): Validation failed for JWT: [{}] ", signedJWT.serialize()); } - } catch (ParseException pe) { - LOG.warn("RangerJwtAuthHandler.authenticate(): Unable to parse the JWT token", pe); + } catch (ParseException | RuntimeException exp) { + LOG.warn("RangerJwtAuthHandler.authenticate(): Unable to parse the JWT", exp); } } else { - LOG.warn("RangerJwtAuthHandler.authenticate(): JWT token not found."); + LOG.warn("RangerJwtAuthHandler.authenticate(): JWT not found"); } } @@ -145,6 +159,31 @@ protected AuthenticationToken authenticate(final String jwtAuthHeader, final Str return token; } + + protected JWTClaimsSet getJWTClaimsSet() throws ParseException { + return signedJWT.getJWTClaimsSet(); + } + + public Set getGroupsFromClaimSet() { + List groupsClaim = null; + try { + groupsClaim = (List) getJWTClaimsSet().getClaim(CUSTOM_JWT_CLAIM_GROUP_KEY_VALUE); + } catch (ParseException e) { + LOG.error("Unable to parse JWT claim set", e); + } + + if (groupsClaim == null) { + LOG.warn("No group claim found!"); + return new HashSet<>(); + } + + Set groups = new HashSet<>(groupsClaim); + if (LOG.isDebugEnabled()) { + LOG.debug("Groups present in Claim [{}]: {}", CUSTOM_JWT_CLAIM_GROUP_KEY_VALUE, groups); + } + return groups; + } + protected String getJWT(final String jwtAuthHeader, final String jwtCookie) { String serializedJWT = null; @@ -171,19 +210,18 @@ protected String getJWT(final String jwtAuthHeader, final String jwtCookie) { * implementation through submethods used within but also allows for the * override of the entire token validation algorithm. * - * @param jwtToken the token to validate * @return true if valid */ - protected boolean validateToken(final SignedJWT jwtToken) { - boolean expValid = validateExpiration(jwtToken); + protected boolean validateToken() throws ParseException { + boolean expValid = validateExpiration(); boolean sigValid = false; boolean audValid = false; if (expValid) { - sigValid = validateSignature(jwtToken); + sigValid = validateSignature(); if (sigValid) { - audValid = validateAudiences(jwtToken); + audValid = validateAudiences(); } } @@ -195,41 +233,40 @@ protected boolean validateToken(final SignedJWT jwtToken) { } /** - * Verify the signature of the JWT token in this method. This method depends on + * Verify the signature of the JWT in this method. This method depends on * the public key that was established during init based upon the provisioned * public key. Override this method in subclasses in order to customize the * signature verification behavior. * - * @param jwtToken the token that contains the signature to be validated * @return valid true if signature verifies successfully; false otherwise */ - protected boolean validateSignature(final SignedJWT jwtToken) { + protected boolean validateSignature() { boolean valid = false; - if (JWSObject.State.SIGNED == jwtToken.getState()) { + if (JWSObject.State.SIGNED == signedJWT.getState()) { if (LOG.isDebugEnabled()) { - LOG.debug("JWT token is in a SIGNED state"); + LOG.debug("JWT is in a SIGNED state"); } - if (jwtToken.getSignature() != null) { + if (signedJWT.getSignature() != null) { try { if (StringUtils.isNotBlank(jwksProviderUrl)) { - JWSKeySelector keySelector = new JWSVerificationKeySelector<>(jwtToken.getHeader().getAlgorithm(), keySource); + JWSKeySelector keySelector = new JWSVerificationKeySelector<>(signedJWT.getHeader().getAlgorithm(), keySource); // Create a JWT processor for the access tokens ConfigurableJWTProcessor jwtProcessor = getJwtProcessor(keySelector); // Process the token - jwtProcessor.process(jwtToken, null); + jwtProcessor.process(signedJWT, null); valid = true; if (LOG.isDebugEnabled()) { - LOG.debug("JWT token has been successfully verified."); + LOG.debug("JWT has been successfully verified."); } } else if (verifier != null) { - if (jwtToken.verify(verifier)) { + if (signedJWT.verify(verifier)) { valid = true; if (LOG.isDebugEnabled()) { - LOG.debug("JWT token has been successfully verified."); + LOG.debug("JWT has been successfully verified."); } } else { LOG.warn("JWT signature verification failed."); @@ -257,61 +294,51 @@ protected boolean validateSignature(final SignedJWT jwtToken) { * token claims list for audience. Override this method in subclasses in order * to customize the audience validation behavior. * - * @param jwtToken the JWT token where the allowed audiences will be found * @return true if an expected audience is present, otherwise false */ - protected boolean validateAudiences(final SignedJWT jwtToken) { + protected boolean validateAudiences() throws ParseException { boolean valid = false; - try { - List tokenAudienceList = jwtToken.getJWTClaimsSet().getAudience(); - // if there were no expected audiences configured then just - // consider any audience acceptable - if (audiences == null) { - valid = true; - } else { - // if any of the configured audiences is found then consider it - // acceptable - for (String aud : tokenAudienceList) { - if (audiences.contains(aud)) { - if (LOG.isDebugEnabled()) { - LOG.debug("JWT token audience has been successfully validated."); - } - valid = true; - break; + JWTClaimsSet claimsSet = getJWTClaimsSet(); + List tokenAudienceList = claimsSet.getAudience(); + // if there were no expected audiences configured then just consider any audience acceptable + if (audiences == null) { + valid = true; + } else { + // if any of the configured audiences is found then consider it acceptable + for (String aud : tokenAudienceList) { + if (audiences.contains(aud)) { + if (LOG.isDebugEnabled()) { + LOG.debug("JWT audience has been successfully validated."); } - } - if (!valid) { - LOG.warn("JWT audience validation failed."); + valid = true; + break; } } - } catch (ParseException pe) { - LOG.warn("Unable to parse the JWT token.", pe); + } + + if (!valid) { + LOG.warn("JWT audience validation failed."); } return valid; } /** - * Validate that the expiration time of the JWT token has not been violated. If - * it has then throw an AuthenticationException. Override this method in + * Validate that the expiration time of the JWT has not been violated. If + * it has, then throw an AuthenticationException. Override this method in * subclasses in order to customize the expiration validation behavior. * - * @param jwtToken the token that contains the expiration date to validate * @return valid true if the token has not expired; false otherwise */ - protected boolean validateExpiration(final SignedJWT jwtToken) { + protected boolean validateExpiration() throws ParseException { boolean valid = false; - try { - Date expires = jwtToken.getJWTClaimsSet().getExpirationTime(); - if (expires == null || new Date().before(expires)) { - valid = true; - if (LOG.isDebugEnabled()) { - LOG.debug("JWT token expiration date has been successfully validated."); - } - } else { - LOG.warn("JWT token provided is expired."); + Date expires = getJWTClaimsSet().getExpirationTime(); + if (expires == null || new Date().before(expires)) { + valid = true; + if (LOG.isDebugEnabled()) { + LOG.debug("JWT expiration date has been successfully validated."); } - } catch (ParseException pe) { - LOG.warn("Failed to validate JWT expiry.", pe); + } else { + LOG.warn("JWT provided has expired."); } return valid; diff --git a/security-admin/src/main/java/org/apache/ranger/security/web/filter/RangerJwtAuthFilter.java b/security-admin/src/main/java/org/apache/ranger/security/web/filter/RangerJwtAuthFilter.java index f14adaaa8d..27acdd7dd7 100644 --- a/security-admin/src/main/java/org/apache/ranger/security/web/filter/RangerJwtAuthFilter.java +++ b/security-admin/src/main/java/org/apache/ranger/security/web/filter/RangerJwtAuthFilter.java @@ -19,6 +19,7 @@ package org.apache.ranger.security.web.filter; import java.io.IOException; +import java.util.Set; import java.util.Arrays; import java.util.List; import java.util.Properties; @@ -33,6 +34,7 @@ import javax.servlet.http.HttpServletRequest; import org.apache.log4j.Logger; +import org.apache.ranger.authz.authority.JwtAuthority; import org.apache.ranger.authz.handler.RangerAuth; import org.apache.ranger.authz.handler.jwt.RangerDefaultJwtAuthHandler; import org.apache.ranger.authz.handler.jwt.RangerJwtAuthHandler; @@ -42,7 +44,6 @@ import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.GrantedAuthority; -import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.userdetails.User; import org.springframework.security.core.userdetails.UserDetails; @@ -101,7 +102,8 @@ public void doFilter(ServletRequest request, ServletResponse response, FilterCha RangerAuth rangerAuth = authenticate(httpServletRequest); if (rangerAuth != null) { - final List grantedAuths = Arrays.asList(new SimpleGrantedAuthority(DEFAULT_RANGER_ROLE)); + final Set groups = getGroupsFromClaimSet(); + final List grantedAuths = Arrays.asList(new JwtAuthority(DEFAULT_RANGER_ROLE, groups)); final UserDetails principal = new User(rangerAuth.getUserName(), "", grantedAuths); final Authentication finalAuthentication = new UsernamePasswordAuthenticationToken(principal, "", grantedAuths); final WebAuthenticationDetails webDetails = new WebAuthenticationDetails(httpServletRequest);