Skip to content

Commit

Permalink
fix: fix worker node cross tasks stealing
Browse files Browse the repository at this point in the history
Signed-off-by: Jérôme Benoit <jerome.benoit@piment-noir.org>
  • Loading branch information
jerome-benoit committed Dec 24, 2023
1 parent e7fc096 commit 5bfb3e8
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 20 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ and this project adheres to

## [Unreleased]

### Fixed

- Avoid worker node cross tasks stealing.
- Ensure only half the pool worker nodes can steal tasks.

## [0.1.9] - 2023-12-22

### Changed
Expand Down
39 changes: 33 additions & 6 deletions src/pools/abstract-pool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,14 @@ export abstract class AbstractPool<
: accumulator,
0,
),
...(this.opts.enableTasksQueue === true &&
{
stealingWorkerNodes: this.workerNodes.reduce(
(accumulator, workerNode) =>
workerNode.info.stealing ? accumulator + 1 : accumulator,
0,
),
}),
busyWorkerNodes: this.workerNodes.reduce(
(accumulator, _workerNode, workerNodeKey) =>
this.isWorkerNodeBusy(workerNodeKey) ? accumulator + 1 : accumulator,
Expand Down Expand Up @@ -1375,6 +1383,10 @@ export abstract class AbstractPool<
})
}

private cannotStealTask(): boolean {
return this.workerNodes.length <= 1 || this.info.queuedTasks === 0
}

