diff --git a/packages/api/package.json b/packages/api/package.json index 129ff269ec5e..8f9099709eed 100644 --- a/packages/api/package.json +++ b/packages/api/package.json @@ -36,6 +36,7 @@ "jsonwebtoken": "8.5.1", "jwks-rsa": "2.0.5", "md5": "2.3.0", + "node-fetch": "2.6.7", "pascalcase": "1.0.0", "pino": "7.8.0", "uuid": "8.3.2" diff --git a/packages/api/src/__tests__/normalizeRequest.test.ts b/packages/api/src/__tests__/normalizeRequest.test.ts new file mode 100644 index 000000000000..c48a7282d100 --- /dev/null +++ b/packages/api/src/__tests__/normalizeRequest.test.ts @@ -0,0 +1,93 @@ +import type { APIGatewayProxyEvent } from 'aws-lambda' +import { Headers } from 'node-fetch' + +import { normalizeRequest } from '../transforms' + +export const createMockedEvent = ( + httpMethod = 'POST', + body: any = undefined, + isBase64Encoded = false +): APIGatewayProxyEvent => { + return { + body, + headers: {}, + multiValueHeaders: {}, + httpMethod, + isBase64Encoded, + path: '/MOCK_PATH', + pathParameters: null, + queryStringParameters: null, + multiValueQueryStringParameters: null, + stageVariables: null, + requestContext: { + accountId: 'MOCKED_ACCOUNT', + apiId: 'MOCKED_API_ID', + authorizer: { name: 'MOCKED_AUTHORIZER' }, + protocol: 'HTTP', + identity: { + accessKey: null, + accountId: null, + apiKey: null, + apiKeyId: null, + caller: null, + clientCert: null, + cognitoAuthenticationProvider: null, + cognitoAuthenticationType: null, + cognitoIdentityId: null, + cognitoIdentityPoolId: null, + principalOrgId: null, + sourceIp: '123.123.123.123', + user: null, + userAgent: null, + userArn: null, + }, + httpMethod: 'POST', + path: '/MOCK_PATH', + stage: 'MOCK_STAGE', + requestId: 'MOCKED_REQUEST_ID', + requestTimeEpoch: 1, + resourceId: 'MOCKED_RESOURCE_ID', + resourcePath: 'MOCKED_RESOURCE_PATH', + }, + resource: 'MOCKED_RESOURCE', + } +} + +test('Normalizes an aws event with base64', () => { + const corsEventB64 = createMockedEvent( + 'POST', + Buffer.from(JSON.stringify({ bazinga: 'hello_world' }), 'utf8').toString( + 'base64' + ), + true + ) + + expect(normalizeRequest(corsEventB64)).toEqual({ + headers: new Headers(corsEventB64.headers), + method: 'POST', + query: null, + body: { + bazinga: 'hello_world', + }, + }) +}) + +test('Handles CORS requests with and without b64 encoded', () => { + const corsEventB64 = createMockedEvent('OPTIONS', undefined, true) + + expect(normalizeRequest(corsEventB64)).toEqual({ + headers: new Headers(corsEventB64.headers), // headers returned as symbol + method: 'OPTIONS', + query: null, + body: undefined, + }) + + const corsEventWithoutB64 = createMockedEvent('OPTIONS', undefined, false) + + expect(normalizeRequest(corsEventWithoutB64)).toEqual({ + headers: new Headers(corsEventB64.headers), // headers returned as symbol + method: 'OPTIONS', + query: null, + body: undefined, + }) +}) diff --git a/packages/api/src/cors.ts b/packages/api/src/cors.ts new file mode 100644 index 000000000000..18009b46a84d --- /dev/null +++ b/packages/api/src/cors.ts @@ -0,0 +1,98 @@ +import type { Request } from 'graphql-helix' +import { Headers } from 'node-fetch' + +export type CorsConfig = { + origin?: boolean | string | string[] + methods?: string | string[] + allowedHeaders?: string | string[] + exposedHeaders?: string | string[] + credentials?: boolean + maxAge?: number +} + +export type CorsHeaders = Record +export type CorsContext = ReturnType + +export function createCorsContext(cors: CorsConfig | undefined) { + // Taken from apollo-server-env + // @see: https://github.com/apollographql/apollo-server/blob/9267a79b974e397e87ad9ee408b65c46751e4565/packages/apollo-server-env/src/polyfills/fetch.js#L1 + const corsHeaders = new Headers() + + if (cors) { + if (cors.methods) { + if (typeof cors.methods === 'string') { + corsHeaders.set('access-control-allow-methods', cors.methods) + } else if (Array.isArray(cors.methods)) { + corsHeaders.set('access-control-allow-methods', cors.methods.join(',')) + } + } + + if (cors.allowedHeaders) { + if (typeof cors.allowedHeaders === 'string') { + corsHeaders.set('access-control-allow-headers', cors.allowedHeaders) + } else if (Array.isArray(cors.allowedHeaders)) { + corsHeaders.set( + 'access-control-allow-headers', + cors.allowedHeaders.join(',') + ) + } + } + + if (cors.exposedHeaders) { + if (typeof cors.exposedHeaders === 'string') { + corsHeaders.set('access-control-expose-headers', cors.exposedHeaders) + } else if (Array.isArray(cors.exposedHeaders)) { + corsHeaders.set( + 'access-control-expose-headers', + cors.exposedHeaders.join(',') + ) + } + } + + if (cors.credentials) { + corsHeaders.set('access-control-allow-credentials', 'true') + } + if (typeof cors.maxAge === 'number') { + corsHeaders.set('access-control-max-age', cors.maxAge.toString()) + } + } + + return { + shouldHandleCors(request: Request) { + return request.method === 'OPTIONS' + }, + getRequestHeaders(request: Request): CorsHeaders { + const eventHeaders = new Headers( + request.headers as Record + ) + const requestCorsHeaders = new Headers(corsHeaders) + + if (cors && cors.origin) { + const requestOrigin = eventHeaders.get('origin') + if (typeof cors.origin === 'string') { + requestCorsHeaders.set('access-control-allow-origin', cors.origin) + } else if ( + requestOrigin && + (typeof cors.origin === 'boolean' || + (Array.isArray(cors.origin) && + requestOrigin && + cors.origin.includes(requestOrigin))) + ) { + requestCorsHeaders.set('access-control-allow-origin', requestOrigin) + } + + const requestAccessControlRequestHeaders = eventHeaders.get( + 'access-control-request-headers' + ) + if (!cors.allowedHeaders && requestAccessControlRequestHeaders) { + requestCorsHeaders.set( + 'access-control-allow-headers', + requestAccessControlRequestHeaders + ) + } + } + + return Object.fromEntries(requestCorsHeaders.entries()) + }, + } +} diff --git a/packages/api/src/functions/dbAuth/DbAuthHandler.ts b/packages/api/src/functions/dbAuth/DbAuthHandler.ts index 9378cb633d2f..b2e008d42dbc 100644 --- a/packages/api/src/functions/dbAuth/DbAuthHandler.ts +++ b/packages/api/src/functions/dbAuth/DbAuthHandler.ts @@ -4,6 +4,14 @@ import CryptoJS from 'crypto-js' import md5 from 'md5' import { v4 as uuidv4 } from 'uuid' +import { + CorsConfig, + CorsContext, + CorsHeaders, + createCorsContext, +} from '../../cors' +import { normalizeRequest } from '../../transforms' + import * as DbAuthError from './errors' import { decryptSession, getSession } from './shared' @@ -30,6 +38,16 @@ interface DbAuthHandlerOptions { resetToken: string resetTokenExpiresAt: string } + /** + * Object containing cookie config options + */ + cookie?: { + Path?: string + HttpOnly?: boolean + Secure?: boolean + SameSite?: string + Domain?: string + } /** * Object containing forgot password options */ @@ -100,6 +118,11 @@ interface DbAuthHandlerOptions { usernameTaken?: string } } + + /** + * CORS settings, same as in createGraphqlHandler + */ + cors?: CorsConfig } interface SignupHandlerOptions { @@ -140,6 +163,7 @@ export class DbAuthHandler { hasInvalidSession: boolean session: SessionRecord | undefined sessionCsrfToken: string | undefined + corsContext: CorsContext | undefined // class constant: list of auth methods that are supported static get METHODS(): AuthMethodNames[] { @@ -168,6 +192,7 @@ export class DbAuthHandler { } // class constant: all the attributes of the cookie other than the value itself + // DEPRECATED: Remove once deprecation warning is removed from _cookieAttributes() static get COOKIE_META() { const meta = [`Path=/`, 'HttpOnly', 'SameSite=Strict'] @@ -223,6 +248,10 @@ export class DbAuthHandler { this.headerCsrfToken = this.event.headers['csrf-token'] this.hasInvalidSession = false + if (options.cors) { + this.corsContext = createCorsContext(options.cors) + } + try { const [session, csrfToken] = decryptSession( getSession(this.event.headers['cookie']) @@ -243,10 +272,26 @@ export class DbAuthHandler { // Actual function that triggers everything else to happen: `login`, `signup`, // etc. is called from here, after some checks to make sure the request is good async invoke() { + const request = normalizeRequest(this.event) + let corsHeaders = {} + if (this.corsContext) { + corsHeaders = this.corsContext.getRequestHeaders(request) + // Return CORS headers for OPTIONS requests + if (this.corsContext.shouldHandleCors(request)) { + return this._buildResponseWithCorsHeaders( + { body: '', statusCode: 200 }, + corsHeaders + ) + } + } + // if there was a problem decryption the session, just return the logout // response immediately if (this.hasInvalidSession) { - return this._ok(...this._logoutResponse()) + return this._buildResponseWithCorsHeaders( + this._ok(...this._logoutResponse()), + corsHeaders + ) } try { @@ -254,12 +299,12 @@ export class DbAuthHandler { // get the auth method the incoming request is trying to call if (!DbAuthHandler.METHODS.includes(method)) { - return this._notFound() + return this._buildResponseWithCorsHeaders(this._notFound(), corsHeaders) } // make sure it's using the correct verb, GET vs POST if (this.event.httpMethod !== DbAuthHandler.VERBS[method]) { - return this._notFound() + return this._buildResponseWithCorsHeaders(this._notFound(), corsHeaders) } // call whatever auth method was requested and return the body and headers @@ -267,12 +312,18 @@ export class DbAuthHandler { method ]() - return this._ok(body, headers, options) + return this._buildResponseWithCorsHeaders( + this._ok(body, headers, options), + corsHeaders + ) } catch (e: any) { if (e instanceof DbAuthError.WrongVerbError) { - return this._notFound() + return this._buildResponseWithCorsHeaders(this._notFound(), corsHeaders) } else { - return this._badRequest(e.message || e) + return this._buildResponseWithCorsHeaders( + this._badRequest(e.message || e), + corsHeaders + ) } } } @@ -517,10 +568,35 @@ export class DbAuthHandler { // pass the argument `expires` set to "now" to get the attributes needed to expire // the session, or "future" (or left out completely) to set to `_futureExpiresDate` _cookieAttributes({ expires = 'future' }: { expires?: 'now' | 'future' }) { - const meta = JSON.parse(JSON.stringify(DbAuthHandler.COOKIE_META)) + let meta - if (process.env.NODE_ENV !== 'development') { - meta.push('Secure') + // DEPRECATED: Remove deprecation logic after a few releases, assume this.options.cookie contains config + if (!this.options.cookie) { + console.warn( + `\n[Deprecation Notice] dbAuth cookie config has moved to\n api/src/function/auth.js for better customization.\n See https://redwoodjs.com/docs/authentication#cookie-config\n` + ) + meta = JSON.parse(JSON.stringify(DbAuthHandler.COOKIE_META)) + + if (process.env.NODE_ENV !== 'development') { + meta.push('Secure') + } + } else { + const cookieOptions = this.options.cookie || {} + meta = Object.keys(cookieOptions) + .map((key) => { + const optionValue = + cookieOptions[key as keyof DbAuthHandlerOptions['cookie']] + + // Convert the options to valid cookie string + if (optionValue === true) { + return key + } else if (optionValue === false) { + return null + } else { + return `${key}=${optionValue}` + } + }) + .filter((v) => v) } const expiresAt = @@ -787,4 +863,21 @@ export class DbAuthHandler { headers: { 'Content-Type': 'application/json' }, } } + + _buildResponseWithCorsHeaders( + response: { + body?: string + statusCode: number + headers?: Record + }, + corsHeaders: CorsHeaders + ) { + return { + ...response, + headers: { + ...(response.headers || {}), + ...corsHeaders, + }, + } + } } diff --git a/packages/api/src/functions/dbAuth/__tests__/DbAuthHandler.test.js b/packages/api/src/functions/dbAuth/__tests__/DbAuthHandler.test.js index bb6e4c26d267..d2a5beec0da0 100644 --- a/packages/api/src/functions/dbAuth/__tests__/DbAuthHandler.test.js +++ b/packages/api/src/functions/dbAuth/__tests__/DbAuthHandler.test.js @@ -104,6 +104,8 @@ let event, context, options describe('dbAuth', () => { beforeEach(() => { + // hide deprecation warnings during test + jest.spyOn(console, 'warn').mockImplementation(() => {}) // encryption key so results are consistent regardless of settings in .env process.env.SESSION_SECRET = 'nREjs1HPS7cFia6tQHK70EWGtfhOgbqJQKsHQz3S' delete process.env.DBAUTH_COOKIE_DOMAIN @@ -164,6 +166,7 @@ describe('dbAuth', () => { }) afterEach(async () => { + jest.spyOn(console, 'warn').mockRestore() await db.user.deleteMany({ where: { email: 'rob@redwoodjs.com' }, }) @@ -514,6 +517,29 @@ describe('dbAuth', () => { expect(response.body).toEqual('{"error":"Logout error"}') }) + it('handlers CORS OPTIONS request', async () => { + event.httpMethod = 'OPTIONS' + event.body = JSON.stringify({ method: 'auth' }) + + const dbAuth = new DbAuthHandler(event, context, { + ...options, + cors: { + origin: 'https://www.myRedwoodWebSide.com', + credentials: true, + }, + }) + dbAuth.logout = jest.fn(() => { + throw Error('Logout error') + }) + const response = await dbAuth.invoke() + + expect(response.statusCode).toEqual(200) + expect(response.headers['access-control-allow-credentials']).toBe('true') + expect(response.headers['access-control-allow-origin']).toBe( + 'https://www.myRedwoodWebSide.com' + ) + }) + it('calls the appropriate auth function', async () => { event.body = JSON.stringify({ method: 'logout' }) event.httpMethod = 'POST' @@ -1145,6 +1171,7 @@ describe('dbAuth', () => { }) describe('_cookieAttributes', () => { + // DEPRECATED: cookie config should come from options object now it('returns an array of attributes for the session cookie', () => { const dbAuth = new DbAuthHandler( { headers: { referer: 'http://test.host' } }, @@ -1163,6 +1190,7 @@ describe('dbAuth', () => { expect(attributes[4]).toMatch(UTC_DATE_REGEX) }) + // DEPRECATED: Secure will be set or not in cookie config options it('does not include the Secure attribute when in development environment', () => { const oldEnv = process.env.NODE_ENV process.env.NODE_ENV = 'development' @@ -1177,6 +1205,7 @@ describe('dbAuth', () => { process.env.NODE_ENV = oldEnv }) + // DEPRECATED: Domain will be set or not in cookie config options it('includes a Domain in the cookie if DBAUTH_COOKIE_DOMAIN is set', () => { process.env.DBAUTH_COOKIE_DOMAIN = 'site.test' @@ -1185,6 +1214,83 @@ describe('dbAuth', () => { expect(attributes[3]).toEqual('Domain=site.test') }) + + it('returns an array of attributes for the session cookie', () => { + const dbAuth = new DbAuthHandler( + { headers: { referer: 'http://test.host' } }, + context, + { + ...options, + cookie: { + Path: '/', + HttpOnly: true, + SameSite: 'Strict', + Secure: true, + Domain: 'example.com', + }, + } + ) + const attributes = dbAuth._cookieAttributes({}) + + expect(attributes.length).toEqual(6) + expect(attributes[0]).toEqual('Path=/') + expect(attributes[1]).toEqual('HttpOnly') + expect(attributes[2]).toEqual('SameSite=Strict') + expect(attributes[3]).toEqual('Secure') + expect(attributes[4]).toEqual('Domain=example.com') + expect(attributes[5]).toMatch(`Expires=`) + expect(attributes[5]).toMatch(UTC_DATE_REGEX) + }) + + it('includes just a key if option set to `true`', () => { + const dbAuth = new DbAuthHandler(event, context, { + ...options, + cookie: { Secure: true }, + }) + const attributes = dbAuth._cookieAttributes({}) + + expect(attributes[0]).toEqual('Secure') + }) + + it('does not include a key if option set to `false`', () => { + const dbAuth = new DbAuthHandler(event, context, { + ...options, + cookie: { Secure: false }, + }) + const attributes = dbAuth._cookieAttributes({}) + + expect(attributes[0]).not.toEqual('Secure') + }) + + it('includes key=value if property value is set', () => { + const dbAuth = new DbAuthHandler(event, context, { + ...options, + cookie: { Domain: 'example.com' }, + }) + const attributes = dbAuth._cookieAttributes({}) + + expect(attributes[0]).toEqual('Domain=example.com') + }) + + it('includes no cookie attributes if cookie options are empty', () => { + const dbAuth = new DbAuthHandler(event, context, { + ...options, + cookie: {}, + }) + const attributes = dbAuth._cookieAttributes({}) + + expect(attributes.length).toEqual(1) + expect(attributes[0]).toMatch(/Expires=/) + }) + + // DEPRECATED: can't test until deprecated functionality is removed + // it('includes no cookie attributes if cookie options not set', () => { + // const dbAuth = new DbAuthHandler(event, context, options) + // const attributes = dbAuth._cookieAttributes({}) + + // expect(attributes.length).toEqual(1) + // expect(attributes[0]).toMatch(/Expires=/) + // }) }) describe('_createSessionHeader()', () => { diff --git a/packages/api/src/functions/dbAuth/errors.ts b/packages/api/src/functions/dbAuth/errors.ts index 7855798ba495..de5622579fb1 100644 --- a/packages/api/src/functions/dbAuth/errors.ts +++ b/packages/api/src/functions/dbAuth/errors.ts @@ -135,7 +135,7 @@ export class CsrfTokenMismatchError extends Error { export class SessionDecryptionError extends Error { constructor() { - super('Session has potentially be tampered with') + super('Session has potentially been tampered with') this.name = 'SessionDecryptionError' } } diff --git a/packages/api/src/index.ts b/packages/api/src/index.ts index eb72c0518cf1..81f5bbed76db 100644 --- a/packages/api/src/index.ts +++ b/packages/api/src/index.ts @@ -5,6 +5,9 @@ export { dbAuthSession } from './functions/dbAuth/shared' export * from './validations/validations' export * from './validations/errors' +export * from './transforms' +export * from './cors' + // @NOTE: use require, to avoid messing around with tsconfig and nested output dirs const packageJson = require('../package.json') export const prismaVersion = packageJson?.dependencies['@prisma/client'] diff --git a/packages/api/src/transforms.ts b/packages/api/src/transforms.ts new file mode 100644 index 000000000000..c38ea7892764 --- /dev/null +++ b/packages/api/src/transforms.ts @@ -0,0 +1,37 @@ +import type { APIGatewayProxyEvent } from 'aws-lambda' +import { Headers } from 'node-fetch' + +// This is the same interface used by graphql-helix +// But not importing here to avoid adding a dependency +export interface Request { + body?: any + headers: Headers + method: string + query: any +} + +/** + * Extracts and parses body payload from event with base64 encoding check + */ +export const parseEventBody = (event: APIGatewayProxyEvent) => { + if (!event.body) { + return + } + + if (event.isBase64Encoded) { + return JSON.parse(Buffer.from(event.body, 'base64').toString('utf-8')) + } else { + return JSON.parse(event.body) + } +} + +export function normalizeRequest(event: APIGatewayProxyEvent): Request { + const body = parseEventBody(event) + + return { + headers: new Headers(event.headers as Record), + method: event.httpMethod, + query: event.queryStringParameters, + body, + } +} diff --git a/packages/auth/src/AuthProvider.tsx b/packages/auth/src/AuthProvider.tsx index d2db9ac90651..504a79840865 100644 --- a/packages/auth/src/AuthProvider.tsx +++ b/packages/auth/src/AuthProvider.tsx @@ -4,6 +4,7 @@ import { createAuthClient } from './authClients' import type { AuthClient, SupportedAuthTypes, + SupportedAuthConfig, SupportedAuthClients, SupportedUserMetadata, } from './authClients' @@ -75,11 +76,19 @@ type AuthProviderProps = | { client: SupportedAuthClients type: Omit + config?: never skipFetchCurrentUser?: boolean } | { client?: never - type: 'dbAuth' | 'clerk' + type: 'clerk' + config?: never + skipFetchCurrentUser?: boolean + } + | { + client?: never + type: 'dbAuth' + config?: SupportedAuthConfig skipFetchCurrentUser?: boolean } @@ -123,7 +132,8 @@ export class AuthProvider extends React.Component< super(props) this.rwClient = createAuthClient( props.client as SupportedAuthClients, - props.type as SupportedAuthTypes + props.type as SupportedAuthTypes, + props.config as SupportedAuthConfig ) } @@ -141,6 +151,8 @@ export class AuthProvider extends React.Component< const token = await this.getToken() const response = await global.fetch(this.getApiGraphQLUrl(), { method: 'POST', + // TODO: how can user configure this? inherit same `config` options given to auth client? + credentials: 'include', headers: { 'content-type': 'application/json', 'auth-provider': this.rwClient.type, diff --git a/packages/auth/src/__tests__/AuthProvider.test.tsx b/packages/auth/src/__tests__/AuthProvider.test.tsx index e563dcbb5fed..7af2cfcdfcc6 100644 --- a/packages/auth/src/__tests__/AuthProvider.test.tsx +++ b/packages/auth/src/__tests__/AuthProvider.test.tsx @@ -806,6 +806,6 @@ test('proxies validateResetToken() calls to client', async () => { ) - // for whatever reason, forgotPassword is invoked twice + // for whatever reason, validateResetToken is invoked twice expect.assertions(2) }) diff --git a/packages/auth/src/__tests__/AuthProviderConfig.test.tsx b/packages/auth/src/__tests__/AuthProviderConfig.test.tsx new file mode 100644 index 000000000000..c621914c39e1 --- /dev/null +++ b/packages/auth/src/__tests__/AuthProviderConfig.test.tsx @@ -0,0 +1,50 @@ +import { render } from '@testing-library/react' +import '@testing-library/jest-dom/extend-expect' + +import { createAuthClient } from '../authClients' +import { AuthProvider } from '../AuthProvider' + +jest.mock('../authClients', () => { + return { + createAuthClient: jest.fn().mockImplementation((...args) => { + return args + }), + } +}) + +describe('AuthProvider options', () => { + it('forwards config options to auth client', () => { + const TestAuthConsumer = () => { + return null + } + + render( + + + + ) + + expect(createAuthClient).toBeCalledWith(undefined, 'dbAuth', { + fetchConfig: { credentials: 'include' }, + }) + }) + + it('does not forward if no config present', () => { + const TestAuthConsumer = () => { + return null + } + + render( + + + + ) + + expect(createAuthClient).toBeCalledWith(undefined, 'dbAuth', undefined) + }) +}) diff --git a/packages/auth/src/authClients/__tests__/dbAuth.test.jsx b/packages/auth/src/authClients/__tests__/dbAuth.test.jsx new file mode 100644 index 000000000000..91ddf77b2952 --- /dev/null +++ b/packages/auth/src/authClients/__tests__/dbAuth.test.jsx @@ -0,0 +1,129 @@ +import { dbAuth } from '../dbAuth' + +global.RWJS_API_DBAUTH_URL = '/.redwood/functions' + +jest.mock('node-fetch', () => { + return +}) + +beforeAll(() => { + global.fetch = jest.fn().mockImplementation(() => { + return { text: () => '', json: () => ({}) } + }) +}) + +beforeEach(() => { + global.fetch.mockClear() +}) + +describe('dbAuth', () => { + it('sets a default credentials value if not included', async () => { + const client = dbAuth(() => null) + await client.getToken() + + expect(global.fetch).toBeCalledWith( + `${global.RWJS_API_DBAUTH_URL}?method=getToken`, + { + credentials: 'same-origin', + } + ) + }) + + it('passes through fetchOptions to forgotPasswrd calls', async () => { + const client = dbAuth(() => null, { + fetchConfig: { credentials: 'include' }, + }) + await client.forgotPassword('username') + + expect(global.fetch).toBeCalledWith( + global.RWJS_API_DBAUTH_URL, + expect.objectContaining({ + credentials: 'include', + }) + ) + }) + + it('passes through fetchOptions to getToken calls', async () => { + const client = dbAuth(() => null, { + fetchConfig: { credentials: 'include' }, + }) + await client.getToken() + + expect(global.fetch).toBeCalledWith( + `${global.RWJS_API_DBAUTH_URL}?method=getToken`, + { + credentials: 'include', + } + ) + }) + + it('passes through fetchOptions to login calls', async () => { + const client = dbAuth(() => null, { + fetchConfig: { credentials: 'include' }, + }) + await client.login({ username: 'username', password: 'password' }) + + expect(global.fetch).toBeCalledWith( + global.RWJS_API_DBAUTH_URL, + expect.objectContaining({ + credentials: 'include', + }) + ) + }) + + it('passes through fetchOptions to logout calls', async () => { + const client = dbAuth(() => null, { + fetchConfig: { credentials: 'include' }, + }) + await client.logout() + + expect(global.fetch).toBeCalledWith( + global.RWJS_API_DBAUTH_URL, + expect.objectContaining({ + credentials: 'include', + }) + ) + }) + + it('passes through fetchOptions to resetPassword calls', async () => { + const client = dbAuth(() => null, { + fetchConfig: { credentials: 'include' }, + }) + await client.resetPassword({}) + + expect(global.fetch).toBeCalledWith( + global.RWJS_API_DBAUTH_URL, + expect.objectContaining({ + credentials: 'include', + }) + ) + }) + + it('passes through fetchOptions to signup calls', async () => { + const client = dbAuth(() => null, { + fetchConfig: { credentials: 'include' }, + }) + await client.signup({}) + + expect(global.fetch).toBeCalledWith( + global.RWJS_API_DBAUTH_URL, + expect.objectContaining({ + credentials: 'include', + }) + ) + }) + + it('passes through fetchOptions to validateResetToken calls', async () => { + const client = dbAuth(() => null, { + fetchConfig: { credentials: 'include' }, + }) + await client.validateResetToken('token') + + expect(global.fetch).toBeCalledWith( + global.RWJS_API_DBAUTH_URL, + expect.objectContaining({ + credentials: 'include', + }) + ) + }) +}) diff --git a/packages/auth/src/authClients/dbAuth.ts b/packages/auth/src/authClients/dbAuth.ts index 5a4b65424648..e35ff21323e1 100644 --- a/packages/auth/src/authClients/dbAuth.ts +++ b/packages/auth/src/authClients/dbAuth.ts @@ -14,9 +14,21 @@ export type SignupAttributes = Record & LoginAttributes export type DbAuth = () => null -export const dbAuth = (): AuthClient => { +export type DbAuthConfig = { + fetchConfig: { + credentials: 'include' | 'same-origin' + } +} + +export const dbAuth = ( + _client: DbAuth, + config: DbAuthConfig = { fetchConfig: { credentials: 'same-origin' } } +): AuthClient => { + const { credentials } = config.fetchConfig + const forgotPassword = async (username: string) => { const response = await fetch(global.RWJS_API_DBAUTH_URL, { + credentials, method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ username, method: 'forgotPassword' }), @@ -26,7 +38,8 @@ export const dbAuth = (): AuthClient => { const getToken = async () => { const response = await fetch( - `${global.RWJS_API_DBAUTH_URL}?method=getToken` + `${global.RWJS_API_DBAUTH_URL}?method=getToken`, + { credentials } ) const token = await response.text() @@ -40,6 +53,7 @@ export const dbAuth = (): AuthClient => { const login = async (attributes: LoginAttributes) => { const { username, password } = attributes const response = await fetch(global.RWJS_API_DBAUTH_URL, { + credentials, method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ username, password, method: 'login' }), @@ -49,6 +63,7 @@ export const dbAuth = (): AuthClient => { const logout = async () => { await fetch(global.RWJS_API_DBAUTH_URL, { + credentials, method: 'POST', body: JSON.stringify({ method: 'logout' }), }) @@ -57,6 +72,7 @@ export const dbAuth = (): AuthClient => { const resetPassword = async (attributes: ResetPasswordAttributes) => { const response = await fetch(global.RWJS_API_DBAUTH_URL, { + credentials, method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ ...attributes, method: 'resetPassword' }), @@ -66,6 +82,7 @@ export const dbAuth = (): AuthClient => { const signup = async (attributes: SignupAttributes) => { const response = await fetch(global.RWJS_API_DBAUTH_URL, { + credentials, method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ ...attributes, method: 'signup' }), @@ -75,6 +92,7 @@ export const dbAuth = (): AuthClient => { const validateResetToken = async (resetToken: string | null) => { const response = await fetch(global.RWJS_API_DBAUTH_URL, { + credentials, method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ resetToken, method: 'validateResetToken' }), diff --git a/packages/auth/src/authClients/index.ts b/packages/auth/src/authClients/index.ts index 9469465a2bbe..8d1601c65903 100644 --- a/packages/auth/src/authClients/index.ts +++ b/packages/auth/src/authClients/index.ts @@ -10,7 +10,7 @@ import type { Clerk, ClerkUser } from './clerk' import { custom } from './custom' import type { Custom } from './custom' import { dbAuth } from './dbAuth' -import type { DbAuth } from './dbAuth' +import type { DbAuth, DbAuthConfig } from './dbAuth' import { ethereum } from './ethereum' import type { Ethereum, EthereumUser } from './ethereum' import { firebase } from './firebase' @@ -62,6 +62,8 @@ export type SupportedAuthClients = export type SupportedAuthTypes = keyof typeof typesToClients +export type SupportedAuthConfig = DbAuthConfig + export type { Auth0User } export type { AzureActiveDirectoryUser } export type { DbAuth } @@ -102,7 +104,8 @@ export interface AuthClient { export const createAuthClient = ( client: SupportedAuthClients, - type: SupportedAuthTypes + type: SupportedAuthTypes, + config?: SupportedAuthConfig ): AuthClient => { if (!typesToClients[type]) { throw new Error( @@ -111,5 +114,6 @@ export const createAuthClient = ( ).join(', ')}` ) } - return typesToClients[type](client) + + return typesToClients[type](client, config) } diff --git a/packages/cli/src/commands/setup/auth/templates/dbAuth.function.ts.template b/packages/cli/src/commands/setup/auth/templates/dbAuth.function.ts.template index 5eb8d43d2d08..92eee495e351 100644 --- a/packages/cli/src/commands/setup/auth/templates/dbAuth.function.ts.template +++ b/packages/cli/src/commands/setup/auth/templates/dbAuth.function.ts.template @@ -140,6 +140,19 @@ export const handler = async (event, context) => { resetTokenExpiresAt: 'resetTokenExpiresAt', }, + // Specifies attributes on the cookie that dbAuth sets in order to remember + // who is logged in. See https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#restrict_access_to_cookies + cookie: { + HttpOnly: true, + Path: '/', + SameSite: 'Strict', + Secure: true, + + // If you need to allow other domains (besides the api side) access to + // the dbAuth session cookie: + // Domain: 'example.com', + }, + forgotPassword: forgotPasswordOptions, login: loginOptions, resetPassword: resetPasswordOptions, diff --git a/packages/graphql-server/src/functions/__tests__/normalizeRequest.test.ts b/packages/graphql-server/src/functions/__tests__/normalizeRequest.test.ts index ec72016fad2e..b4c7e2e2c41c 100644 --- a/packages/graphql-server/src/functions/__tests__/normalizeRequest.test.ts +++ b/packages/graphql-server/src/functions/__tests__/normalizeRequest.test.ts @@ -1,6 +1,7 @@ import type { APIGatewayProxyEvent } from 'aws-lambda' +import { Headers } from 'node-fetch' -import { normalizeRequest } from '../graphql' +import { normalizeRequest } from '@redwoodjs/api' export const createMockedEvent = ( httpMethod = 'POST', @@ -62,7 +63,7 @@ test('Normalizes an aws event with base64', () => { ) expect(normalizeRequest(corsEventB64)).toEqual({ - headers: corsEventB64.headers, + headers: new Headers(corsEventB64.headers), method: 'POST', query: null, body: { @@ -75,7 +76,7 @@ test('Handles CORS requests with and without b64 encoded', () => { const corsEventB64 = createMockedEvent('OPTIONS', undefined, true) expect(normalizeRequest(corsEventB64)).toEqual({ - headers: corsEventB64.headers, + headers: new Headers(corsEventB64.headers), method: 'OPTIONS', query: null, body: undefined, @@ -84,7 +85,7 @@ test('Handles CORS requests with and without b64 encoded', () => { const corsEventWithoutB64 = createMockedEvent('OPTIONS', undefined, false) expect(normalizeRequest(corsEventWithoutB64)).toEqual({ - headers: corsEventB64.headers, + headers: new Headers(corsEventB64.headers), method: 'OPTIONS', query: null, body: undefined, diff --git a/packages/graphql-server/src/functions/graphql.ts b/packages/graphql-server/src/functions/graphql.ts index e67e3e348774..9f69e0e337a8 100644 --- a/packages/graphql-server/src/functions/graphql.ts +++ b/packages/graphql-server/src/functions/graphql.ts @@ -13,7 +13,7 @@ import { useDisableIntrospection } from '@envelop/disable-introspection' import { useFilterAllowedOperations } from '@envelop/filter-operation-type' import { useParserCache } from '@envelop/parser-cache' import { useValidationCache } from '@envelop/validation-cache' -import { RedwoodError } from '@redwoodjs/api' +import { normalizeRequest, RedwoodError } from '@redwoodjs/api' import type { APIGatewayProxyEvent, APIGatewayProxyResult, @@ -23,12 +23,11 @@ import { GraphQLError, GraphQLSchema, OperationTypeNode } from 'graphql' import { getGraphQLParameters, processRequest, - Request, shouldRenderGraphiQL, } from 'graphql-helix' import { renderPlaygroundPage } from 'graphql-playground-html' -import { createCorsContext } from '../cors' +import { createCorsContext } from '@redwoodjs/api' import { makeDirectivesForPlugin } from '../directives/makeDirectives' import { getAsyncStoreInstance } from '../globalContext' import { createHealthcheckContext } from '../healthcheck' @@ -45,32 +44,6 @@ import { useRedwoodPopulateContext } from '../plugins/useRedwoodPopulateContext' import { ValidationError } from '../errors' import type { GraphQLHandlerOptions } from './types' -/** - * Extracts and parses body payload from event with base64 encoding check - * - */ -const parseEventBody = (event: APIGatewayProxyEvent) => { - if (!event.body) { - return - } - - if (event.isBase64Encoded) { - return JSON.parse(Buffer.from(event.body, 'base64').toString('utf-8')) - } else { - return JSON.parse(event.body) - } -} - -export function normalizeRequest(event: APIGatewayProxyEvent): Request { - const body = parseEventBody(event) - - return { - headers: event.headers || {}, - method: event.httpMethod, - query: event.queryStringParameters, - body, - } -} /* * Prevent unexpected error messages from leaking to the GraphQL clients. diff --git a/packages/graphql-server/src/functions/types.ts b/packages/graphql-server/src/functions/types.ts index 76039780629b..352a486fee98 100644 --- a/packages/graphql-server/src/functions/types.ts +++ b/packages/graphql-server/src/functions/types.ts @@ -5,10 +5,10 @@ import { IExecutableSchemaDefinition } from '@graphql-tools/schema' import type { APIGatewayProxyEvent, Context as LambdaContext } from 'aws-lambda' import type { AuthContextPayload } from '@redwoodjs/api' +import { CorsConfig } from '@redwoodjs/api' import { DirectiveGlobImports } from 'src/directives/makeDirectives' -import { CorsConfig } from '../cors' import { OnHealthcheckFn } from '../healthcheck' import { LoggerConfig } from '../plugins/useRedwoodLogger' import { SdlGlobImports, ServicesGlobImports } from '../types' diff --git a/packages/graphql-server/src/healthcheck.ts b/packages/graphql-server/src/healthcheck.ts index 9ca6c05875f0..360532a8d0cf 100644 --- a/packages/graphql-server/src/healthcheck.ts +++ b/packages/graphql-server/src/healthcheck.ts @@ -1,7 +1,7 @@ import type { APIGatewayProxyEvent } from 'aws-lambda' import { Request } from 'graphql-helix' -import { CorsContext } from './cors' +import { CorsContext } from '@redwoodjs/api' const HEALTH_CHECK_PATH = '/health' diff --git a/yarn.lock b/yarn.lock index d0ec7de6d77c..30476d8eea56 100644 --- a/yarn.lock +++ b/yarn.lock @@ -5730,6 +5730,7 @@ __metadata: jsonwebtoken: 8.5.1 jwks-rsa: 2.0.5 md5: 2.3.0 + node-fetch: 2.6.7 pascalcase: 1.0.0 pino: 7.8.0 split2: 4.1.0