-
Notifications
You must be signed in to change notification settings - Fork 1
/
inference_multi_gpu.py
73 lines (54 loc) · 1.88 KB
/
inference_multi_gpu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import argparse, os, sys, torch
from pathlib import Path
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--n_gpus', type=int, default=torch.cuda.device_count())
parser.add_argument('--ckpt_dir', type=str)
parser.add_argument('--dataset', type=str)
parser.add_argument('--train_dir', type=str)
parser.add_argument('--decode_script', type=str)
args = parser.parse_args()
return args
def parse_ckpt_list(ckpt_dir):
ckpt_list = []
for file in Path(ckpt_dir).iterdir():
if file.suffix == ".pt":
ckpt_list.append(file)
return ckpt_list
def run_cmd(cmd):
os.system(cmd)
def inference_multi_gpu(n_gpus, ckpt_list, args):
decode_script = args.decode_script
python = sys.executable
print(python)
print(decode_script)
import multiprocessing
for i in range(0, len(ckpt_list), n_gpus):
cur_ckpt_list = ckpt_list[i:i + n_gpus]
pool = multiprocessing.Pool(n_gpus)
for gpu_id, ckpt_file in enumerate(cur_ckpt_list):
ckpt_file = Path(ckpt_file).name
cmd = '''CUDA_VISIBLE_DEVICES={gpu_id} {python} -u {script} \
--dataset {dataset} \
--train_dir {train_dir} \
--ckpt_file {ckpt_file} \
'''.format(
gpu_id=gpu_id,
python=python,
script=decode_script,
dataset=args.dataset,
train_dir=args.train_dir,
ckpt_file=ckpt_file,
)
print(cmd)
pool.apply_async(run_cmd, args=(cmd,))
pool.close()
pool.join()
def main():
args = get_args()
args.ckpt_list = parse_ckpt_list(args.ckpt_dir)
print(args.ckpt_list)
n_gpus = min(args.n_gpus, len(args.ckpt_list))
inference_multi_gpu(n_gpus, args.ckpt_list, args)
if __name__ == '__main__':
main()