-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathpool.py
54 lines (39 loc) · 1.46 KB
/
pool.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch.multiprocessing as multiprocessing
import time
import os
import psutil
import gc
# Ensure that sharing all the models with all the workers is not too many FDs
multiprocessing.set_sharing_strategy('file_system')
def worker(to_worker, core, all_models):
psutil.Process().cpu_affinity([core])
while True:
f, args = to_worker.get()
args[0] = all_models[args[0]]
args[3] = all_models[args[3]]
f(args)
class Pool(object):
def __init__(self, cores, maxsize, all_models):
""" Initialize a pool with cores processes
"""
self._to_worker = multiprocessing.Queue(maxsize)
self._processes = []
self._all_models = all_models
affinity = list(psutil.Process().cpu_affinity())
for i in range(cores):
core = affinity[i % len(affinity)]
p = multiprocessing.Process(target=worker, args=(self._to_worker, core, all_models), daemon=True)
p.start()
self._processes.append(p)
def map(self, func, args):
count = len(self._processes)
# Push functions and arguments to the worker processes
gc.disable()
for i, arg in enumerate(args):
# Replace models by their index
if func != 'set_all_models':
arg[0] = self._all_models.index(arg[0])
arg[3] = self._all_models.index(arg[3])
self._to_worker.put((func, arg))
gc.enable()
return []