-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_dqn.py
53 lines (49 loc) · 1.82 KB
/
train_dqn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import tensorflow as tf
from tf_agents.utils.common import function, Checkpointer
from src.ai.dqn import (
CarRacingEnv,
compute_avg_return,
get_replay_buffer,
collect_step
)
from src.ai import get_ann, get_agent
if __name__ == "__main__":
batch_size = 1
# env = CarRacingEnv.tf_batched_environment(batch_size)
env = CarRacingEnv.tf_environment()
model = get_ann(5, 9)
agent = get_agent(model, env.time_step_spec(), env.action_spec())
num_iterations = 10_000
collect_steps_per_iteration = 12
replay_buffer = get_replay_buffer(agent, batch_size=env.batch_size)
for _ in range(10):
collect_step(env, agent.policy, replay_buffer)
checkpointer = Checkpointer(
ckpt_dir='pwr_shaped',
max_to_keep=50,
agent=agent,
policy=agent.policy,
global_step=agent.train_step_counter
)
checkpointer.initialize_or_restore()
dataset = replay_buffer.as_dataset(
num_parallel_calls=3,
sample_batch_size=64,
num_steps=2,
single_deterministic_pass=False
).prefetch(3)
iterator = iter(dataset)
env.reset()
agent.train = function(agent.train) # optimizations
for _ in range(5):
with tf.device('/GPU:0'):
for _ in range(num_iterations):
for __ in range(collect_steps_per_iteration):
collect_step(env, agent.policy, replay_buffer)
# Sample data from buffer and feed to network
experience, unused_info = next(iterator)
train_loss = agent.train(experience).loss
step = agent.train_step_counter.numpy()
checkpointer.save(step)
if step % 10 == 0:
print(f"Step = {step}, Loss = {train_loss}, Average Return = {compute_avg_return(env, agent.policy)}")