-
Notifications
You must be signed in to change notification settings - Fork 0
/
discrete_main.py
139 lines (102 loc) · 3.47 KB
/
discrete_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os
import random
import shutil
import string
from glob import glob
import gymnasium as gym
import neptune
from dotenv import load_dotenv
from gymnasium.wrappers.record_video import RecordVideo
from gymnasium.wrappers.time_limit import TimeLimit
from neptune.integrations.sacred import NeptuneObserver
from sacred import Experiment
from tqdm import tqdm
from discrete_agent import DISCRETE_AGENTS
from discrete_agent.discrete_agent import DiscreteAgent
load_dotenv()
ex = Experiment()
# Add Neptune observer for storing logs
run = neptune.init_run(
project=os.environ["NEPTUNE_PROJECT"],
api_token=os.environ["NEPTUNE_API_TOKEN"],
source_files=["**/*.py"],
)
ex.observers.append(NeptuneObserver(run))
@ex.config
def config():
env_name = "CliffWalking-v0"
env_config = {}
agent_id = "ctdl"
train_steps = 100_000
gamma = 0.85
record_video_every = 100
@ex.main
def main(
env_name: str,
env_config: dict,
agent_id: str,
train_steps: int,
gamma: float,
record_video_every: int,
):
NAME = get_random_id()
os.mkdir(NAME)
print("Run ID:", NAME)
print("DO NOT DELETE THIS DIRECTORY!")
env = gym.make(env_name, render_mode="rgb_array", **env_config)
env = TimeLimit(env, max_episode_steps=500)
env = RecordVideo(env, f"{NAME}/videos", disable_logger=True, episode_trigger=lambda t: t % record_video_every == 0)
agent = DISCRETE_AGENTS[agent_id](env.observation_space, env.action_space)
config = { "gamma": gamma }
if hasattr(env.observation_space, "n"):
config["n_states"] = env.observation_space.n
if hasattr(env.action_space, "n"):
config["n_actions"] = env.action_space.n
agent.setup(config)
env.metadata["render_fps"] = 30
train_discrete_agent(env, agent, gamma, train_steps, run)
env.close()
for i, video in enumerate(sorted(glob(f"{NAME}/videos/*.mp4"))):
run[f"video/episode-{i*record_video_every}.mp4"].upload(video, wait=True)
# Save weights
os.mkdir(f"{NAME}/weights")
saved = agent.save(f"{NAME}/weights")
if saved:
shutil.make_archive(f"{NAME}/weights", "zip", f"{NAME}/weights")
run["weights"].upload(f"{NAME}/weights.zip", wait=True)
# Clean up
shutil.rmtree(NAME)
def train_discrete_agent(
env: gym.Env,
agent: DiscreteAgent,
gamma: float,
total_steps: int,
run: neptune.Run,
):
state, _ = env.reset()
ep_timesteps = 0
ep_return = 0
ep_discounted_return = 0
for step in tqdm(range(total_steps)):
action = agent.act(state, train=True)
next_state, reward, terminated, truncated, _ = env.step(action)
agent.update_policy(state, action, reward, next_state, terminated)
state = next_state
# Update statistics
ep_return += reward
ep_discounted_return += (gamma ** ep_timesteps) * reward
ep_timesteps += 1
# Reset environment
if terminated or truncated:
run["train/discounted_return"].append(step=step, value=ep_discounted_return)
run["train/undiscounted_return"].append(step=step, value=ep_return)
run["train/episode_timesteps"].append(step=step, value=ep_timesteps)
ep_return = 0
ep_discounted_return = 0
ep_timesteps = 0
state, _ = env.reset()
agent.log(run)
def get_random_id():
return "".join(random.choices(string.ascii_uppercase + string.digits, k=8))
if __name__ == "__main__":
ex.run_commandline()