private handleTask(workerNodeKey: number, task: Task<Data>): void {
if (this.shallExecuteTask(workerNodeKey)) {
this.executeTask(workerNodeKey, task)
Expand All @@ -1387,7 +1399,7 @@ export abstract class AbstractPool<
if (workerNodeKey === -1) {
return
}
if (this.workerNodes.length <= 1) {
if (this.cannotStealTask()) {
return
}
while (this.tasksQueueSize(workerNodeKey) > 0) {
Expand Down Expand Up @@ -1481,22 +1493,29 @@ export abstract class AbstractPool<
event: CustomEvent<WorkerNodeEventDetail>,
previousStolenTask?: Task<Data>,
): void => {
if (this.workerNodes.length <= 1) {
return
}
const { workerNodeKey } = event.detail
if (workerNodeKey == null) {
throw new Error(
'WorkerNode event detail workerNodeKey attribute must be defined',
'WorkerNode event detail workerNodeKey property must be defined',
)
}
if (
this.cannotStealTask() || (this.info.stealingWorkerNodes as number) >
Math.floor(this.workerNodes.length / 2)
) {
if (previousStolenTask != null) {
this.getWorkerInfo(workerNodeKey).stealing = false
}
return
}
const workerNodeTasksUsage = this.workerNodes[workerNodeKey].usage.tasks
if (
previousStolenTask != null &&
workerNodeTasksUsage.sequentiallyStolen > 0 &&
(workerNodeTasksUsage.executing > 0 ||
this.tasksQueueSize(workerNodeKey) > 0)
) {
this.getWorkerInfo(workerNodeKey).stealing = false
for (
const taskName of this.workerNodes[workerNodeKey].info
.taskFunctionNames as string[]
Expand All @@ -1511,6 +1530,7 @@ export abstract class AbstractPool<
)
return
}
this.getWorkerInfo(workerNodeKey).stealing = true
const stolenTask = this.workerNodeStealTask(workerNodeKey)
if (
this.shallUpdateTaskFunctionWorkerUsage(workerNodeKey) &&
Expand Down Expand Up @@ -1556,6 +1576,7 @@ export abstract class AbstractPool<
const sourceWorkerNode = workerNodes.find(
(sourceWorkerNode, sourceWorkerNodeKey) =>
sourceWorkerNode.info.ready &&
!sourceWorkerNode.info.stealing &&
sourceWorkerNodeKey !== workerNodeKey &&
sourceWorkerNode.usage.tasks.queued > 0,
)
Expand All @@ -1576,7 +1597,10 @@ export abstract class AbstractPool<
private readonly handleBackPressureEvent = (
event: CustomEvent<WorkerNodeEventDetail>,
): void => {
if (this.workerNodes.length <= 1) {
if (
this.cannotStealTask() || (this.info.stealingWorkerNodes as number) >
Math.floor(this.workerNodes.length / 2)
) {
return
}
const { workerId } = event.detail
Expand All @@ -1596,16 +1620,19 @@ export abstract class AbstractPool<
if (
sourceWorkerNode.usage.tasks.queued > 0 &&
workerNode.info.ready &&
!workerNode.info.stealing &&
workerNode.info.id !== workerId &&
workerNode.usage.tasks.queued <
(this.opts.tasksQueueOptions?.size as number) - sizeOffset
) {
this.getWorkerInfo(workerNodeKey).stealing = true
const task = sourceWorkerNode.popTask() as Task<Data>
this.handleTask(workerNodeKey, task)
this.updateTaskStolenStatisticsWorkerUsage(
workerNodeKey,
task.name as string,
)
this.getWorkerInfo(workerNodeKey).stealing = false
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/pools/pool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ export interface PoolInfo {
readonly workerNodes: number
/** Pool idle worker nodes. */
readonly idleWorkerNodes: number
/** Pool stealing worker nodes. */
readonly stealingWorkerNodes?: number
/** Pool busy worker nodes. */
readonly busyWorkerNodes: number
readonly executedTasks: number
Expand Down
1 change: 1 addition & 0 deletions src/pools/worker-node.ts
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ export class WorkerNode<Worker extends IWorker, Data = unknown>
type: getWorkerType(worker) as WorkerType,
dynamic: false,
ready: false,
stealing: false,
}
}

Expand Down
5 changes: 5 additions & 0 deletions src/pools/worker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,11 @@ export interface WorkerInfo {
* Ready flag.
*/
ready: boolean
/**
* Stealing flag.
* This flag is set to `true` when worker node is stealing tasks from another worker node.
*/
stealing: boolean
/**
* Task function names.
*/
Expand Down
7 changes: 5 additions & 2 deletions tests/pools/abstract-pool.test.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,7 @@ Deno.test({
type: WorkerTypes.web,
dynamic: false,
ready: true,
stealing: false,
})
}
await pool.destroy()
Expand All @@ -1046,6 +1047,7 @@ Deno.test({
type: WorkerTypes.web,
dynamic: false,
ready: true,
stealing: false,
})
}
await pool.destroy()
Expand Down Expand Up @@ -1447,7 +1449,7 @@ Deno.test({
stub(
pool,
'hasBackPressure',
returnsNext(Array(5).fill(true)),
returnsNext(Array(7).fill(true)),
)
expect(pool.emitter.eventNames()).toStrictEqual([])
const promises = new Set()
Expand Down Expand Up @@ -1476,6 +1478,7 @@ Deno.test({
maxSize: expect.any(Number),
workerNodes: expect.any(Number),
idleWorkerNodes: expect.any(Number),
stealingWorkerNodes: expect.any(Number),
busyWorkerNodes: expect.any(Number),
executedTasks: expect.any(Number),
executingTasks: expect.any(Number),
Expand All @@ -1485,7 +1488,7 @@ Deno.test({
stolenTasks: expect.any(Number),
failedTasks: expect.any(Number),
})
assertSpyCalls(pool.hasBackPressure, 5)
assertSpyCalls(pool.hasBackPressure, 7)
pool.hasBackPressure.restore()
await pool.destroy()
},
Expand Down
1 change: 1 addition & 0 deletions tests/pools/worker-node.test.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ Deno.test({
type: WorkerTypes.web,
dynamic: false,
ready: false,
stealing: false,
})
expect(threadWorkerNode.usage).toStrictEqual({
tasks: {
Expand Down
19 changes: 7 additions & 12 deletions tests/worker/thread-worker.test.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ Deno.test('Thread worker test suite', async (t) => {
return 2
}
const worker = new ThreadWorker({ fn1, fn2 })
worker.port = {
postMessage: stub(() => {}),
}
expect(worker.removeTaskFunction(0, fn1)).toStrictEqual({
status: false,
error: new TypeError('name parameter is not a string'),
Expand All @@ -40,9 +43,6 @@ Deno.test('Thread worker test suite', async (t) => {
status: false,
error: new TypeError('name parameter is an empty string'),
})
worker.port = {
postMessage: stub(() => {}),
}
expect(worker.taskFunctions.get(DEFAULT_TASK_NAME)).toBeInstanceOf(
Function,
)
Expand Down Expand Up @@ -85,17 +85,12 @@ Deno.test('Thread worker test suite', async (t) => {
})

await t.step(
'Verify worker invokes the postMessage() method on port property',
'Verify that sendToMainWorker() method invokes the port property postMessage() method',
() => {
class SpyWorker extends ThreadWorker {
constructor(fn) {
super(fn)
this.port = {
postMessage: stub(() => {}),
}
}
const worker = new ThreadWorker(() => {})
worker.port = {
postMessage: stub(() => {}),
}
const worker = new SpyWorker(() => {})
worker.sendToMainWorker({ ok: 1 })
assertSpyCalls(worker.port.postMessage, 1)
worker.port.postMessage.restore()
Expand Down

0 comments on commit 5bfb3e8

Please sign in to comment.