-
Notifications
You must be signed in to change notification settings - Fork 2
/
run_test.py
67 lines (55 loc) · 2.04 KB
/
run_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
from argparse import Namespace, ArgumentParser
from subprocess import run, TimeoutExpired, CalledProcessError
from sys import stdout
from time import strftime
from numpy.random import Generator, PCG64
from tqdm import tqdm
from gencog.graph import GraphGenerator, print_relay
from gencog.spec import OpRegistry
args = Namespace()
def parse_args():
global args
p = ArgumentParser()
p.add_argument('-r', '--root', type=str, help='Root directory of TVM source code.')
p.add_argument('-s', '--seed', type=int, default=42, help='Random seed of graph generator.')
p.add_argument('-o', '--output', type=str, default='out', help='Output directory.')
args = p.parse_args()
def main():
# Initialization
rng = Generator(PCG64(seed=args.seed))
gen = GraphGenerator(OpRegistry.ops(), rng)
path = os.path.join(args.output, strftime('run-%Y%m%d-%H%M%S'))
env = os.environ.copy()
env['PYTHONPATH'] = os.path.join(args.root, 'python')
if not os.path.exists(path):
os.mkdir(path)
# Generation loop
progress = tqdm(file=stdout)
while True:
# Generate graph
graph = gen.generate()
code = print_relay(graph)
# Write code to case directory
case_id = str(progress.n)
case_path = os.path.join(path, case_id)
os.mkdir(case_path)
with open(os.path.join(case_path, 'code.txt'), 'w') as f:
f.write(code)
# Run subprocess
cmd = ['python3', '_run_ps.py', f'-d={case_path}', '-e', f'-s={rng.integers(2 ** 63)}']
keep_dir = False
try:
run(cmd, env=env, check=True, timeout=60, stderr=open(os.devnull, 'w'))
except CalledProcessError:
print(f'Error detected in case {case_id}.')
keep_dir = True
except TimeoutExpired:
print(f'Case {case_id} timed out.')
if not keep_dir:
os.remove(os.path.join(case_path, 'code.txt'))
os.rmdir(case_path)
progress.update()
if __name__ == '__main__':
parse_args()
main()