Skip to content

Commit

Permalink
fix(core): add cache-control header to cognito identity client (aws-a…
Browse files Browse the repository at this point in the history
  • Loading branch information
haverchuck authored Dec 12, 2022
1 parent cde60fc commit dfbabaf
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 37 deletions.
70 changes: 43 additions & 27 deletions packages/amazon-cognito-identity-js/src/Client.js
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,20 @@ export default class Client {
requestWithRetry(operation, params, callback) {
const MAX_DELAY_IN_MILLIS = 5 * 1000;

jitteredExponentialRetry((p) => new Promise((res, rej) => {
this.request(operation, p, (error, result) => {
if (error) {
rej(error);
} else {
res(result);
}
});
}), [params], MAX_DELAY_IN_MILLIS)
jitteredExponentialRetry(
p =>
new Promise((res, rej) => {
this.request(operation, p, (error, result) => {
if (error) {
rej(error);
} else {
res(result);
}
});
}),
[params],
MAX_DELAY_IN_MILLIS
)
.then(result => callback(null, result))
.catch(error => callback(error));
}
Expand Down Expand Up @@ -112,9 +117,9 @@ export default class Client {
// Taken from aws-sdk-js/lib/protocol/json.js
// eslint-disable-next-line no-underscore-dangle
const code = (data.__type || data.code).split('#').pop();
const error = new Error(data.message || data.Message || null)
error.name = code
error.code = code
const error = new Error(data.message || data.Message || null);
error.name = code;
error.code = code;
return callback(error);
})
.catch(err => {
Expand All @@ -126,17 +131,19 @@ export default class Client {
) {
try {
const code = response.headers.get('x-amzn-errortype').split(':')[0];
const error = new Error(response.status ? response.status.toString() : null)
error.code = code
error.name = code
error.statusCode = response.status
const error = new Error(
response.status ? response.status.toString() : null
);
error.code = code;
error.name = code;
error.statusCode = response.status;
return callback(error);
} catch (ex) {
return callback(err);
}
// otherwise check if error is Network error
} else if (err instanceof Error && err.message === 'Network error') {
err.code = 'NetworkError'
err.code = 'NetworkError';
}
return callback(err);
});
Expand All @@ -146,7 +153,7 @@ export default class Client {
const logger = {
debug: () => {
// Intentionally blank. This package doesn't have logging
}
},
};

/**
Expand All @@ -159,7 +166,7 @@ class NonRetryableError extends Error {
}
}

const isNonRetryableError = (obj) => {
const isNonRetryableError = obj => {
const key = 'nonRetryable';
return obj && obj[key];
};
Expand All @@ -169,9 +176,13 @@ function retry(functionToRetry, args, delayFn, attempt = 1) {
throw Error('functionToRetry must be a function');
}

logger.debug(`${functionToRetry.name} attempt #${attempt} with args: ${JSON.stringify(args)}`);
logger.debug(
`${functionToRetry.name} attempt #${attempt} with args: ${JSON.stringify(
args
)}`
);

return functionToRetry(...args).catch((err) => {
return functionToRetry(...args).catch(err => {
logger.debug(`error on ${functionToRetry.name}`, err);

if (isNonRetryableError(err)) {
Expand All @@ -184,12 +195,13 @@ function retry(functionToRetry, args, delayFn, attempt = 1) {
logger.debug(`${functionToRetry.name} retrying in ${retryIn} ms`);

if (retryIn !== false) {
return new Promise(res => setTimeout(res, retryIn))
.then(() => retry(functionToRetry, args, delayFn, attempt + 1))
return new Promise(res => setTimeout(res, retryIn)).then(() =>
retry(functionToRetry, args, delayFn, attempt + 1)
);
} else {
throw err;
}
})
});
}

function jitteredBackoff(maxDelayMs) {
Expand All @@ -203,6 +215,10 @@ function jitteredBackoff(maxDelayMs) {
}

const MAX_DELAY_MS = 5 * 60 * 1000;
function jitteredExponentialRetry(functionToRetry, args, maxDelayMs = MAX_DELAY_MS) {
return retry(functionToRetry, args, jitteredBackoff(maxDelayMs))
};
function jitteredExponentialRetry(
functionToRetry,
args,
maxDelayMs = MAX_DELAY_MS
) {
return retry(functionToRetry, args, jitteredBackoff(maxDelayMs));
}
3 changes: 3 additions & 0 deletions packages/core/__tests__/Credentials-test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ describe('Credentials test', () => {
};
}
},
middlewareStack: {
add: (next, _) => {},
},
};
});

Expand Down
82 changes: 82 additions & 0 deletions packages/core/__tests__/Util-test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@ import Reachability from '../src/Util/Reachability';
import { ConsoleLogger as Logger } from '../src/Logger';
import { urlSafeDecode, urlSafeEncode } from '../src/Util/StringUtils';
import { DateUtils } from '../src/Util/DateUtils';
import {
createCognitoIdentityClient,
middlewareArgs,
} from '../src/Util/CognitoIdentityClient';
import { BuildMiddleware, HttpRequest } from '@aws-sdk/types';
import {
GetCredentialsForIdentityCommand,
GetIdCommand,
} from '@aws-sdk/client-cognito-identity';

Logger.LOG_LEVEL = 'DEBUG';

Expand Down Expand Up @@ -51,6 +60,79 @@ describe('Util', () => {
});
});

describe('cognito identity client test', () => {
test('client should be instantiated', async () => {
const cognitoClient = createCognitoIdentityClient({
region: 'us-west-1',
});
expect(cognitoClient).toBeTruthy();
expect.assertions(1);
});

test('middlewareArgs helper should merge headers into request object', async () => {
const args = middlewareArgs({
request: {
headers: {
'test-header': '1234',
},
},
input: {},
});
expect(args.request.headers['test-header']).toEqual('1234');
expect(args.request.headers['cache-control']).toEqual('no-store');
expect.assertions(2);
});

test('headers should be added by middleware on GetIdCommand', async () => {
const requestCacheHeaderValidator: BuildMiddleware<any, any> =
next => async args => {
// middleware intercept the request and return it early
const request = args.request as HttpRequest;
const { headers } = request;
expect(headers['cache-control']).toEqual('no-store');
return { output: {} as any, response: {} as any };
};

const client = createCognitoIdentityClient({ region: 'us-west-1' });
client.middlewareStack.addRelativeTo(requestCacheHeaderValidator, {
relation: 'after',
toMiddleware: 'cacheControlMiddleWare',
});

await client.send(
new GetIdCommand({
IdentityPoolId: 'us-west-1:12345678-1234-1234-1234-123456789000',
})
);
expect.assertions(1);
});

test('headers should be added by middleware on GetCredentialsForIdentityCommand', async () => {
const requestCacheHeaderValidator: BuildMiddleware<any, any> =
next => async args => {
// middleware intercept the request and return it early
const request = args.request as HttpRequest;
const { headers } = request;
expect(headers['cache-control']).toEqual('no-store');
return { output: {} as any, response: {} as any };
};

const client = createCognitoIdentityClient({ region: 'us-west-1' });
client.middlewareStack.addRelativeTo(requestCacheHeaderValidator, {
relation: 'after',
toMiddleware: 'cacheControlMiddleWare',
});
await client.send(
new GetCredentialsForIdentityCommand({
IdentityId: '1234',
Logins: {},
})
);

expect.assertions(1);
});
});

test('jitteredExponential retry happy case', async () => {
const resolveAt = 3;
let attempts = 0;
Expand Down
18 changes: 8 additions & 10 deletions packages/core/src/Credentials.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import { makeQuerablePromise } from './JS';
import { FacebookOAuth, GoogleOAuth } from './OAuthHelper';
import { jitteredExponentialRetry } from './Util';
import { ICredentials } from './types';
import { getAmplifyUserAgent } from './Platform';
import { Amplify } from './Amplify';
import {
fromCognitoIdentity,
Expand All @@ -13,13 +12,13 @@ import {
FromCognitoIdentityPoolParameters,
} from '@aws-sdk/credential-provider-cognito-identity';
import {
CognitoIdentityClient,
GetIdCommand,
GetCredentialsForIdentityCommand,
} from '@aws-sdk/client-cognito-identity';
import { CredentialProvider } from '@aws-sdk/types';
import { parseAWSExports } from './parseAWSExports';
import { Hub } from './Hub';
import { createCognitoIdentityClient } from './Util/CognitoIdentityClient';

const logger = new Logger('Credentials');

Expand Down Expand Up @@ -265,7 +264,8 @@ export class CredentialsClass {
parseAWSExports(this._config || {}).Auth
);
}
const { identityPoolId, region, mandatorySignIn, identityPoolRegion } = this._config;
const { identityPoolId, region, mandatorySignIn, identityPoolRegion } =
this._config;

if (mandatorySignIn) {
return Promise.reject(
Expand All @@ -291,9 +291,8 @@ export class CredentialsClass {

const identityId = (this._identityId = await this._getGuestIdentityId());

const cognitoClient = new CognitoIdentityClient({
const cognitoClient = createCognitoIdentityClient({
region: identityPoolRegion || region,
customUserAgent: getAmplifyUserAgent(),
});

let credentials = undefined;
Expand Down Expand Up @@ -408,9 +407,8 @@ export class CredentialsClass {
);
}

const cognitoClient = new CognitoIdentityClient({
const cognitoClient = createCognitoIdentityClient({
region: identityPoolRegion || region,
customUserAgent: getAmplifyUserAgent(),
});

let credentials = undefined;
Expand All @@ -435,7 +433,8 @@ export class CredentialsClass {
private _setCredentialsFromSession(session): Promise<ICredentials> {
logger.debug('set credentials from session');
const idToken = session.getIdToken().getJwtToken();
const { region, userPoolId, identityPoolId, identityPoolRegion } = this._config;
const { region, userPoolId, identityPoolId, identityPoolRegion } =
this._config;
if (!identityPoolId) {
logger.debug('No Cognito Federated Identity pool provided');
return Promise.reject('No Cognito Federated Identity pool provided');
Expand All @@ -450,9 +449,8 @@ export class CredentialsClass {
const logins = {};
logins[key] = idToken;

const cognitoClient = new CognitoIdentityClient({
const cognitoClient = createCognitoIdentityClient({
region: identityPoolRegion || region,
customUserAgent: getAmplifyUserAgent(),
});

/*
Expand Down
48 changes: 48 additions & 0 deletions packages/core/src/Util/CognitoIdentityClient.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

import {
CognitoIdentityClient,
CognitoIdentityClientConfig,
} from '@aws-sdk/client-cognito-identity';
import { Provider } from '@aws-sdk/types';
import { getAmplifyUserAgent } from '../Platform';

/**
* Returns a CognitoIdentityClient with middleware
* @param {CognitoIdentityClientConfig} config
* @return {CognitoIdentityClient}
*/
export function createCognitoIdentityClient(
config: CognitoIdentityClientConfig
): CognitoIdentityClient {
const client = new CognitoIdentityClient({
region: config.region,
customUserAgent: getAmplifyUserAgent(),
});

client.middlewareStack.add(
(next, _) => (args: any) => {
return next(middlewareArgs(args));
},
{
step: 'build',
name: 'cacheControlMiddleWare',
}
);

return client;
}

export function middlewareArgs(args: { request: any; input: any }) {
return {
...args,
request: {
...args.request,
headers: {
...args.request.headers,
'cache-control': 'no-store',
},
},
};
}

0 comments on commit dfbabaf

Please sign in to comment.