Skip to content

Commit

Permalink
now logging during HPO for both baseline and tuning steps
Browse files Browse the repository at this point in the history
  • Loading branch information
wangpatrick57 committed May 30, 2024
1 parent 474d7ee commit a6e00b9
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 13 deletions.
8 changes: 5 additions & 3 deletions tune/protox/agent/build_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import socket
import xml.etree.ElementTree as ET
from pathlib import Path
from typing import Any, Callable, Tuple, Union
from typing import Any, Callable, Optional, Tuple, Union

import gymnasium as gym
import numpy as np
Expand Down Expand Up @@ -381,6 +381,7 @@ def _build_agent(
observation_space: StateSpace,
action_space: HolonSpace,
logger: Logger,
ray_trial_id: Optional[str],
) -> Wolp:
action_dim = noise_action_dim = action_space.latent_dim()
critic_action_dim = action_space.critic_dim()
Expand Down Expand Up @@ -498,6 +499,7 @@ def _build_agent(
obs_shape=[gym.spaces.utils.flatdim(observation_space)],
action_dim=critic_action_dim,
),
ray_trial_id=ray_trial_id,
learning_starts=hpo_params["learning_starts"],
batch_size=hpo_params["batch_size"],
train_freq=(hpo_params["train_freq_frequency"], hpo_params["train_freq_unit"]),
Expand All @@ -510,7 +512,7 @@ def _build_agent(


def build_trial(
dbgym_cfg: DBGymConfig, tuning_mode: TuningMode, seed: int, hpo_params: dict[str, Any]
dbgym_cfg: DBGymConfig, tuning_mode: TuningMode, seed: int, hpo_params: dict[str, Any], ray_trial_id: Optional[str]=None
) -> Tuple[Logger, TargetResetWrapper, AgentEnv, Wolp, str]:
# The massive trial builder.

Expand All @@ -533,5 +535,5 @@ def build_trial(
logger,
)

agent = _build_agent(seed, hpo_params, observation_space, holon_space, logger)
agent = _build_agent(seed, hpo_params, observation_space, holon_space, logger, ray_trial_id)
return logger, target_reset, env, agent, signal
12 changes: 5 additions & 7 deletions tune/protox/agent/hpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import os
import pandas as pd
from datetime import datetime
from typing import Any, Union
from typing import Any, Optional, Union
import random
import click
import ssd_checker
Expand Down Expand Up @@ -435,7 +435,7 @@ def __call__(self) -> bool:


class TuneTrial:
def __init__(self, dbgym_cfg: DBGymConfig, tuning_mode: TuningMode, ray_trial_id: str | None=None) -> None:
def __init__(self, dbgym_cfg: DBGymConfig, tuning_mode: TuningMode, ray_trial_id: Optional[str]=None) -> None:
"""
We use this object for HPO, tune, and replay. It behaves *slightly* differently
depending on what it's used for, which is why we have the tuning_mode param.
Expand Down Expand Up @@ -470,6 +470,7 @@ def setup(self, hpo_params: dict[str, Any]) -> None:
self.tuning_mode,
seed=seed,
hpo_params=hpo_params,
ray_trial_id=self.ray_trial_id,
)
self.logger.get_logger(None).info("%s", hpo_params)
self.logger.get_logger(None).info(f"Seed: {seed}")
Expand Down Expand Up @@ -504,11 +505,8 @@ def step(self) -> dict[Any, Any]:
)
self.env_init = True

# During HPO, we need to make sure different trials don't create folders that override each other.
if self.tuning_mode == TuningMode.HPO:
self.logger.stash_results(infos, name_override=f"baseline_{self.ray_trial_id}")
else:
self.logger.stash_results(infos, name_override="baseline")
assert self.ray_trial_id != None if self.tuning_mode == TuningMode.HPO else True, "If we're doing HPO, we need to ensure that we're passing a non-None ray_trial_id to stash_results() to avoid conflicting folder names."
self.logger.stash_results(infos, name_override="baseline", ray_trial_id=self.ray_trial_id)
else:
self.agent.learn(self.env, total_timesteps=1, tuning_mode=self.tuning_mode)

Expand Down
7 changes: 5 additions & 2 deletions tune/protox/agent/off_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,12 @@ def __init__(
gradient_steps: int = 1,
action_noise: Optional[ActionNoise] = None,
seed: Optional[int] = None,
ray_trial_id: Optional[str] = None,
):
super().__init__(seed=seed)
self.policy = policy
self.replay_buffer = replay_buffer
self.ray_trial_id = ray_trial_id

self.batch_size = batch_size
self.learning_starts = learning_starts
Expand Down Expand Up @@ -186,8 +188,9 @@ def collect_rollouts(
dones = terms or truncs
# We only stash the results if we're not doing HPO, or else the results from concurrent HPO would get
# stashed in the same directory and potentially cause a race condition.
if self.logger and not tuning_mode == TuningMode.HPO:
self.logger.stash_results(infos)
if self.logger:
assert self.ray_trial_id != None if tuning_mode == TuningMode.HPO else True, "If we're doing HPO, we need to ensure that we're passing a non-None ray_trial_id to stash_results() to avoid conflicting folder names."
self.logger.stash_results(infos, ray_trial_id=self.ray_trial_id)

self.num_timesteps += 1
num_collected_steps += 1
Expand Down
2 changes: 2 additions & 0 deletions tune/protox/agent/wolp/wolp.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
target_action_noise: Optional[ActionNoise] = None,
seed: Optional[int] = None,
neighbor_parameters: Dict[str, Any] = {},
ray_trial_id: Optional[str] = None,
):
super().__init__(
policy,
Expand All @@ -63,6 +64,7 @@ def __init__(
gradient_steps,
action_noise=action_noise,
seed=seed,
ray_trial_id=ray_trial_id,
)

self.target_action_noise = target_action_noise
Expand Down
7 changes: 6 additions & 1 deletion tune/protox/env/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,17 @@ def get_logger(self, name: Optional[str]) -> logging.Logger:
return logging.getLogger(name)

def stash_results(
self, info_dict: dict[str, Any], name_override: Optional[str] = None
self, info_dict: dict[str, Any], name_override: Optional[str] = None, ray_trial_id: Optional[str] = None,
) -> None:
"""
Stash data about this step of tuning so that it can be replayed.
"""
dname = name_override if name_override else datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
if ray_trial_id != None:
# Orthogonal to whether name_override is used, ray_trial_id disambiguates between folders created
# by different HPO trials so that the folders don't overwrite each other.
dname += f"_{ray_trial_id}"

if info_dict["results_dpath"] is not None and Path(info_dict["results_dpath"]).exists():
local["mv"][info_dict["results_dpath"], f"{self.tuning_steps_dpath}/{dname}"].run()
else:
Expand Down

0 comments on commit a6e00b9

Please sign in to comment.