Skip to content

Commit

Permalink
add new option for id token request to require all scope items to be …
Browse files Browse the repository at this point in the history
…present

Signed-off-by: Henry Avetisyan <hga@yahooinc.com>
  • Loading branch information
havetisyan committed Jul 17, 2024
1 parent fc8b1a8 commit a979629
Show file tree
Hide file tree
Showing 14 changed files with 176 additions and 54 deletions.
4 changes: 2 additions & 2 deletions clients/go/zts/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1157,9 +1157,9 @@ func (client ZTSClient) PostAccessTokenRequest(request AccessTokenRequest) (*Acc
}
}

func (client ZTSClient) GetOIDCResponse(responseType string, clientId ServiceName, redirectUri string, scope string, state EntityName, nonce EntityName, keyType SimpleName, fullArn *bool, expiryTime *int32, output SimpleName, roleInAudClaim *bool) (*OIDCResponse, string, error) {
func (client ZTSClient) GetOIDCResponse(responseType string, clientId ServiceName, redirectUri string, scope string, state EntityName, nonce EntityName, keyType SimpleName, fullArn *bool, expiryTime *int32, output SimpleName, roleInAudClaim *bool, allScopePresent *bool) (*OIDCResponse, string, error) {
var data *OIDCResponse
url := client.URL + "/oauth2/auth" + encodeParams(encodeStringParam("response_type", string(responseType), ""), encodeStringParam("client_id", string(clientId), ""), encodeStringParam("redirect_uri", string(redirectUri), ""), encodeStringParam("scope", string(scope), ""), encodeStringParam("state", string(state), ""), encodeStringParam("nonce", string(nonce), ""), encodeStringParam("keyType", string(keyType), ""), encodeOptionalBoolParam("fullArn", fullArn), encodeOptionalInt32Param("expiryTime", expiryTime), encodeStringParam("output", string(output), ""), encodeOptionalBoolParam("roleInAudClaim", roleInAudClaim))
url := client.URL + "/oauth2/auth" + encodeParams(encodeStringParam("response_type", string(responseType), ""), encodeStringParam("client_id", string(clientId), ""), encodeStringParam("redirect_uri", string(redirectUri), ""), encodeStringParam("scope", string(scope), ""), encodeStringParam("state", string(state), ""), encodeStringParam("nonce", string(nonce), ""), encodeStringParam("keyType", string(keyType), ""), encodeOptionalBoolParam("fullArn", fullArn), encodeOptionalInt32Param("expiryTime", expiryTime), encodeStringParam("output", string(output), ""), encodeOptionalBoolParam("roleInAudClaim", roleInAudClaim), encodeOptionalBoolParam("allScopePresent", allScopePresent))
resp, err := client.httpGet(url, nil)
if err != nil {
return nil, "", err
Expand Down
1 change: 1 addition & 0 deletions clients/go/zts/zts_schema.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -3307,6 +3307,29 @@ public OIDCResponse getIDToken(String domainName, List<String> roleNames, String
*/
public OIDCResponse getIDToken(String responseType, String clientId, String redirectUri, String scope, String state,
String keyType, Boolean fullArn, Integer expiryTime, boolean ignoreCache) {
return getIDToken(responseType, clientId, redirectUri, scope, state, keyType, fullArn,
expiryTime, false, ignoreCache);
}

/**
* For the specified requester(user/service) return the corresponding Access Token that
* includes the list of roles that the principal has access to in the specified domain
* @param responseType response object type - only id_token is supported for now
* @param clientId name of the audience service name (e.g. sys.auth.gcp)
* @param redirectUri the redirect uri for the request
* @param scope the scope of the request e.g. "openid sports.api:roles.hockey-writers"
* @param state the state component of the location header. could be empty
* @param keyType the private key type to sign the token - possible values are "RSA" or "EC"
* @param fullArn boolean flag indicating whether the groups claim in the token contains only the
* role names or the full names including domains (e.g. sports.api:role.hockey-writers).
* @param expiryTime (optional) specifies that the returned Access must be
* at least valid for specified number of seconds. Pass 0 to use
* server default timeout.
* @param allRolesPresent boolean flag indicating that all roles specifies in the scope must be present
* @return ZTS generated ID Token String. ZTSClientException will be thrown in case of failure
*/
public OIDCResponse getIDToken(String responseType, String clientId, String redirectUri, String scope, String state,
String keyType, Boolean fullArn, Integer expiryTime, Boolean allRolesPresent, boolean ignoreCache) {

// check for required attributes

Expand Down Expand Up @@ -3347,7 +3370,8 @@ public OIDCResponse getIDToken(String responseType, String clientId, String redi
try {
Map<String, List<String>> responseHeaders = new HashMap<>();
oidcResponse = ztsClient.getOIDCResponse(responseType, clientId, redirectUri, scope,
state, Crypto.randomSalt(), keyType, fullArn, expiryTime, "json", false, responseHeaders);
state, Crypto.randomSalt(), keyType, fullArn, expiryTime, "json", false,
allRolesPresent, responseHeaders);

} catch (ResourceException ex) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -933,7 +933,7 @@ public AccessTokenResponse postAccessTokenRequest(String request) throws URISynt
}
}

public OIDCResponse getOIDCResponse(String responseType, String clientId, String redirectUri, String scope, String state, String nonce, String keyType, Boolean fullArn, Integer expiryTime, String output, Boolean roleInAudClaim, java.util.Map<String, java.util.List<String>> headers) throws URISyntaxException, IOException {
public OIDCResponse getOIDCResponse(String responseType, String clientId, String redirectUri, String scope, String state, String nonce, String keyType, Boolean fullArn, Integer expiryTime, String output, Boolean roleInAudClaim, Boolean allScopePresent, java.util.Map<String, java.util.List<String>> headers) throws URISyntaxException, IOException {
UriTemplateBuilder uriTemplateBuilder = new UriTemplateBuilder(baseUrl, "/oauth2/auth");
URIBuilder uriBuilder = new URIBuilder(uriTemplateBuilder.getUri());
if (responseType != null) {
Expand Down Expand Up @@ -969,6 +969,9 @@ public OIDCResponse getOIDCResponse(String responseType, String clientId, String
if (roleInAudClaim != null) {
uriBuilder.setParameter("roleInAudClaim", String.valueOf(roleInAudClaim));
}
if (allScopePresent != null) {
uriBuilder.setParameter("allScopePresent", String.valueOf(allScopePresent));
}
HttpUriRequest httpUriRequest = RequestBuilder.get()
.setUri(uriBuilder.build())
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ public AccessTokenResponse postAccessTokenRequest(String request) {
@Override
public OIDCResponse getOIDCResponse(String responseType, String clientId, String redirectUri, String scope,
String state, String nonce, String keyType, Boolean fullArn, Integer expiryTime,
String output, Boolean roleInAudClaim, Map<String, List<String>> headers)
String output, Boolean roleInAudClaim, Boolean allScopePresent, Map<String, List<String>> headers)
throws URISyntaxException, IOException {

// some exception test cases based on the state value
Expand Down
1 change: 1 addition & 0 deletions core/zts/src/main/java/com/yahoo/athenz/zts/ZTSSchema.java
Original file line number Diff line number Diff line change
Expand Up @@ -1086,6 +1086,7 @@ private static Schema build() {
.queryParam("expiryTime", "expiryTime", "Int32", null, "optional expiry period specified in seconds")
.queryParam("output", "output", "SimpleName", null, "optional output format of json")
.queryParam("roleInAudClaim", "roleInAudClaim", "Bool", false, "flag to indicate to include role name in the audience claim only if we have a single role in response")
.queryParam("allScopePresent", "allScopePresent", "Bool", false, "flag to indicate that all requested roles/groups in the scope must be present in the response otherwise return an error")
.output("Location", "location", "String", "return location header with id token")
.auth("", "", true)
.expected("OK")
Expand Down
3 changes: 2 additions & 1 deletion core/zts/src/main/rdl/OAuth.rdli
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ resource AccessTokenResponse POST "/oauth2/token" {
}

// Fetch OAuth OpenID Connect ID Token
resource OIDCResponse GET "/oauth2/auth?response_type={responseType}&client_id={clientId}&redirect_uri={redirectUri}&scope={scope}&state={state}&nonce={nonce}&keyType={keyType}&fullArn={fullArn}&expiryTime={expiryTime}&output={output}&roleInAudClaim={roleInAudClaim}" {
resource OIDCResponse GET "/oauth2/auth?response_type={responseType}&client_id={clientId}&redirect_uri={redirectUri}&scope={scope}&state={state}&nonce={nonce}&keyType={keyType}&fullArn={fullArn}&expiryTime={expiryTime}&output={output}&roleInAudClaim={roleInAudClaim}&allScopePresent={allScopePresent}" {
String responseType; //response type - currently only supporting id tokens - id_token
ServiceName clientId; //client id - must be valid athenz service identity name
String redirectUri; //redirect uri for the response
Expand All @@ -57,6 +57,7 @@ resource OIDCResponse GET "/oauth2/auth?response_type={responseType}&client_id={
Int32 expiryTime (optional); //optional expiry period specified in seconds
SimpleName output (optional); //optional output format of json
Bool roleInAudClaim (optional, default=false); //flag to indicate to include role name in the audience claim only if we have a single role in response
Bool allScopePresent (optional, default=false); //flag to indicate that all requested roles/groups in the scope must be present in the response otherwise return an error
String location (header="Location", out); //return location header with id token
authenticate;
expected OK, FOUND;
Expand Down
4 changes: 2 additions & 2 deletions libs/go/athenzutils/idtoken.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
"time"
)

func FetchIdToken(ztsURL, svcKeyFile, svcCertFile, svcCACertFile, clientId, scope, nonce, state, keyType string, fullArn *bool, proxy bool, expireTime *int32, roleInAudClaim *bool) (string, error) {
func FetchIdToken(ztsURL, svcKeyFile, svcCertFile, svcCACertFile, clientId, scope, nonce, state, keyType string, fullArn *bool, proxy bool, expireTime *int32, roleInAudClaim, allScopesPresent *bool) (string, error) {

client, err := ZtsClient(ztsURL, svcKeyFile, svcCertFile, svcCACertFile, proxy)
if err != nil {
Expand All @@ -22,7 +22,7 @@ func FetchIdToken(ztsURL, svcKeyFile, svcCertFile, svcCACertFile, clientId, scop
client.DisableRedirect = true

// request an id token
response, _, err := client.GetOIDCResponse("id_token", zts.ServiceName(clientId), "", scope, zts.EntityName(state), zts.EntityName(nonce), zts.SimpleName(keyType), fullArn, expireTime, "json", roleInAudClaim)
response, _, err := client.GetOIDCResponse("id_token", zts.ServiceName(clientId), "", scope, zts.EntityName(state), zts.EntityName(nonce), zts.SimpleName(keyType), fullArn, expireTime, "json", roleInAudClaim, allScopesPresent)
if err != nil {
return "", err
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ public final class ZTSConsts {
public static final String ZTS_EXTERNAL_ATTR_SCOPE = "athenzScope";
public static final String ZTS_EXTERNAL_ATTR_FULL_ARN = "athenzFullArn";
public static final String ZTS_EXTERNAL_ATTR_ISSUER_OPTION = "athenzIssuerOption";
public static final String ZTS_EXTERNAL_ATTR_ALL_SCOPE_PRESENT = "athenzAllScopePresent";

public static final String ZTS_ISSUER_TYPE_OPENID = "openid";
public static final String ZTS_ISSUER_TYPE_OIDC_PORT = "oidc_port";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public interface ZTSHandler {
OAuthConfig getOAuthConfig(ResourceContext context);
JWKList getJWKList(ResourceContext context, Boolean rfc, String service);
AccessTokenResponse postAccessTokenRequest(ResourceContext context, String request);
Response getOIDCResponse(ResourceContext context, String responseType, String clientId, String redirectUri, String scope, String state, String nonce, String keyType, Boolean fullArn, Integer expiryTime, String output, Boolean roleInAudClaim);
Response getOIDCResponse(ResourceContext context, String responseType, String clientId, String redirectUri, String scope, String state, String nonce, String keyType, Boolean fullArn, Integer expiryTime, String output, Boolean roleInAudClaim, Boolean allScopePresent);
RoleCertificate postRoleCertificateRequestExt(ResourceContext context, RoleCertificateRequest req);
RoleAccess getRolesRequireRoleCert(ResourceContext context, String principal);
Workloads getWorkloadsByService(ResourceContext context, String domainName, String serviceName);
Expand Down
51 changes: 42 additions & 9 deletions servers/zts/src/main/java/com/yahoo/athenz/zts/ZTSImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -1998,8 +1998,8 @@ String getQueryLogData(final String request) {

@Override
public Response getOIDCResponse(ResourceContext ctx, String responseType, String clientId, String redirectUri,
String scope, String state, String nonce, String keyType, Boolean fullArn,
Integer timeout, String output, Boolean roleInAudClaim) {
String scope, String state, String nonce, String keyType, Boolean fullArn,
Integer timeout, String output, Boolean roleInAudClaim, Boolean allScopePresent) {

final String caller = ctx.getApiName();

Expand Down Expand Up @@ -2084,12 +2084,12 @@ public Response getOIDCResponse(ResourceContext ctx, String responseType, String
if (tokenRequest.isGroupsScope()) {

idTokenGroups = processIdTokenGroups(principalName, tokenRequest, domainName, true,
principalDomain, caller);
allScopePresent, principalDomain, caller);

} else if (tokenRequest.isRolesScope()) {

idTokenGroups = processIdTokenRoles(principalName, tokenRequest, domainName, fullArn,
principalDomain, caller);
allScopePresent, principalDomain, caller);
}

long iat = System.currentTimeMillis() / 1000;
Expand Down Expand Up @@ -2146,8 +2146,9 @@ String getIdTokenAudience(final String clientId, Boolean includeGroup, List<Stri
clientId + ":" + idTokenGroups.get(0) : clientId;
}

List<String> processIdTokenGroups(final String principalName, IdTokenRequest tokenRequest, final String clientIdDomainName,
Boolean fullArn, final String principalDomain, final String caller) {
List<String> processIdTokenGroups(final String principalName, IdTokenRequest tokenRequest,
final String clientIdDomainName, Boolean fullArn, Boolean allScopePresent,
final String principalDomain, final String caller) {

List<String> tokenGroups;
Set<String> domainNames = tokenRequest.getDomainNames();
Expand All @@ -2172,6 +2173,18 @@ List<String> processIdTokenGroups(final String principalName, IdTokenRequest tok
}
List<String> groups = processDomainIdTokenGroups(principalName, domainName,
groupNames, fullArn, principalDomain, caller);

// if we're asked to verify all scopes are present, then we need to
// make sure the number of groups returned matches the number of groups
// requested. If not, then we'll return an error

if (allScopePresent == Boolean.TRUE && groupNames != null) {
if (groups == null || groups.size() != groupNames.size()) {
throw forbiddenError("principal not included in all requested groups", caller,
clientIdDomainName, principalDomain);
}
}

if (groups != null) {
tokenGroups.addAll(groups);
}
Expand Down Expand Up @@ -2214,8 +2227,9 @@ List<String> processDomainIdTokenGroups(final String principalName, final String
return getIdTokenGroupsFromGroups(groups, domainName, fullArn);
}

List<String> processIdTokenRoles(final String principalName, IdTokenRequest tokenRequest, final String clientIdDomainName,
Boolean fullArn, final String principalDomain, final String caller) {
List<String> processIdTokenRoles(final String principalName, IdTokenRequest tokenRequest,
final String clientIdDomainName, Boolean fullArn, Boolean allScopePresent,
final String principalDomain, final String caller) {

List<String> tokenRoles;
Set<String> domainNames = tokenRequest.getDomainNames();
Expand All @@ -2234,12 +2248,25 @@ List<String> processIdTokenRoles(final String principalName, IdTokenRequest toke
boolean rolesRequested = false;
tokenRoles = new ArrayList<>();
for (String domainName : domainNames) {

String[] roleNames = tokenRequest.getRoleNames(domainName);
if (roleNames != null) {
rolesRequested = true;
}
List<String> roles = processDomainIdTokenRoles(principalName, domainName,
roleNames, fullArn, principalDomain, caller);

// if we're asked to verify all scopes are present, then we need to
// make sure the number of roles returned matches the number of roles
// requested. If not, then we'll return an error

if (allScopePresent == Boolean.TRUE && roleNames != null) {
if (roles == null || roles.size() != roleNames.length) {
throw forbiddenError("principal not included in all requested roles", caller,
clientIdDomainName, principalDomain);
}
}

if (roles != null) {
tokenRoles.addAll(roles);
}
Expand Down Expand Up @@ -4990,6 +5017,12 @@ public ExternalCredentialsResponse postExternalCredentialsRequest(ResourceContex
fullArn = Boolean.parseBoolean(fullArnValue);
}

boolean allScopePresent = false;
final String allScopePresentValue = extCredsAttributes.get(ZTSConsts.ZTS_EXTERNAL_ATTR_ALL_SCOPE_PRESENT);
if (!StringUtil.isEmpty(allScopePresentValue)) {
allScopePresent = Boolean.parseBoolean(allScopePresentValue);
}

// get our principal's name

final Principal principal = ((RsrcCtxWrapper) ctx).principal();
Expand Down Expand Up @@ -5033,7 +5066,7 @@ public ExternalCredentialsResponse postExternalCredentialsRequest(ResourceContex
// either groups or roles for our response

List<String> idTokenGroups = processIdTokenRoles(principalName, tokenRequest,
clientIdDomain, fullArn, principalDomain, caller);
clientIdDomain, fullArn, allScopePresent, principalDomain, caller);

long iat = System.currentTimeMillis() / 1000;

Expand Down
Loading

0 comments on commit a979629

Please sign in to comment.