Skip to content

Commit

Permalink
Merge pull request #2621 from Thumimku/clean-oidc-session-v1
Browse files Browse the repository at this point in the history
[Clean Up] Clean OIDC session code v1
  • Loading branch information
Thumimku authored Nov 11, 2024
2 parents adc38b6 + 6303c3c commit f4d499f
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 100 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
*/
public class ClaimProviderImpl implements ClaimProvider {

private static final Log log = LogFactory.getLog(ClaimProviderImpl.class);
private static final Log LOG = LogFactory.getLog(ClaimProviderImpl.class);

@Override
public Map<String, Object> getAdditionalClaims(OAuthAuthzReqMessageContext oAuthAuthzReqMessageContext,
Expand All @@ -61,15 +61,11 @@ public Map<String, Object> getAdditionalClaims(OAuthAuthzReqMessageContext oAuth
if (previousSession == null) {
// If there is no previous browser session, generate new sid value.
claimValue = UUID.randomUUID().toString();
if (log.isDebugEnabled()) {
log.debug("sid claim is generated for auth request. ");
}
LOG.debug("sid claim is generated for auth request.");
} else {
// Previous browser session exists, get sid claim from OIDCSessionState.
claimValue = previousSession.getSidClaim();
if (log.isDebugEnabled()) {
log.debug("sid claim is found in the session state");
}
LOG.debug("sid claim is found in the session state.");
}
additionalClaims.put(OAuthConstants.OIDCClaims.SESSION_ID_CLAIM, claimValue);
oAuth2AuthorizeRespDTO.setOidcSessionId(claimValue);
Expand Down Expand Up @@ -104,16 +100,12 @@ public Map<String, Object> getAdditionalClaims(OAuthTokenReqMessageContext oAuth
claimValue = previousSession.getSidClaim();
}
} else {
if (log.isDebugEnabled()) {
log.debug("AccessCode is null. Possibly a back end grant");
}
LOG.debug("AccessCode is null. Possibly a back end grant");
return additionalClaims;
}

if (claimValue != null) {
if (log.isDebugEnabled()) {
log.debug("sid claim is found in the session state");
}
LOG.debug("sid claim is found in the session state");
additionalClaims.put("sid", claimValue);
}
return additionalClaims;
Expand All @@ -122,7 +114,7 @@ public Map<String, Object> getAdditionalClaims(OAuthTokenReqMessageContext oAuth
/**
* Return previousSessionState using opbs cookie.
*
* @param oAuthAuthzReqMessageContext
* @param oAuthAuthzReqMessageContext OAuthAuthzReqMessageContext.
* @return OIDCSession state
*/
private OIDCSessionState getSessionState(OAuthAuthzReqMessageContext oAuthAuthzReqMessageContext) {
Expand All @@ -131,10 +123,9 @@ private OIDCSessionState getSessionState(OAuthAuthzReqMessageContext oAuthAuthzR
if (cookies != null) {
for (Cookie cookie : cookies) {
if (OIDCSessionConstants.OPBS_COOKIE_ID.equals(cookie.getName())) {
OIDCSessionState previousSessionState = OIDCSessionManagementUtil.getSessionManager()
return OIDCSessionManagementUtil.getSessionManager()
.getOIDCSessionState(cookie.getValue(), oAuthAuthzReqMessageContext.
getAuthorizationReqDTO().getLoggedInTenantDomain());
return previousSessionState;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import org.apache.commons.logging.LogFactory;
import org.wso2.carbon.base.MultitenantConstants;
import org.wso2.carbon.core.util.KeyStoreManager;
import org.wso2.carbon.identity.application.common.model.IdentityProvider;
import org.wso2.carbon.identity.core.ServiceURLBuilder;
import org.wso2.carbon.identity.core.URLBuilderException;
import org.wso2.carbon.identity.core.util.IdentityTenantUtil;
Expand All @@ -42,8 +41,6 @@
import org.wso2.carbon.identity.oidc.session.OIDCSessionConstants;
import org.wso2.carbon.identity.oidc.session.OIDCSessionState;
import org.wso2.carbon.identity.oidc.session.util.OIDCSessionManagementUtil;
import org.wso2.carbon.idp.mgt.IdentityProviderManagementException;
import org.wso2.carbon.idp.mgt.IdentityProviderManager;
import org.wso2.carbon.utils.security.KeystoreUtils;

import java.security.interfaces.RSAPublicKey;
Expand All @@ -70,7 +67,7 @@
*/
public class DefaultLogoutTokenBuilder implements LogoutTokenBuilder {

private static final Log log = LogFactory.getLog(DefaultLogoutTokenBuilder.class);
private static final Log LOG = LogFactory.getLog(DefaultLogoutTokenBuilder.class);
private OAuthServerConfiguration config = null;
private JWSAlgorithm signatureAlgorithm = null;
private static final String OPENID_IDP_ENTITY_ID = "IdPEntityId";
Expand All @@ -85,6 +82,7 @@ public DefaultLogoutTokenBuilder() throws IdentityOAuth2Exception {
}

@Override
@Deprecated
public Map<String, String> buildLogoutToken(HttpServletRequest request)
throws IdentityOAuth2Exception, InvalidOAuthClientException {

Expand All @@ -102,8 +100,8 @@ public Map<String, String> buildLogoutToken(HttpServletRequest request)
try {
oAuthAppDO = getOAuthAppDO(clientID);
} catch (InvalidOAuthClientException e) {
if (log.isDebugEnabled()) {
log.debug("The application with client id: " + clientID
if (LOG.isDebugEnabled()) {
LOG.debug("The application with client id: " + clientID
+ " does not exists. This application may be deleted after"
+ " this session is created. So skipping it in logout token list.", e);
}
Expand Down Expand Up @@ -154,8 +152,8 @@ private void addToLogoutTokenList(Map<String, String> logoutTokenList,
try {
oAuthAppDO = getOAuthAppDO(clientID);
} catch (InvalidOAuthClientException e) {
if (log.isDebugEnabled()) {
log.debug("The application with client id: " + clientID
if (LOG.isDebugEnabled()) {
LOG.debug("The application with client id: " + clientID
+ " does not exists. This application may be deleted after"
+ " this session is created. So skipping it in logout token list.", e);
}
Expand All @@ -169,8 +167,8 @@ private void addToLogoutTokenList(Map<String, String> logoutTokenList,
getSigningTenantDomain(oAuthAppDO)).serialize();
logoutTokenList.put(logoutToken, backChannelLogoutUrl);

if (log.isDebugEnabled()) {
log.debug("Logout token created for the client: " + clientID);
if (LOG.isDebugEnabled()) {
LOG.debug("Logout token created for the client: " + clientID);
}
}
}
Expand Down Expand Up @@ -231,21 +229,21 @@ private String getClientId(HttpServletRequest request, String tenantDomain)
JWT decryptedIDToken = OIDCSessionManagementUtil.decryptWithRSA(tenantDomain, idToken);
clientId = OIDCSessionManagementUtil.extractClientIDFromDecryptedIDToken(decryptedIDToken);
} catch (ParseException e) {
if (log.isDebugEnabled()) {
log.debug("Error in extracting the client ID from the ID token : " + idToken);
if (LOG.isDebugEnabled()) {
LOG.debug("Error in extracting the client ID from the ID token : " + idToken);
}
}
return clientId;
}
clientId = getClientIdFromIDTokenHint(idToken);
} else {
log.debug("IdTokenHint is not found in the request ");
LOG.debug("IdTokenHint is not found in the request ");
return null;
}
if (validateIdTokenHint(clientId, idToken)) {
return clientId;
} else {
log.debug("Id Token is not valid");
LOG.debug("Id Token is not valid");
return null;
}
}
Expand Down Expand Up @@ -311,16 +309,6 @@ private String getSidClaim(OIDCSessionState sessionState) {
return sidClaim;
}

private IdentityProvider getResidentIdp(String tenantDomain) throws IdentityOAuth2Exception {

try {
return IdentityProviderManager.getInstance().getResidentIdP(tenantDomain);
} catch (IdentityProviderManagementException e) {
String errorMsg = String.format(ERROR_GET_RESIDENT_IDP, tenantDomain);
throw new IdentityOAuth2Exception(errorMsg, e);
}
}

/**
* Returning issuer of the tenant domain.
*
Expand Down Expand Up @@ -428,8 +416,8 @@ private String getClientIdFromIDTokenHint(String idTokenHint) {
try {
clientId = extractClientFromIdToken(idTokenHint);
} catch (ParseException e) {
if (log.isDebugEnabled()) {
log.debug("Error while decoding the ID Token Hint: " + idTokenHint, e);
if (LOG.isDebugEnabled()) {
LOG.debug("Error while decoding the ID Token Hint: " + idTokenHint, e);
}
}
}
Expand Down Expand Up @@ -482,12 +470,12 @@ private Boolean validateIdTokenHint(String clientId, String idToken) throws Iden

return signedJWT.verify(verifier);
} catch (JOSEException | ParseException e) {
if (log.isDebugEnabled()) {
log.debug("Error occurred while validating id token signature.", e);
if (LOG.isDebugEnabled()) {
LOG.debug("Error occurred while validating id token signature.", e);
}
return false;
} catch (Exception e) {
log.error("Error occurred while validating id token signature.", e);
LOG.error("Error occurred while validating id token signature.", e);
return false;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
*/
public class LogoutRequestSender {

private static final Log log = LogFactory.getLog(LogoutRequestSender.class);
private static final Log LOG = LogFactory.getLog(LogoutRequestSender.class);
private static ExecutorService threadPool = Executors.newFixedThreadPool(2);
private static LogoutRequestSender instance = new LogoutRequestSender();
private static final String LOGOUT_TOKEN = "logout_token";
Expand Down Expand Up @@ -85,7 +85,7 @@ public void sendLogoutRequests(HttpServletRequest request) {
if (opbsCookie != null) {
sendLogoutRequests(opbsCookie.getValue());
} else {
log.error("No opbscookie exists in the request");
LOG.error("No opbscookie exists in the request");
}
}

Expand Down Expand Up @@ -118,11 +118,8 @@ public void sendLogoutRequests(String opbsCookieId, String tenantDomain) {
for (Map.Entry<String, String> logoutTokenMap : logoutTokenList.entrySet()) {
String logoutToken = logoutTokenMap.getKey();
String bcLogoutUrl = logoutTokenMap.getValue();
LOG.debug("A logoutReqSenderTask will be assigned to the thread pool");
threadPool.submit(new LogoutReqSenderTask(logoutToken, bcLogoutUrl));
if (log.isDebugEnabled()) {
log.debug("A logoutReqSenderTask is assigned to the thread pool");

}
}
}
}
Expand All @@ -140,10 +137,10 @@ private Map<String, String> getLogoutTokenList(String opbsCookie, String tenantD
DefaultLogoutTokenBuilder logoutTokenBuilder = new DefaultLogoutTokenBuilder();
logoutTokenList = logoutTokenBuilder.buildLogoutToken(opbsCookie, tenantDomain);
} catch (IdentityOAuth2Exception e) {
log.error("Error while initializing " + DefaultLogoutTokenBuilder.class, e);
LOG.error("Error while initializing " + DefaultLogoutTokenBuilder.class, e);
} catch (InvalidOAuthClientException e) {
if (log.isDebugEnabled()) {
log.debug("Error while obtaining logout token list for the obpsCookie: " + opbsCookie +
if (LOG.isDebugEnabled()) {
LOG.debug("Error while obtaining logout token list for the obpsCookie: " + opbsCookie +
"& tenant domain: " + tenantDomain, e);
}
}
Expand All @@ -169,8 +166,8 @@ public LogoutReqSenderTask(String logoutToken, String backChannelLogouturl) {
@Override
public void run() {

if (log.isDebugEnabled()) {
log.debug("Starting backchannel logout request to: " + backChannelLogouturl);
if (LOG.isDebugEnabled()) {
LOG.debug("Starting backchannel logout request to: " + backChannelLogouturl);
}

List<NameValuePair> logoutReqParams = new ArrayList<NameValuePair>();
Expand All @@ -194,15 +191,15 @@ public void run() {
try {
httpPost.setEntity(new UrlEncodedFormEntity(logoutReqParams));
} catch (UnsupportedEncodingException e) {
log.error("Error while sending logout token", e);
LOG.error("Error while sending logout token", e);
}
HttpResponse response = httpClient.execute(httpPost);
if (log.isDebugEnabled()) {
log.debug("Backchannel logout response: " + response.getStatusLine());
if (LOG.isDebugEnabled()) {
LOG.debug("Backchannel logout response: " + response.getStatusLine());
}

} catch (IOException e) {
log.error("Error sending logout requests to: " + backChannelLogouturl, e);
LOG.error("Error sending logout requests to: " + backChannelLogouturl, e);
}
}
}
Expand Down
Loading

0 comments on commit f4d499f

Please sign in to comment.