diff --git a/src/core.ts b/src/core.ts index 70a00da..f40cb3a 100644 --- a/src/core.ts +++ b/src/core.ts @@ -1,7 +1,7 @@ import * as qs from 'qs'; import { VERSION } from './version'; import { Stream } from './streaming'; -import { APIError, APIConnectionError, APIConnectionTimeoutError } from './error'; +import { APIError, APIConnectionError, APIConnectionTimeoutError, APIUserAbortError } from './error'; import type { Readable } from '@tryfinch/finch-api/_shims/node-readable'; import { getDefaultAgent, type Agent } from '@tryfinch/finch-api/_shims/agent'; import { @@ -183,6 +183,9 @@ export abstract class APIClient { ...(body && { body: body as any }), headers: reqHeaders, ...(httpAgent && { agent: httpAgent }), + // @ts-ignore node-fetch uses a custom AbortSignal type that is + // not compatible with standard web types + signal: options.signal ?? null, }; this.validateHeaders(reqHeaders, headers); @@ -220,8 +223,15 @@ export abstract class APIClient { const response = await this.fetchWithTimeout(url, req, timeout, controller).catch(castToError); if (response instanceof Error) { - if (retriesRemaining) return this.retryRequest(options, retriesRemaining); - if (response.name === 'AbortError') throw new APIConnectionTimeoutError(); + if (options.signal?.aborted) { + throw new APIUserAbortError(); + } + if (retriesRemaining) { + return this.retryRequest(options, retriesRemaining); + } + if (response.name === 'AbortError') { + throw new APIConnectionTimeoutError(); + } throw new APIConnectionError({ cause: response }); } @@ -561,6 +571,7 @@ export type RequestOptions | Readable> stream?: boolean | undefined; timeout?: number; httpAgent?: Agent; + signal?: AbortSignal | undefined | null; idempotencyKey?: string; }; @@ -578,6 +589,7 @@ const requestOptionsKeys: KeysEnum = { stream: true, timeout: true, httpAgent: true, + signal: true, idempotencyKey: true, }; diff --git a/src/error.ts b/src/error.ts index fa360e6..45971af 100644 --- a/src/error.ts +++ b/src/error.ts @@ -67,6 +67,14 @@ export class APIError extends Error { } } +export class APIUserAbortError extends APIError { + override readonly status: undefined = undefined; + + constructor({ message }: { message?: string } = {}) { + super(undefined, undefined, message || 'Request was aborted.', undefined); + } +} + export class APIConnectionError extends APIError { override readonly status: undefined = undefined; diff --git a/src/index.ts b/src/index.ts index f99ac43..982c278 100644 --- a/src/index.ts +++ b/src/index.ts @@ -201,6 +201,7 @@ export class Finch extends Core.APIClient { static APIError = Errors.APIError; static APIConnectionError = Errors.APIConnectionError; static APIConnectionTimeoutError = Errors.APIConnectionTimeoutError; + static APIUserAbortError = Errors.APIUserAbortError; static NotFoundError = Errors.NotFoundError; static ConflictError = Errors.ConflictError; static RateLimitError = Errors.RateLimitError; @@ -215,6 +216,7 @@ export const { APIError, APIConnectionError, APIConnectionTimeoutError, + APIUserAbortError, NotFoundError, ConflictError, RateLimitError, diff --git a/tests/index.test.ts b/tests/index.test.ts index 2fab3ea..dcf8771 100644 --- a/tests/index.test.ts +++ b/tests/index.test.ts @@ -1,8 +1,9 @@ // File generated from our OpenAPI spec by Stainless. -import { Headers } from '@tryfinch/finch-api/core'; import Finch from '@tryfinch/finch-api'; -import { Response } from '@tryfinch/finch-api/_shims/fetch'; +import { APIUserAbortError } from '@tryfinch/finch-api'; +import { Headers } from '@tryfinch/finch-api/core'; +import { Response, fetch as defaultFetch } from '@tryfinch/finch-api/_shims/fetch'; describe('instantiate client', () => { const env = process.env; @@ -95,6 +96,32 @@ describe('instantiate client', () => { expect(response).toEqual({ url: 'http://localhost:5000/foo', custom: true }); }); + test('custom signal', async () => { + const client = new Finch({ + baseURL: 'http://127.0.0.1:4010', + accessToken: 'my access token', + fetch: (...args) => { + return new Promise((resolve, reject) => + setTimeout( + () => + defaultFetch(...args) + .then(resolve) + .catch(reject), + 300, + ), + ); + }, + }); + + const controller = new AbortController(); + setTimeout(() => controller.abort(), 200); + + const spy = jest.spyOn(client, 'request'); + + await expect(client.get('/foo', { signal: controller.signal })).rejects.toThrowError(APIUserAbortError); + expect(spy).toHaveBeenCalledTimes(1); + }); + describe('baseUrl', () => { test('trailing slash', () => { const client = new Finch({