-
Notifications
You must be signed in to change notification settings - Fork 2
/
train.py
103 lines (78 loc) · 3.31 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import tqdm
import re
import matplotlib.pyplot as plt
import numpy as np
import time
import json
import warnings
from datetime import timedelta
import logging
logging.getLogger('tensorflow').disabled = True
import tensorflow as tf
import argparse
#tf.get_logger().setLevel('ERROR')
from gym_minigrid.wrappers import *
import numpy as np
from collections import deque
import PIL
import random
import matplotlib.pyplot as plt
import flloat
from flloat.parser.ltlf import LTLfParser
from models.run import *
"""parsing and configuration"""
def parse_args():
desc = "Tensorflow 1.x Deep Reinforcemet Learning using Restraining Bolts"
parser = argparse.ArgumentParser(description=desc)
parser.add_argument('--episodes', type=int, default=10000, help='The number of episodes to run')
parser.add_argument('--env', type=str, default='MiniGrid-Unlock-v0', help='choose gym enviroement')
parser.add_argument('--algo', type=str, default='dqn', help='Deep RL algorithm')
parser.add_argument('--gui', type=bool, default=False, help='enable gui (nor recommended for training')
parser.add_argument('--model_name', type=str, default=None, help='path to model if starting from checkpoint')
parser.add_argument('--rand_seed', type=int, default=42, help='tf random seed')
parser.add_argument('--BATCH_SIZE', type=int, default=32, help='batch size (only supported for algo==a2c)')
return parser.parse_args()
def main(args):
""" saving paths """
output_dir = "logs"
if not os.path.exists(output_dir):
os.mkdir(output_dir)
if args.model_name is None:
t = time.strftime('%Y-%m-%d_%H_%M_%S_%z')
model_name = "env_{}_algo_{}_ep_{}_{}".format(args.env, args.algo, args.episodes, t)
print("[*] created model folder: {}".format(model_name))
model_dir = '{}/{}'.format(output_dir, model_name)
else:
model_name = args.model_name
print("[*] proceeding to load model: {}".format(model_name))
model_dir = model_name
image_dir = '{}/images'.format(model_dir)
checkpoints_dir = '{}/checkpoints'.format(model_dir)
for path in [output_dir, model_dir, image_dir, checkpoints_dir]:
if not os.path.exists(path):
os.mkdir(path)
""" tf session definitions """
tf.reset_default_graph()
tf.random.set_random_seed(args.rand_seed)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
#config.log_device_placement = True
config.gpu_options.per_process_gpu_memory_fraction = 0.9
sess = tf.Session(config=config)
""" load env """
print("[*] attempting to load {} env".format(args.env))
env = gym.make(args.env)
print("[*] success")
supported_algorithms = ['dqn', 'ddqn', 'a2c', 'pompdp']
assert args.algo in supported_algorithms, "Unsupported Algorithm! Please choose a supported one: {}".format(*supported_algorithms)
""" main loop """
if args.algo in ['dqn', 'ddqn']:
run(sess=sess, env=env, algo=args.algo, checkpoints_dir = checkpoints_dir, n_episodes=args.episodes, gui=args.gui)
else:
run_a2c(sess=sess, env=env, algo=args.algo, checkpoints_dir = checkpoints_dir, n_episodes=args.episodes, gui=args.gui, BATCH_SIZE=args.BATCH_SIZE)
if __name__ == "__main__":
args = parse_args()
if args is None:
exit()
main(args)