Skip to content

Commit

Permalink
feat(credential-providers): pass caller client options to fromTempora…
Browse files Browse the repository at this point in the history
…ryCredentials inner STSClient (#6838)

* feat(credential-providers): pass caller client options to fromTemporaryCredentials inner STSClient

* chore: update lockfile

* chore(client-sts): rename variable

* test(credential-providers): add unit test fromTempCreds
  • Loading branch information
kuhe authored Jan 22, 2025
1 parent 9199f2f commit 0d0b14e
Show file tree
Hide file tree
Showing 11 changed files with 231 additions and 326 deletions.
1 change: 1 addition & 0 deletions .prettierignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ CHANGELOG.md
**/*.hbs
**/*/report.md
clients/*/src/endpoint/ruleset.ts
packages/nested-clients/src/submodules/*/endpoint/ruleset.ts
**/*.java
10 changes: 6 additions & 4 deletions clients/client-sts/src/defaultStsRoleAssumers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ const resolveRegion = async (
*/
export const getDefaultRoleAssumer = (
stsOptions: STSRoleAssumerOptions,
stsClientCtor: new (options: STSClientConfig) => STSClient
STSClient: new (options: STSClientConfig) => STSClient
): RoleAssumer => {
let stsClient: STSClient;
let closureSourceCreds: AwsCredentialIdentity;
Expand All @@ -104,7 +104,8 @@ export const getDefaultRoleAssumer = (
);
const isCompatibleRequestHandler = !isH2(requestHandler);

stsClient = new stsClientCtor({
stsClient = new STSClient({
profile: stsOptions?.parentClientConfig?.profile,
// A hack to make sts client uses the credential in current closure.
credentialDefaultProvider: () => async () => closureSourceCreds,
region: resolvedRegion,
Expand Down Expand Up @@ -146,7 +147,7 @@ export type RoleAssumerWithWebIdentity = (
*/
export const getDefaultRoleAssumerWithWebIdentity = (
stsOptions: STSRoleAssumerOptions,
stsClientCtor: new (options: STSClientConfig) => STSClient
STSClient: new (options: STSClientConfig) => STSClient
): RoleAssumerWithWebIdentity => {
let stsClient: STSClient;
return async (params) => {
Expand All @@ -164,7 +165,8 @@ export const getDefaultRoleAssumerWithWebIdentity = (
);
const isCompatibleRequestHandler = !isH2(requestHandler);

stsClient = new stsClientCtor({
stsClient = new STSClient({
profile: stsOptions?.parentClientConfig?.profile,
region: resolvedRegion,
requestHandler: isCompatibleRequestHandler ? (requestHandler as any) : undefined,
logger: logger as any,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ const resolveRegion = async (
*/
export const getDefaultRoleAssumer = (
stsOptions: STSRoleAssumerOptions,
stsClientCtor: new (options: STSClientConfig) => STSClient
STSClient: new (options: STSClientConfig) => STSClient
): RoleAssumer => {
let stsClient: STSClient;
let closureSourceCreds: AwsCredentialIdentity;
Expand All @@ -101,7 +101,8 @@ export const getDefaultRoleAssumer = (
);
const isCompatibleRequestHandler = !isH2(requestHandler);

stsClient = new stsClientCtor({
stsClient = new STSClient({
profile: stsOptions?.parentClientConfig?.profile,
// A hack to make sts client uses the credential in current closure.
credentialDefaultProvider: () => async () => closureSourceCreds,
region: resolvedRegion,
Expand Down Expand Up @@ -143,7 +144,7 @@ export type RoleAssumerWithWebIdentity = (
*/
export const getDefaultRoleAssumerWithWebIdentity = (
stsOptions: STSRoleAssumerOptions,
stsClientCtor: new (options: STSClientConfig) => STSClient
STSClient: new (options: STSClientConfig) => STSClient
): RoleAssumerWithWebIdentity => {
let stsClient: STSClient;
return async (params) => {
Expand All @@ -161,7 +162,8 @@ export const getDefaultRoleAssumerWithWebIdentity = (
);
const isCompatibleRequestHandler = !isH2(requestHandler);

stsClient = new stsClientCtor({
stsClient = new STSClient({
profile: stsOptions?.parentClientConfig?.profile,
region: resolvedRegion,
requestHandler: isCompatibleRequestHandler ? (requestHandler as any) : undefined,
logger: logger as any,
Expand Down
1 change: 1 addition & 0 deletions packages/credential-providers/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
"@aws-sdk/credential-provider-web-identity": "*",
"@aws-sdk/nested-clients": "*",
"@aws-sdk/types": "*",
"@smithy/core": "^3.0.0",
"@smithy/credential-provider-imds": "^4.0.0",
"@smithy/property-provider": "^4.0.0",
"@smithy/types": "^4.0.0",
Expand Down
99 changes: 89 additions & 10 deletions packages/credential-providers/src/fromTemporaryCredentials.base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,39 @@ import type {
CredentialProviderOptions,
RuntimeConfigAwsCredentialIdentityProvider,
} from "@aws-sdk/types";
import { normalizeProvider } from "@smithy/core";
import { CredentialsProviderError } from "@smithy/property-provider";
import { AwsCredentialIdentity, AwsCredentialIdentityProvider, Pluggable } from "@smithy/types";
import { AwsCredentialIdentity, AwsCredentialIdentityProvider, Logger, Pluggable, RequestHandler } from "@smithy/types";

export interface FromTemporaryCredentialsOptions extends CredentialProviderOptions {
params: Omit<AssumeRoleCommandInput, "RoleSessionName"> & { RoleSessionName?: string };
masterCredentials?: AwsCredentialIdentity | AwsCredentialIdentityProvider;
clientConfig?: STSClientConfig;
logger?: Logger;
clientPlugins?: Pluggable<any, any>[];
mfaCodeProvider?: (mfaSerial: string) => Promise<string>;
}

const ASSUME_ROLE_DEFAULT_REGION = "us-east-1";

export const fromTemporaryCredentials = (
options: FromTemporaryCredentialsOptions,
credentialDefaultProvider?: () => AwsCredentialIdentityProvider
): RuntimeConfigAwsCredentialIdentityProvider => {
let stsClient: STSClient;
return async (awsIdentityProperties: AwsIdentityProperties = {}): Promise<AwsCredentialIdentity> => {
options.logger?.debug("@aws-sdk/credential-providers - fromTemporaryCredentials (STS)");
const { callerClientConfig } = awsIdentityProperties;
const logger = options.logger ?? callerClientConfig?.logger;
logger?.debug("@aws-sdk/credential-providers - fromTemporaryCredentials (STS)");

const params = { ...options.params, RoleSessionName: options.params.RoleSessionName ?? "aws-sdk-js-" + Date.now() };
if (params?.SerialNumber) {
if (!options.mfaCodeProvider) {
throw new CredentialsProviderError(
`Temporary credential requires multi-factor authentication, but no MFA code callback was provided.`,
{
tryNextLink: false,
logger: options.logger,
logger,
}
);
}
Expand All @@ -42,14 +49,68 @@ export const fromTemporaryCredentials = (
const defaultCredentialsOrError =
typeof credentialDefaultProvider === "function" ? credentialDefaultProvider() : undefined;

const { callerClientConfig } = awsIdentityProperties;
const credentialSources = [
options.masterCredentials,
options.clientConfig?.credentials,
/**
* Important (!): callerClientConfig?.credentials is not a valid
* credential source for this provider, because this function
* is the caller client's credential provider function.
*/
void callerClientConfig?.credentials,
callerClientConfig?.credentialDefaultProvider?.(),
defaultCredentialsOrError,
];
let credentialSource = "STS client default credentials";
if (credentialSources[0]) {
credentialSource = "options.masterCredentials";
} else if (credentialSources[1]) {
credentialSource = "options.clientConfig.credentials";
} else if (credentialSources[2]) {
// This branch is not possible, see above void note.
// This code is here to prevent accidental attempts to utilize
// the invalid credential source.
credentialSource = "caller client's credentials";
throw new Error("fromTemporaryCredentials recursion in callerClientConfig.credentials");
} else if (credentialSources[3]) {
credentialSource = "caller client's credentialDefaultProvider";
} else if (credentialSources[4]) {
credentialSource = "AWS SDK default credentials";
}

const regionSources = [options.clientConfig?.region, callerClientConfig?.region, ASSUME_ROLE_DEFAULT_REGION];
let regionSource = "default partition's default region";
if (regionSources[0]) {
regionSource = "options.clientConfig.region";
} else if (regionSources[1]) {
regionSource = "caller client's region";
}

const requestHandlerSources = [
filterRequestHandler(options.clientConfig?.requestHandler),
filterRequestHandler(callerClientConfig?.requestHandler),
];
let requestHandlerSource = "STS default requestHandler";
if (requestHandlerSources[0]) {
requestHandlerSource = "options.clientConfig.requestHandler";
} else if (requestHandlerSources[1]) {
requestHandlerSource = "caller client's requestHandler";
}

logger?.debug?.(
`@aws-sdk/credential-providers - fromTemporaryCredentials STS client init with ` +
`${regionSource}=${await normalizeProvider(
coalesce(regionSources)
)()}, ${credentialSource}, ${requestHandlerSource}.`
);

stsClient = new STSClient({
...options.clientConfig,
credentials:
options.masterCredentials ??
options.clientConfig?.credentials ??
callerClientConfig?.credentialDefaultProvider?.() ??
defaultCredentialsOrError,
credentials: coalesce(credentialSources),
logger,
profile: options.clientConfig?.profile ?? callerClientConfig?.profile,
region: coalesce(regionSources),
requestHandler: coalesce(requestHandlerSources),
});
}
if (options.clientPlugins) {
Expand All @@ -60,7 +121,7 @@ export const fromTemporaryCredentials = (
const { Credentials } = await stsClient.send(new AssumeRoleCommand(params));
if (!Credentials || !Credentials.AccessKeyId || !Credentials.SecretAccessKey) {
throw new CredentialsProviderError(`Invalid response from STS.assumeRole call with role ${params.RoleArn}`, {
logger: options.logger,
logger,
});
}
return {
Expand All @@ -73,3 +134,21 @@ export const fromTemporaryCredentials = (
};
};
};

/**
* @internal
*/
const filterRequestHandler = (requestHandler: STSClientConfig["requestHandler"]): undefined | typeof requestHandler => {
return (requestHandler as RequestHandler<any, any>)?.metadata?.handlerProtocol === "h2" ? undefined : requestHandler;
};

/**
* @internal
*/
const coalesce = (args: any) => {
for (const item of args) {
if (item !== undefined) {
return item;
}
}
};
42 changes: 42 additions & 0 deletions packages/credential-providers/src/fromTemporaryCredentials.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ describe("fromTemporaryCredentials", () => {
await provider();
expect(vi.mocked(STSClient as any)).toHaveBeenCalledWith({
credentials: masterCredentials,
logger: void 0,
profile: void 0,
region: "us-east-1",
requestHandler: void 0,
});
expect(mockUsePlugin).toHaveBeenCalledTimes(1);
expect(mockUsePlugin).toHaveBeenNthCalledWith(1, plugin);
Expand Down Expand Up @@ -193,6 +197,44 @@ describe("fromTemporaryCredentials", () => {
});
});

it("uses caller client options if not overridden with provider client options", async () => {
const provider = fromTemporaryCredentialsNode({
params: {
RoleArn,
RoleSessionName,
},
});
const logger = {
debug() {},
info() {},
warn() {},
error() {},
};
const credentials = {
accessKeyId: "",
secretAccessKey: "",
};
const credentialProvider = async () => credentials;
const regionProvider = async () => "B";
await provider({
callerClientConfig: {
profile: "A",
region: regionProvider,
logger,
requestHandler: Symbol.for("requestHandler") as any,
credentialDefaultProvider: () => credentialProvider,
},
});
expect(vi.mocked(STSClient as any).mock.calls[0][0]).toEqual({
profile: "A",
region: regionProvider,
logger,
requestHandler: Symbol.for("requestHandler") as any,
// mockImpl resolved the credentials.
credentials,
});
});

it("should allow assume roles assuming roles assuming roles ad infinitum", async () => {
const roleArnOf = (id: string) => `arn:aws:iam::123456789:role/${id}`;
const idOf = (roleArn: string) => roleArn.split("/")?.[1] ?? "UNKNOWN";
Expand Down
Loading

0 comments on commit 0d0b14e

Please sign in to comment.