diff --git a/src/sqs.test.ts b/src/sqs.test.ts index a9946f9..2abfd4c 100644 --- a/src/sqs.test.ts +++ b/src/sqs.test.ts @@ -300,7 +300,8 @@ describe('SQSMessageHandler', () => { body: JSON.stringify({ name: 'test-event-8' }), }, ]; - const messageHandler: SQSMessageAction<{ name: string }, any> = ( + + const errorMessageHandler: SQSMessageAction<{ name: string }, any> = ( ctx, message, ) => { @@ -310,6 +311,13 @@ describe('SQSMessageHandler', () => { } }; + const successMessageHandler: SQSMessageAction< + { name: string }, + any + > = () => { + return Promise.resolve(); + }; + test('throws on unprocessed events by default', async () => { expect.assertions(2); const handler = new SQSMessageHandler({ @@ -318,7 +326,7 @@ describe('SQSMessageHandler', () => { createRunContext: () => ({}), concurrency: 2, }) - .onMessage(messageHandler) + .onMessage(errorMessageHandler) .lambda(); try { @@ -347,7 +355,7 @@ describe('SQSMessageHandler', () => { // when using concurrency. concurrency: 2, }) - .onMessage(messageHandler) + .onMessage(errorMessageHandler) .lambda(); const result = await handler( @@ -357,14 +365,61 @@ describe('SQSMessageHandler', () => { {} as any, ); + const batchItemFailures = [ + { itemIdentifier: 'message-3' }, + { itemIdentifier: 'message-4' }, + { itemIdentifier: 'message-7' }, + { itemIdentifier: 'message-8' }, + ]; + expect(result).toEqual({ - batchItemFailures: [ - { itemIdentifier: 'message-3' }, - { itemIdentifier: 'message-4' }, - { itemIdentifier: 'message-7' }, - { itemIdentifier: 'message-8' }, - ], + batchItemFailures, }); + expect(logger.info).not.toHaveBeenCalledWith( + 'Successfully processed all SQS messages', + ); + expect(logger.info).toHaveBeenCalledWith( + { + batchItemFailures, + }, + 'Sending SQS partial batch response', + ); + }); + + test('returns nothing when all events are processed successfully', async () => { + const handler = new SQSMessageHandler({ + logger, + parseMessage: testSerializer.parseMessage, + createRunContext: () => ({}), + usePartialBatchResponses: true, + // Make sure partial batch responses are returned in order even + // when using concurrency. + concurrency: 2, + }) + .onMessage(successMessageHandler) + .lambda(); + + const result = await handler( + { + Records: records, + } as any, + {} as any, + ); + + expect(result).toEqual(undefined); + expect(logger.error).not.toHaveBeenCalledWith( + expect.any(Object), + 'Failed to fully process message group', + ); + expect(logger.info).toHaveBeenCalledWith( + 'Successfully processed all SQS messages', + ); + expect(logger.info).not.toHaveBeenCalledWith( + { + batchItemFailures: expect.any(Array), + }, + 'Sending SQS partial batch response', + ); }); }); diff --git a/src/sqs.ts b/src/sqs.ts index a7a8d83..a4a0ba5 100644 --- a/src/sqs.ts +++ b/src/sqs.ts @@ -199,21 +199,23 @@ export class SQSMessageHandler { }, ); - if (!processingResult.hasUnprocessedRecords) { + const unprocessedRecordsByGroupIdEntries = Object.entries( + processingResult.unprocessedRecordsByGroupId, + ); + + if (!unprocessedRecordsByGroupIdEntries.length) { context.logger.info('Successfully processed all SQS messages'); + return; } if (!this.config.usePartialBatchResponses) { processingResult.throwOnUnprocessedRecords(); - return; } // SQS partial batching expects that you return an ordered list of // failures. We map through each group and add them to the batch item // failures in order for each group. - const batchItemFailures = Object.entries( - processingResult.unprocessedRecords, - ) + const batchItemFailures = unprocessedRecordsByGroupIdEntries .map(([groupId, record]) => { const [failedRecord, ...subsequentUnprocessedRecords] = record.items; context.logger.error( diff --git a/src/utils.ts b/src/utils.ts index 3a199b4..8b23611 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -2,7 +2,6 @@ import { LoggerInterface } from '@lifeomic/logging'; import { Context } from 'aws-lambda'; import pMap from 'p-map'; import groupBy from 'lodash/groupBy'; -import zipObject from 'lodash/zipObject'; export type BaseContext = { logger: LoggerInterface; @@ -91,27 +90,30 @@ export const processWithOrdering = async ( process: (item: T) => Promise, ) => { const groupedItems = groupBy(params.items, params.orderBy); - const listIds = Object.keys(groupedItems); - const lists = Object.values(groupedItems); - const unprocessedRecordsByListId = zipObject<{ error: any; items: T[] }>( - listIds, - lists.map(() => ({ error: null, items: [] })), - ); + const groupIds = Object.keys(groupedItems); + const groups = Object.values(groupedItems); + const unprocessedRecordsByGroupId: Record< + string, + { + error: any; + items: T[]; + } + > = {}; await pMap( - lists, - async (list, listIndex) => { - for (let i = 0; i < list.length; i++) { - const item = list[i]; + groups, + async (group, groupIndex) => { + for (let i = 0; i < group.length; i++) { + const item = group[i]; try { await process(item); } catch (error) { // Keep track of all unprocessed items and stop processing the current - // list as soon as we encounter the first error. - unprocessedRecordsByListId[listIds[listIndex]] = { + // group as soon as we encounter the first error. + unprocessedRecordsByGroupId[groupIds[groupIndex]] = { error, - items: list.slice(i), + items: group.slice(i), }; return; } @@ -122,15 +124,13 @@ export const processWithOrdering = async ( }, ); - const aggregateErrors = Object.values(unprocessedRecordsByListId) - .map((record) => record.error) - .filter(Boolean) - .flat(); - return { - hasUnprocessedRecords: aggregateErrors.length > 0, - unprocessedRecords: unprocessedRecordsByListId, + unprocessedRecordsByGroupId, throwOnUnprocessedRecords: () => { + const aggregateErrors = Object.values(unprocessedRecordsByGroupId).map( + (record) => record.error, + ); + if (aggregateErrors.length) { throw new AggregateError(aggregateErrors); }