-
Notifications
You must be signed in to change notification settings - Fork 4
/
util.py
52 lines (43 loc) · 1.67 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
import os
import random
import numpy as np
def initialize_distributed(args):
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}, gpu {}'.format(
args.rank, args.dist_url, args.gpu), flush=True)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
torch.distributed.barrier()
def set_random_seed(seed):
"""Set random seed for reproducability."""
if seed is not None and seed > 0:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def mkdir_ckpt_dirs(args):
if torch.distributed.get_rank() == 0:
if os.path.exists(args.save):
print('savedir already here.', args.save)
exit()
else:
os.makedirs(args.save)
os.makedirs(args.save + '/ckpts')
os.makedirs(args.save + '/samples')
argsDict = args.__dict__
with open(os.path.join(args.save, 'setting.txt'), 'w') as f:
f.writelines('------------------- start -------------------' + '\n')
for arg, value in argsDict.items():
f.writelines(arg + ' : ' + str(value) + '\n')
f.writelines('------------------- end -------------------' + '\n')
def multiplyList(myList):
# Multiply elements one by one
result = 1
for x in myList:
result = result * x
return result