-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
52 lines (43 loc) · 1.38 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import tensorflow as tf
from agent import optimization_agent
from util.config import get_config
from util import logger
import numpy as np
tf.compat.v1.disable_eager_execution()
def main(args):
if args.write_log:
logger.set_file_handler(
path=args.output_dir,
prefix="mujoco_" + "_".join(args.task),
time_str=args.time_id,
)
learner_agent = optimization_agent.optimization_agent(args)
i = 0
while True:
results = learner_agent.update_step()
totalsteps = results["totalsteps"]
logger.info("%d total steps have happened" % totalsteps)
if totalsteps > args.max_timesteps:
break
i += 1
learner_agent.end()
if args.test:
logger.info(
"Test performance ({} rollouts): {} (std: {})".format(
args.test, results["avg_reward"], results["std_reward"]
)
)
logger.info(
"max: {}, min: {}, median: {}".format(
results["max_reward"], results["min_reward"], results["median_reward"]
)
)
logger.info(
"raw_rewards: {}".format(
np.array2string(results["raw_rewards"], separator=",")
)
)
return np.array2string(results["raw_rewards"], separator=",")
if __name__ == "__main__":
args = get_config()
main(args)