Skip to content

Commit

Permalink
Expose a JWKS endpoint for gateway and improve backend JWT to include…
Browse files Browse the repository at this point in the history
… kid claim
  • Loading branch information
ashera96 committed May 19, 2023
1 parent fb35b75 commit ed8c84c
Show file tree
Hide file tree
Showing 14 changed files with 523 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,6 @@ public class JWTConstants {

public static final String SUB = "sub";
public static final String ORGANIZATIONS = "organizations";
public static final String GATEWAY_JWKS_API_CONTEXT = "/jwks";
public static final String GATEWAY_JWKS_API_NAME = "_JwksEndpoint_";
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import java.nio.charset.Charset;
import java.security.PrivateKey;
import java.security.cert.Certificate;
import java.security.cert.X509Certificate;
import java.util.Date;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -104,13 +105,19 @@ public String generateToken(JWTInfoDto jwtInfoDto) throws JWTGeneratorException

public String buildHeader() throws JWTGeneratorException {
String jwtHeader = null;
X509Certificate x509Certificate = (X509Certificate) jwtConfigurationDto.getPublicCert();

if (NONE.equals(signatureAlgorithm)) {
StringBuilder jwtHeaderBuilder = new StringBuilder();
jwtHeaderBuilder.append("{\"typ\":\"JWT\",");
jwtHeaderBuilder.append("\"alg\":\"");
jwtHeaderBuilder.append(JWTUtil.getJWSCompliantAlgorithmCode(NONE));
jwtHeaderBuilder.append('\"');
if (jwtConfigurationDto.useKid()) {
jwtHeaderBuilder.append(",\"kid\":\"");
jwtHeaderBuilder.append(JWTUtil.getKID(x509Certificate));
jwtHeaderBuilder.append("\"");
}
jwtHeaderBuilder.append('}');

jwtHeader = jwtHeaderBuilder.toString();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import java.security.SignatureException;
import java.security.cert.Certificate;
import java.security.cert.CertificateEncodingException;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -81,14 +82,14 @@ public static String generateHeader(Certificate publicCert, String signatureAlgo
/**
* Utility method to generate JWT header with public certificate thumbprint for signature verification.
*
* @param publicCert - The public certificate which needs to include in the header as thumbprint
* @param signatureAlgorithm signature algorithm which needs to include in the header
* @param useKid - Specifies whether this function should use kid as a thumbprint or x5t
* @param publicCert The public certificate which needs to include in the header as thumbprint
* @param signatureAlgorithm Signature algorithm which needs to include in the header
* @param useKid Specifies whether the header should include the kid property
* @throws JWTGeneratorException
*/

public static String generateHeader(Certificate publicCert, String signatureAlgorithm, boolean useKid) throws
JWTGeneratorException {
public static String generateHeader(Certificate publicCert, String signatureAlgorithm, boolean useKid)
throws JWTGeneratorException {

/*
* Sample header
Expand All @@ -97,21 +98,34 @@ public static String generateHeader(Certificate publicCert, String signatureAlgo
* {"typ":"JWT", "alg":"[2]", "x5t":"[1]", "x5t":"[1]"}
* */
try {
X509Certificate x509Certificate = (X509Certificate) publicCert;

//generate the SHA-1 thumbprint of the certificate
MessageDigest digestValue = MessageDigest.getInstance("SHA-1");
byte[] der = publicCert.getEncoded();
digestValue.update(der);
byte[] digestInBytes = digestValue.digest();
String publicCertThumbprint = hexify(digestInBytes);
String base64UrlEncodedThumbPrint;
base64UrlEncodedThumbPrint = java.util.Base64.getUrlEncoder()
.encodeToString(publicCertThumbprint.getBytes("UTF-8"));

StringBuilder jwtHeader = new StringBuilder();
jwtHeader.append("{\"typ\":\"JWT\",");
jwtHeader.append("\"alg\":\"");
jwtHeader.append(getJWSCompliantAlgorithmCode(signatureAlgorithm));
jwtHeader.append("\",");
jwtHeader.append("\"x5t\":\"");
jwtHeader.append(base64UrlEncodedThumbPrint);
jwtHeader.append("\"");

if (useKid) {
jwtHeader.append("\"kid\":\"");
// No padding
jwtHeader.append(generateThumbprint("SHA-256", publicCert, false));
} else {
jwtHeader.append("\"x5t\":\"");
// Has padding for legacy support
jwtHeader.append(generateThumbprint("SHA-1", publicCert, true));
jwtHeader.append(",\"kid\":\"");
jwtHeader.append(getKID(x509Certificate));
jwtHeader.append("\"");
}
jwtHeader.append("\"}");

jwtHeader.append("}");
return jwtHeader.toString();
} catch (NoSuchAlgorithmException | CertificateEncodingException | UnsupportedEncodingException e) {
throw new JWTGeneratorException("Error in generating public certificate thumbprint", e);
Expand Down Expand Up @@ -157,6 +171,19 @@ public static String hexify(byte bytes[]) {
return buf.toString();
}

/**
* Helper method to add kid claim into to JWT_HEADER.
*
* @param cert X509 certificate
* @return KID
*/
public static String getKID(X509Certificate cert) {
String serialNumber = cert.getSerialNumber().toString();
String issuerName = cert.getIssuerDN().getName();
String kid = issuerName + "#" + serialNumber;
return java.util.Base64.getUrlEncoder().withoutPadding().encodeToString(kid.getBytes(StandardCharsets.UTF_8));
}

/**
* Utility method to sign a JWT assertion with a particular signature algorithm.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,15 @@
import org.apache.commons.logging.LogFactory;
import org.apache.synapse.SynapseConstants;
import org.apache.synapse.transport.dynamicconfigurations.DynamicProfileReloaderHolder;
import org.wso2.carbon.apimgt.api.APIManagementException;
import org.wso2.carbon.apimgt.api.ExceptionCodes;
import org.wso2.carbon.apimgt.api.gateway.GatewayAPIDTO;
import org.wso2.carbon.apimgt.api.gateway.GatewayContentDTO;
import org.wso2.carbon.apimgt.api.gateway.GraphQLSchemaDTO;
import org.wso2.carbon.apimgt.api.model.API;
import org.wso2.carbon.apimgt.api.model.APIIdentifier;
import org.wso2.carbon.apimgt.api.model.APIProductIdentifier;
import org.wso2.carbon.apimgt.common.gateway.constants.JWTConstants;
import org.wso2.carbon.apimgt.gateway.internal.DataHolder;
import org.wso2.carbon.apimgt.gateway.internal.ServiceReferenceHolder;
import org.wso2.carbon.apimgt.gateway.service.APIGatewayAdmin;
Expand Down Expand Up @@ -194,6 +197,11 @@ public boolean deployAllAPIsAtGatewayStartup(Set<String> assignedGatewayLabels,
throws ArtifactSynchronizerException {

boolean result = false;
try {
deployJWKSSynapseAPI(tenantDomain); // Deploy JWKS API
} catch (APIManagementException e) {
log.error("Error while deploying JWKS API for tenant domain :" + tenantDomain, e);
}

if (gatewayArtifactSynchronizerProperties.isRetrieveFromStorageEnabled()) {
if (artifactRetriever != null) {
Expand Down Expand Up @@ -440,4 +448,56 @@ public void unDeployAPI(String apiName, String version, String tenantDomain) thr
}
}
}

/**
* Deploy Synapse API for JWKS endpoint
*
* @param tenantDomain tenant domain
*/
public static void deployJWKSSynapseAPI(String tenantDomain) throws APIManagementException {
String api = org.wso2.carbon.apimgt.gateway.utils.GatewayUtils.retrieveDeployedAPI(
JWTConstants.GATEWAY_JWKS_API_NAME, null, tenantDomain);
if (api == null) {
try {
// Deploy JWKS API for tenant
MessageContext.setCurrentMessageContext(
org.wso2.carbon.apimgt.gateway.utils.GatewayUtils.createAxis2MessageContext());
PrivilegedCarbonContext.startTenantFlow();
PrivilegedCarbonContext.getThreadLocalCarbonContext().setTenantDomain(tenantDomain, true);
GatewayAPIDTO jwksAPIDto = new GatewayAPIDTO();
String jwksApiContext;
if (tenantDomain != null && !APIConstants.SUPER_TENANT_DOMAIN.equals(tenantDomain)) {
jwksApiContext = "/t/" + tenantDomain + JWTConstants.GATEWAY_JWKS_API_CONTEXT;
} else {
jwksApiContext = JWTConstants.GATEWAY_JWKS_API_CONTEXT;
}
String jwksSynapseAPI = "<api xmlns=\"http://ws.apache.org/ns/synapse\" name=\"_JwksEndpoint_\" "
+ "context=\"" + jwksApiContext + "\">\n"
+ " <resource methods=\"GET\" url-mapping=\"/*\" faultSequence=\"fault\">\n"
+ " <inSequence>\n"
+ " <respond/>\n"
+ " </inSequence>\n"
+ " </resource>\n"
+ " <handlers>\n"
+ " <handler class=\"org.wso2.carbon.apimgt.gateway.handlers.common.JwksHandler\"/>\n"
+ " </handlers>\n"
+ "</api>\n";

jwksAPIDto.setName(JWTConstants.GATEWAY_JWKS_API_NAME);
jwksAPIDto.setTenantDomain(tenantDomain);
jwksAPIDto.setApiDefinition(jwksSynapseAPI);

log.info("Deploying synapse artifacts of " + jwksAPIDto.getName());
APIGatewayAdmin apiGatewayAdmin = new APIGatewayAdmin();
apiGatewayAdmin.deployAPI(jwksAPIDto);
DataHolder.getInstance().markAPIAsDeployed(jwksAPIDto);
} catch (AxisFault axisFault) {
throw new APIManagementException("Error while retrieving JWKS API Artifact", axisFault,
ExceptionCodes.INTERNAL_ERROR);
} finally {
MessageContext.destroyCurrentMessageContext();
PrivilegedCarbonContext.endTenantFlow();
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import org.apache.synapse.api.ApiUtils;
import org.apache.synapse.core.axis2.Axis2MessageContext;
import org.apache.synapse.rest.RESTConstants;
import org.wso2.carbon.apimgt.api.APIManagementException;
import org.wso2.carbon.apimgt.common.gateway.constants.JWTConstants;
import org.wso2.carbon.apimgt.gateway.InMemoryAPIDeployer;
import org.wso2.carbon.apimgt.gateway.internal.ServiceReferenceHolder;
import org.wso2.carbon.apimgt.gateway.utils.GatewayUtils;
Expand All @@ -34,9 +36,6 @@
import org.wso2.carbon.apimgt.keymgt.model.entity.API;
import org.wso2.carbon.inbound.endpoint.protocol.websocket.InboundWebsocketConstants;

import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;

/**
Expand All @@ -49,10 +48,23 @@ public boolean handleRequestInFlow(MessageContext messageContext) {
if (messageContext.getPropertyKeySet().contains(InboundWebsocketConstants.WEBSOCKET_SUBSCRIBER_PATH)) {
return true;
}
String path = ApiUtils.getFullRequestPath(messageContext);
String tenantDomain = GatewayUtils.getTenantDomain();

// Handle JWKS API calls
if (path.contains(JWTConstants.GATEWAY_JWKS_API_CONTEXT)) {
try {
InMemoryAPIDeployer.deployJWKSSynapseAPI(tenantDomain);
} catch(APIManagementException e){
log.error("Error while deploying JWKS API for tenant domain :" + tenantDomain, e);
}
return true;
}

org.apache.axis2.context.MessageContext axis2MessageContext =
((Axis2MessageContext) messageContext).getAxis2MessageContext();
String path = ApiUtils.getFullRequestPath(messageContext);
TreeMap<String, API> selectedAPIS = Utils.getSelectedAPIList(path, GatewayUtils.getTenantDomain());
TreeMap<String, API> selectedAPIS = Utils.getSelectedAPIList(path, tenantDomain);

if (selectedAPIS.size() > 0) {
Object transportInUrl = axis2MessageContext.getProperty(APIConstants.TRANSPORT_URL_IN);
String selectedPath = selectedAPIS.firstKey();
Expand Down
Loading

0 comments on commit ed8c84c

Please sign in to comment.