Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: prevent instanceof handler check failures between different MSW versions #2349

Merged
merged 6 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import { XMLHttpRequestInterceptor } from '@mswjs/interceptors/XMLHttpRequest'
import { SetupWorkerInternalContext, StartOptions } from '../glossary'
import type { RequiredDeep } from '~/core/typeUtils'
import { handleRequest } from '~/core/utils/handleRequest'
import { isHandlerKind } from '~/core/utils/internal/isHandlerKind'

export function createFallbackRequestListener(
context: SetupWorkerInternalContext,
Expand All @@ -24,7 +25,7 @@ export function createFallbackRequestListener(
const response = await handleRequest(
request,
requestId,
context.getRequestHandlers(),
context.getRequestHandlers().filter(isHandlerKind('RequestHandler')),
options,
context.emitter,
{
Expand Down
9 changes: 2 additions & 7 deletions src/browser/setupWorker/start/createRequestListener.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@ import {
} from './utils/createMessageChannel'
import { parseWorkerRequest } from '../../utils/parseWorkerRequest'
import { RequestHandler } from '~/core/handlers/RequestHandler'
import { HttpHandler } from '~/core/handlers/HttpHandler'
import { GraphQLHandler } from '~/core/handlers/GraphQLHandler'
import { handleRequest } from '~/core/utils/handleRequest'
import { RequiredDeep } from '~/core/typeUtils'
import { devUtils } from '~/core/utils/internal/devUtils'
import { toResponseInit } from '~/core/utils/toResponseInit'
import { isHandlerKind } from '~/core/utils/internal/isHandlerKind'

export const createRequestListener = (
context: SetupWorkerInternalContext,
Expand Down Expand Up @@ -45,11 +44,7 @@ export const createRequestListener = (
await handleRequest(
request,
requestId,
context.getRequestHandlers().filter((handler) => {
return (
handler instanceof HttpHandler || handler instanceof GraphQLHandler
)
}),
context.getRequestHandlers().filter(isHandlerKind('RequestHandler')),
options,
context.emitter,
{
Expand Down
4 changes: 4 additions & 0 deletions src/core/handlers/RequestHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {
import type { ResponseResolutionContext } from '../utils/executeHandlers'
import type { MaybePromise } from '../typeUtils'
import { StrictRequest, StrictResponse } from '..//HttpResponse'
import type { HandlerKind } from './common'

export type DefaultRequestMultipartBody = Record<
string,
Expand Down Expand Up @@ -117,6 +118,8 @@ export abstract class RequestHandler<
StrictRequest<DefaultBodyType>
>()

private readonly __kind: HandlerKind

public info: HandlerInfo & RequestHandlerInternalInfo
/**
* Indicates whether this request handler has been used
Expand Down Expand Up @@ -151,6 +154,7 @@ export abstract class RequestHandler<
}

this.isUsed = false
this.__kind = 'RequestHandler'
}

/**
Expand Down
4 changes: 4 additions & 0 deletions src/core/handlers/WebSocketHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
matchRequestUrl,
} from '../utils/matching/matchRequestUrl'
import { getCallFrame } from '../utils/internal/getCallFrame'
import type { HandlerKind } from './common'

type WebSocketHandlerParsedResult = {
match: Match
Expand All @@ -28,6 +29,8 @@ const kStopPropagationPatched = Symbol('kStopPropagationPatched')
const KOnStopPropagation = Symbol('KOnStopPropagation')

export class WebSocketHandler {
private readonly __kind: HandlerKind

public id: string
public callFrame?: string

Expand All @@ -38,6 +41,7 @@ export class WebSocketHandler {

this[kEmitter] = new Emitter()
this.callFrame = getCallFrame(new Error())
this.__kind = 'EventHandler'
}

public parse(args: {
Expand Down
1 change: 1 addition & 0 deletions src/core/handlers/common.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
export type HandlerKind = 'RequestHandler' | 'EventHandler'
6 changes: 1 addition & 5 deletions src/core/utils/executeHandlers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ export interface ResponseResolutionContext {
* Returns the execution result object containing any matching request
* handler and any mocked response it returned.
*/
export const executeHandlers = async <Handlers extends Array<unknown>>({
export const executeHandlers = async <Handlers extends Array<RequestHandler>>({
request,
requestId,
handlers,
Expand All @@ -33,10 +33,6 @@ export const executeHandlers = async <Handlers extends Array<unknown>>({
let result: RequestHandlerExecutionResult<any> | null = null

for (const handler of handlers) {
if (!(handler instanceof RequestHandler)) {
continue
}

result = await handler.run({ request, requestId, resolutionContext })

// If the handler produces some result for this request,
Expand Down
64 changes: 64 additions & 0 deletions src/core/utils/internal/isHandlerKind.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import { GraphQLHandler } from '../../handlers/GraphQLHandler'
import { HttpHandler } from '../../handlers/HttpHandler'
import { RequestHandler } from '../../handlers/RequestHandler'
import { WebSocketHandler } from '../../handlers/WebSocketHandler'
import { isHandlerKind } from './isHandlerKind'

it('returns true if expected a request handler and given a request handler', () => {
expect(
isHandlerKind('RequestHandler')(new HttpHandler('*', '*', () => {})),
).toBe(true)

expect(
isHandlerKind('RequestHandler')(
new GraphQLHandler('all', '*', '*', () => {}),
),
).toBe(true)
})

it('returns true if expected a request handler and given a custom request handler', () => {
class MyHandler extends RequestHandler {
constructor() {
super({ info: { header: '*' }, resolver: () => {} })
}
predicate = () => false
log() {}
}

expect(isHandlerKind('RequestHandler')(new MyHandler())).toBe(true)
})

it('returns false if expected a request handler but given event handler', () => {
expect(isHandlerKind('RequestHandler')(new WebSocketHandler('*'))).toBe(false)
})

it('returns false if expected a request handler but given arbitrary object', () => {
expect(isHandlerKind('RequestHandler')(undefined)).toBe(false)
expect(isHandlerKind('RequestHandler')(null)).toBe(false)
expect(isHandlerKind('RequestHandler')({})).toBe(false)
expect(isHandlerKind('RequestHandler')([])).toBe(false)
expect(isHandlerKind('RequestHandler')(123)).toBe(false)
expect(isHandlerKind('RequestHandler')('hello')).toBe(false)
})

it('returns true if expected an event handler and given an event handler', () => {
expect(isHandlerKind('EventHandler')(new WebSocketHandler('*'))).toBe(true)
})

it('returns true if expected an event handler and given a custom event handler', () => {
class MyEventHandler extends WebSocketHandler {
constructor() {
super('*')
}
}
expect(isHandlerKind('EventHandler')(new MyEventHandler())).toBe(true)
})

it('returns false if expected an event handler but given arbitrary object', () => {
expect(isHandlerKind('EventHandler')(undefined)).toBe(false)
expect(isHandlerKind('EventHandler')(null)).toBe(false)
expect(isHandlerKind('EventHandler')({})).toBe(false)
expect(isHandlerKind('EventHandler')([])).toBe(false)
expect(isHandlerKind('EventHandler')(123)).toBe(false)
expect(isHandlerKind('EventHandler')('hello')).toBe(false)
})
21 changes: 21 additions & 0 deletions src/core/utils/internal/isHandlerKind.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import type { HandlerKind } from '../../handlers/common'
import type { RequestHandler } from '../../handlers/RequestHandler'
import type { WebSocketHandler } from '../../handlers/WebSocketHandler'

/**
* A filter function that ensures that the provided argument
* is a handler of the given kind. This helps differentiate
* between different kinds of handlers, e.g. request and event handlers.
*/
export function isHandlerKind<K extends HandlerKind>(kind: K) {
return (
input: unknown,
): input is K extends 'EventHandler' ? WebSocketHandler : RequestHandler => {
return (
input != null &&
typeof input === 'object' &&
'__kind' in input &&
input.__kind === kind
)
}
}
3 changes: 2 additions & 1 deletion src/core/ws/handleWebSocketEvent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
onUnhandledRequest,
UnhandledRequestStrategy,
} from '../utils/request/onUnhandledRequest'
import { isHandlerKind } from '../utils/internal/isHandlerKind'

interface HandleWebSocketEventOptions {
getUnhandledRequestStrategy: () => UnhandledRequestStrategy
Expand All @@ -30,7 +31,7 @@ export function handleWebSocketEvent(options: HandleWebSocketEventOptions) {

for (const handler of handlers) {
if (
handler instanceof WebSocketHandler &&
isHandlerKind('EventHandler')(handler) &&
handler.predicate({
event: connectionEvent,
parsedResult: handler.parse({
Expand Down
12 changes: 4 additions & 8 deletions src/node/SetupServerCommonApi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,13 @@ import type { LifeCycleEventsMap, SharedOptions } from '~/core/sharedOptions'
import { SetupApi } from '~/core/SetupApi'
import { handleRequest } from '~/core/utils/handleRequest'
import type { RequestHandler } from '~/core/handlers/RequestHandler'
import { HttpHandler } from '~/core/handlers/HttpHandler'
import { GraphQLHandler } from '~/core/handlers/GraphQLHandler'
import type { WebSocketHandler } from '~/core/handlers/WebSocketHandler'
import { mergeRight } from '~/core/utils/internal/mergeRight'
import { InternalError, devUtils } from '~/core/utils/internal/devUtils'
import type { SetupServerCommon } from './glossary'
import { handleWebSocketEvent } from '~/core/ws/handleWebSocketEvent'
import { webSocketInterceptor } from '~/core/ws/webSocketInterceptor'
import { isHandlerKind } from '~/core/utils/internal/isHandlerKind'

export const DEFAULT_LISTEN_OPTIONS: RequiredDeep<SharedOptions> = {
onUnhandledRequest: 'warn',
Expand Down Expand Up @@ -63,12 +62,9 @@ export class SetupServerCommonApi
const response = await handleRequest(
request,
requestId,
this.handlersController.currentHandlers().filter((handler) => {
return (
handler instanceof HttpHandler ||
handler instanceof GraphQLHandler
)
}),
this.handlersController
.currentHandlers()
.filter(isHandlerKind('RequestHandler')),
this.resolvedOptions,
this.emitter,
)
Expand Down
Loading