-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
127 lines (110 loc) · 4.78 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
""" Deep Q-Learning for OpenAI Gym environment
"""
import os
import sys
import gym
import argparse
import numpy as np
import pandas as pd
import tensorflow as tf
from A2C.a2c import A2C
# from A3C.a3c import A3C
# from DDQN.ddqn import DDQN
# from DDPG.ddpg import DDPG
from single_cell_env import opticalTweezers
from keras.backend.tensorflow_backend import set_session
from keras.utils import to_categorical
from utils.atari_environment import AtariEnvironment
from utils.continuous_environments import Environment
from utils.networks import get_session
gym.logger.set_level(40)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
def parse_args(args):
""" Parse arguments from command line input
"""
parser = argparse.ArgumentParser(description='Training parameters')
#
parser.add_argument('--type', type=str, default='DDQN',help="Algorithm to train from {A2C, A3C, DDQN, DDPG}")
parser.add_argument('--is_atari', dest='is_atari', action='store_true', help="Atari Environment")
parser.add_argument('--with_PER', dest='with_per', action='store_true', help="Use Prioritized Experience Replay (DDQN + PER)")
parser.add_argument('--dueling', dest='dueling', action='store_true', help="Use a Dueling Architecture (DDQN)")
#
parser.add_argument('--nb_episodes', type=int, default=5000, help="Number of training episodes")
parser.add_argument('--batch_size', type=int, default=64, help="Batch size (experience replay)")
parser.add_argument('--consecutive_frames', type=int, default=4, help="Number of consecutive frames (action repeat)")
parser.add_argument('--training_interval', type=int, default=30, help="Network training frequency")
parser.add_argument('--n_threads', type=int, default=8, help="Number of threads (A3C)")
#
parser.add_argument('--gather_stats', dest='gather_stats', action='store_true',help="Compute Average reward per episode (slower)")
parser.add_argument('--render', dest='render', action='store_true', help="Render environment while training")
parser.add_argument('--env', type=str, default='BreakoutNoFrameskip-v4',help="OpenAI Gym Environment")
parser.add_argument('--gpu', type=int, default=0, help='GPU ID')
#
parser.set_defaults(render=False)
return parser.parse_args(args)
def main(args=None):
# Parse arguments
if args is None:
args = sys.argv[1:]
args = parse_args(args)
# Check if a GPU ID was set
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
set_session(get_session())
summary_writer = tf.summary.FileWriter(args.type + "/tensorboard_" + args.env)
# Environment Initialization
if(args.is_atari):
# Atari Environment Wrapper
env = AtariEnvironment(args)
state_dim = env.get_state_size()
action_dim = env.get_action_size()
elif(args.type=="DDPG"):
# Continuous Environments Wrapper
env = Environment(gym.make(args.env), args.consecutive_frames)
env.reset()
state_dim = env.get_state_size()
action_space = gym.make(args.env).action_space
action_dim = action_space.high.shape[0]
act_range = action_space.high
else:
if args.env=='cell':
#do this
env=Environment(opticalTweezers(), args.consecutive_frames)
# env=opticalTweezers(consecutive_frames=args.consecutive_frames)
env.reset()
state_dim=(6,)
action_dim=4 #note that I have to change the reshape code for a 2d agent # should be 4
else:
# Standard Environments
env = Environment(gym.make(args.env), args.consecutive_frames)
env.reset()
state_dim = env.get_state_size()
print(state_dim)
action_dim = gym.make(args.env).action_space.n
print(action_dim)
# Pick algorithm to train
if(args.type=="DDQN"):
algo = DDQN(action_dim, state_dim, args)
elif(args.type=="A2C"):
algo = A2C(action_dim, state_dim, args.consecutive_frames)
elif(args.type=="A3C"):
algo = A3C(action_dim, state_dim, args.consecutive_frames, is_atari=args.is_atari)
elif(args.type=="DDPG"):
algo = DDPG(action_dim, state_dim, act_range, args.consecutive_frames)
# Train
stats = algo.train(env, args, summary_writer)
# Export results to CSV
if(args.gather_stats):
df = pd.DataFrame(np.array(stats))
df.to_csv(args.type + "/logs.csv", header=['Episode', 'Mean', 'Stddev'], float_format='%10.5f')
# Display agent
old_state, time = env.reset(), 0
# all_old_states=[old_state for i in range(args.consecutive_frames)]
while True:
env.render()
a = algo.policy_action(old_state)
old_state, r, done, _ = env.step(a)
time += 1
if done: env.reset()
if __name__ == "__main__":
main()