diff --git a/src/jest.js b/src/jest.js index dc735dd..205e98e 100644 --- a/src/jest.js +++ b/src/jest.js @@ -15,6 +15,7 @@ const { getCallerLocation, installLocationInNextTest } = createCallerLocationHoo expect.extend(matchers) let defaultTimeout = jestConfig().testTimeout // overridable via jest.setTimeout() +const defaultConcurrency = jestConfig().maxConcurrency function parseArgs(list, targs) { if (!(Object.isFrozen(list) && list.length === targs.length + 1)) return list // template check @@ -79,7 +80,33 @@ const makeEach = const forceExit = process.execArgv.map((x) => x.replaceAll('_', '-')).includes('--test-force-exit') -const describe = (...args) => nodeDescribe(...args) +const inConcurrent = [] +const concurrent = [] +const describe = (...args) => { + const fn = args.pop() + const optionsConcurrent = args?.at(-1)?.concurrency > 1 + if (optionsConcurrent) inConcurrent.push(fn) + const res = nodeDescribe(...args, async () => { + const res = fn() + + // We do only block-level concurrency, not file-level + if (concurrent.length === 1) { + test(...concurrent[0]) + concurrent.length = 0 + } else if (concurrent.length > 0) { + const queue = [...concurrent] + concurrent.length = 0 + describe('concurrent', { concurrency: defaultConcurrency }, () => { + for (const args of queue) test(...args) + }) + } + + return res + }) + if (optionsConcurrent) inConcurrent.pop() + return res +} + const test = (name, fn, testTimeout) => { const timeout = testTimeout ?? defaultTimeout installLocationInNextTest(getCallerLocation()) @@ -99,6 +126,8 @@ Also, using expect.assertions() to ensure the planned number of assertions is be describe.each = makeEach(describe) test.each = makeEach(test) +test.concurrent = (...args) => (inConcurrent.length > 0 ? test(...args) : concurrent.push(args)) +test.concurrent.each = makeEach(test.concurrent) describe.skip = (...args) => nodeDescribe.skip(...args) test.skip = (...args) => nodeTest.skip(...args)