-
Notifications
You must be signed in to change notification settings - Fork 1
/
distribute_benchmark.py
64 lines (50 loc) · 1.9 KB
/
distribute_benchmark.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
import sys
import time
import functools
import collections
import contextlib
os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(r) for r in range(int(sys.argv[2]))])
import jax
import jax.numpy as jnp
from jax import lax
from jax import random
import jax.numpy as jnp
import numpy as np
import h5py
from itertools import product
from tensorflow_probability.substrates import jax as tfp
from utils import array_maker
from runner import run_wrapper
from models import *
N_GPU = int(sys.argv[2])
N_array,M_array,inds,vals = array_maker(*[int(s) for s in sys.argv[2:]])
model_type = sys.argv[1]
for packed in inds:
i,j = [int(p) for p in packed.split(',')]
N = N_array[i]
M = M_array[j]
dt_file = 'data/%s/%s_%s.hdf5'%(model_type,N,M)
dt = h5py.File(dt_file,'a')
def shard(data_array):
split_data = [np.array(d).reshape((jax.device_count(),-1,*d.shape[1:])) for d in data_array]
return(tuple([jax.pmap(lambda x: x)(s) for s in split_data]))
sharded_data = shard([dt['data'][k] for k in dt['data_keys']])
pass_data = sharded_data[-dt['n_comp'][0]:]
run = run_wrapper(model_dict[model_type])
start_time = time.perf_counter()
states, trace = run(random.PRNGKey(0),sharded_data,pass_data)
end_time = time.perf_counter()
print('%s,%s\t%s,%s'%(str(N),str(M),str(start_time),str(end_time)))
# runtime = timeit('run(random.PRNGKey(0), (sharded_X, sharded_y))',number=1)
n_param = len(dt['param_keys'])
if 'runtime_%i_GPU'%N_GPU in dt:
del dt['runtime_%i_GPU'%N_GPU]
else:
dt.create_dataset('runtime_%i_GPU'%N_GPU,data=float(end_time - start_time))
for i in range(n_param):
k = dt['param_keys'][i]
if 'errors/%s_%i_GPU'%(k,N_GPU) in dt:
del dt['errors']['%s_%i_GPU'%(k,N_GPU)]
else:
dt['errors'].create_dataset('%s_%i_GPU'%(k,N_GPU),data=(states[i].mean(0) - dt['params'][k][()]))