diff --git a/src/browser/setupWorker/start/createFallbackRequestListener.ts b/src/browser/setupWorker/start/createFallbackRequestListener.ts index 37a3473ef..1afee6f8d 100644 --- a/src/browser/setupWorker/start/createFallbackRequestListener.ts +++ b/src/browser/setupWorker/start/createFallbackRequestListener.ts @@ -8,7 +8,7 @@ import { XMLHttpRequestInterceptor } from '@mswjs/interceptors/XMLHttpRequest' import { SetupWorkerInternalContext, StartOptions } from '../glossary' import type { RequiredDeep } from '~/core/typeUtils' import { handleRequest } from '~/core/utils/handleRequest' -import { toRequestHandlersOnly } from '~/core/utils/internal/toRequestHandlersOnly' +import { isHandlerKind } from '~/core/utils/internal/isHandlerKind' export function createFallbackRequestListener( context: SetupWorkerInternalContext, @@ -25,7 +25,7 @@ export function createFallbackRequestListener( const response = await handleRequest( request, requestId, - context.getRequestHandlers().filter(toRequestHandlersOnly), + context.getRequestHandlers().filter(isHandlerKind('RequestHandler')), options, context.emitter, { diff --git a/src/browser/setupWorker/start/createRequestListener.ts b/src/browser/setupWorker/start/createRequestListener.ts index 9e2a7aee0..ec96603ae 100644 --- a/src/browser/setupWorker/start/createRequestListener.ts +++ b/src/browser/setupWorker/start/createRequestListener.ts @@ -13,7 +13,7 @@ import { handleRequest } from '~/core/utils/handleRequest' import { RequiredDeep } from '~/core/typeUtils' import { devUtils } from '~/core/utils/internal/devUtils' import { toResponseInit } from '~/core/utils/toResponseInit' -import { toRequestHandlersOnly } from '~/core/utils/internal/toRequestHandlersOnly' +import { isHandlerKind } from '~/core/utils/internal/isHandlerKind' export const createRequestListener = ( context: SetupWorkerInternalContext, @@ -44,7 +44,7 @@ export const createRequestListener = ( await handleRequest( request, requestId, - context.getRequestHandlers().filter(toRequestHandlersOnly), + context.getRequestHandlers().filter(isHandlerKind('RequestHandler')), options, context.emitter, { diff --git a/src/core/handlers/RequestHandler.ts b/src/core/handlers/RequestHandler.ts index bd610ea04..0a5e6f83d 100644 --- a/src/core/handlers/RequestHandler.ts +++ b/src/core/handlers/RequestHandler.ts @@ -7,6 +7,7 @@ import { import type { ResponseResolutionContext } from '../utils/executeHandlers' import type { MaybePromise } from '../typeUtils' import { StrictRequest, StrictResponse } from '..//HttpResponse' +import type { HandlerKind } from './common' export type DefaultRequestMultipartBody = Record< string, @@ -117,7 +118,7 @@ export abstract class RequestHandler< StrictRequest >() - private readonly __kind: 'RequestHandler' + private readonly __kind: HandlerKind public info: HandlerInfo & RequestHandlerInternalInfo /** diff --git a/src/core/handlers/WebSocketHandler.ts b/src/core/handlers/WebSocketHandler.ts index 6413640a8..f37f1bd6f 100644 --- a/src/core/handlers/WebSocketHandler.ts +++ b/src/core/handlers/WebSocketHandler.ts @@ -8,6 +8,7 @@ import { matchRequestUrl, } from '../utils/matching/matchRequestUrl' import { getCallFrame } from '../utils/internal/getCallFrame' +import type { HandlerKind } from './common' type WebSocketHandlerParsedResult = { match: Match @@ -28,7 +29,7 @@ const kStopPropagationPatched = Symbol('kStopPropagationPatched') const KOnStopPropagation = Symbol('KOnStopPropagation') export class WebSocketHandler { - private readonly __kind: 'WebSocketHandler' + private readonly __kind: HandlerKind public id: string public callFrame?: string @@ -40,7 +41,7 @@ export class WebSocketHandler { this[kEmitter] = new Emitter() this.callFrame = getCallFrame(new Error()) - this.__kind = 'WebSocketHandler' + this.__kind = 'EventHandler' } public parse(args: { diff --git a/src/core/handlers/common.ts b/src/core/handlers/common.ts new file mode 100644 index 000000000..ef0d1018a --- /dev/null +++ b/src/core/handlers/common.ts @@ -0,0 +1 @@ +export type HandlerKind = 'RequestHandler' | 'EventHandler' diff --git a/src/core/utils/internal/isHandlerKind.test.ts b/src/core/utils/internal/isHandlerKind.test.ts new file mode 100644 index 000000000..84486fbe9 --- /dev/null +++ b/src/core/utils/internal/isHandlerKind.test.ts @@ -0,0 +1,64 @@ +import { GraphQLHandler } from '../../handlers/GraphQLHandler' +import { HttpHandler } from '../../handlers/HttpHandler' +import { RequestHandler } from '../../handlers/RequestHandler' +import { WebSocketHandler } from '../../handlers/WebSocketHandler' +import { isHandlerKind } from './isHandlerKind' + +it('returns true if expected a request handler and given a request handler', () => { + expect( + isHandlerKind('RequestHandler')(new HttpHandler('*', '*', () => {})), + ).toBe(true) + + expect( + isHandlerKind('RequestHandler')( + new GraphQLHandler('all', '*', '*', () => {}), + ), + ).toBe(true) +}) + +it('returns true if expected a request handler and given a custom request handler', () => { + class MyHandler extends RequestHandler { + constructor() { + super({ info: { header: '*' }, resolver: () => {} }) + } + predicate = () => false + log() {} + } + + expect(isHandlerKind('RequestHandler')(new MyHandler())).toBe(true) +}) + +it('returns false if expected a request handler but given event handler', () => { + expect(isHandlerKind('RequestHandler')(new WebSocketHandler('*'))).toBe(false) +}) + +it('returns false if expected a request handler but given arbitrary object', () => { + expect(isHandlerKind('RequestHandler')(undefined)).toBe(false) + expect(isHandlerKind('RequestHandler')(null)).toBe(false) + expect(isHandlerKind('RequestHandler')({})).toBe(false) + expect(isHandlerKind('RequestHandler')([])).toBe(false) + expect(isHandlerKind('RequestHandler')(123)).toBe(false) + expect(isHandlerKind('RequestHandler')('hello')).toBe(false) +}) + +it('returns true if expected an event handler and given an event handler', () => { + expect(isHandlerKind('EventHandler')(new WebSocketHandler('*'))).toBe(true) +}) + +it('returns true if expected an event handler and given a custom event handler', () => { + class MyEventHandler extends WebSocketHandler { + constructor() { + super('*') + } + } + expect(isHandlerKind('EventHandler')(new MyEventHandler())).toBe(true) +}) + +it('returns false if expected an event handler but given arbitrary object', () => { + expect(isHandlerKind('EventHandler')(undefined)).toBe(false) + expect(isHandlerKind('EventHandler')(null)).toBe(false) + expect(isHandlerKind('EventHandler')({})).toBe(false) + expect(isHandlerKind('EventHandler')([])).toBe(false) + expect(isHandlerKind('EventHandler')(123)).toBe(false) + expect(isHandlerKind('EventHandler')('hello')).toBe(false) +}) diff --git a/src/core/utils/internal/isHandlerKind.ts b/src/core/utils/internal/isHandlerKind.ts new file mode 100644 index 000000000..d877bc847 --- /dev/null +++ b/src/core/utils/internal/isHandlerKind.ts @@ -0,0 +1,21 @@ +import type { HandlerKind } from '../../handlers/common' +import type { RequestHandler } from '../../handlers/RequestHandler' +import type { WebSocketHandler } from '../../handlers/WebSocketHandler' + +/** + * A filter function that ensures that the provided argument + * is a handler of the given kind. This helps differentiate + * between different kinds of handlers, e.g. request and event handlers. + */ +export function isHandlerKind(kind: K) { + return ( + input: unknown, + ): input is K extends 'EventHandler' ? WebSocketHandler : RequestHandler => { + return ( + input != null && + typeof input === 'object' && + '__kind' in input && + input.__kind === kind + ) + } +} diff --git a/src/core/utils/internal/toRequestHandlersOnly.test.ts b/src/core/utils/internal/toRequestHandlersOnly.test.ts deleted file mode 100644 index 5bbfb638e..000000000 --- a/src/core/utils/internal/toRequestHandlersOnly.test.ts +++ /dev/null @@ -1,40 +0,0 @@ -import { GraphQLHandler } from '../../handlers/GraphQLHandler' -import { HttpHandler } from '../../handlers/HttpHandler' -import { RequestHandler } from '../../handlers/RequestHandler' -import { WebSocketHandler } from '../../handlers/WebSocketHandler' -import { toRequestHandlersOnly } from './toRequestHandlersOnly' - -it('returns true for HttpHandler', () => { - expect(toRequestHandlersOnly(new HttpHandler('*', '*', () => {}))).toBe(true) -}) - -it('returns true for GraphQLHandler', () => { - expect( - toRequestHandlersOnly(new GraphQLHandler('all', '*', '*', () => {})), - ).toBe(true) -}) - -it('returns true for a custom RequestHandler', () => { - class MyHandler extends RequestHandler { - constructor() { - super({ info: { header: '*' }, resolver: () => {} }) - } - predicate = () => false - log() {} - } - - expect(toRequestHandlersOnly(new MyHandler())).toBe(true) -}) - -it('returns false for a WebSocketHandler', () => { - expect(toRequestHandlersOnly(new WebSocketHandler('*'))).toBe(false) -}) - -it('returns false for an arbitrary values', () => { - expect(toRequestHandlersOnly(undefined)).toBe(false) - expect(toRequestHandlersOnly(null)).toBe(false) - expect(toRequestHandlersOnly({})).toBe(false) - expect(toRequestHandlersOnly([])).toBe(false) - expect(toRequestHandlersOnly(123)).toBe(false) - expect(toRequestHandlersOnly('hello')).toBe(false) -}) diff --git a/src/core/utils/internal/toRequestHandlersOnly.ts b/src/core/utils/internal/toRequestHandlersOnly.ts deleted file mode 100644 index 7f003fe2b..000000000 --- a/src/core/utils/internal/toRequestHandlersOnly.ts +++ /dev/null @@ -1,15 +0,0 @@ -import { RequestHandler } from '../../handlers/RequestHandler' - -/** - * A filter function that ensures that the provided argument - * is an instance of a `RequestHandler` class. This helps filter - * out other handlers, like `WebSocketHandler`. - */ -export function toRequestHandlersOnly(input: unknown): input is RequestHandler { - return ( - input != null && - typeof input === 'object' && - '__kind' in input && - input.__kind === 'RequestHandler' - ) -} diff --git a/src/core/ws/handleWebSocketEvent.ts b/src/core/ws/handleWebSocketEvent.ts index a20bd6ec4..919ae0e89 100644 --- a/src/core/ws/handleWebSocketEvent.ts +++ b/src/core/ws/handleWebSocketEvent.ts @@ -6,6 +6,7 @@ import { onUnhandledRequest, UnhandledRequestStrategy, } from '../utils/request/onUnhandledRequest' +import { isHandlerKind } from '../utils/internal/isHandlerKind' interface HandleWebSocketEventOptions { getUnhandledRequestStrategy: () => UnhandledRequestStrategy @@ -30,7 +31,7 @@ export function handleWebSocketEvent(options: HandleWebSocketEventOptions) { for (const handler of handlers) { if ( - handler instanceof WebSocketHandler && + isHandlerKind('EventHandler')(handler) && handler.predicate({ event: connectionEvent, parsedResult: handler.parse({ diff --git a/src/node/SetupServerCommonApi.ts b/src/node/SetupServerCommonApi.ts index 8d50e5c27..0d2104119 100644 --- a/src/node/SetupServerCommonApi.ts +++ b/src/node/SetupServerCommonApi.ts @@ -20,7 +20,7 @@ import { InternalError, devUtils } from '~/core/utils/internal/devUtils' import type { SetupServerCommon } from './glossary' import { handleWebSocketEvent } from '~/core/ws/handleWebSocketEvent' import { webSocketInterceptor } from '~/core/ws/webSocketInterceptor' -import { toRequestHandlersOnly } from '~/core/utils/internal/toRequestHandlersOnly' +import { isHandlerKind } from '~/core/utils/internal/isHandlerKind' export const DEFAULT_LISTEN_OPTIONS: RequiredDeep = { onUnhandledRequest: 'warn', @@ -64,7 +64,7 @@ export class SetupServerCommonApi requestId, this.handlersController .currentHandlers() - .filter(toRequestHandlersOnly), + .filter(isHandlerKind('RequestHandler')), this.resolvedOptions, this.emitter, )