Skip to content

Commit

Permalink
renamed Logger -> ArtifactManager
Browse files Browse the repository at this point in the history
  • Loading branch information
wangpatrick57 committed Sep 6, 2024
1 parent a6fa72e commit 8cb673b
Show file tree
Hide file tree
Showing 18 changed files with 52 additions and 46 deletions.
8 changes: 4 additions & 4 deletions tune/protox/agent/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from misc.utils import TuningMode
from tune.protox.agent.agent_env import AgentEnv
from tune.protox.agent.noise import ActionNoise
from tune.protox.env.logger import Logger
from tune.protox.env.logger import ArtifactManager


class BaseAlgorithm(ABC):
Expand All @@ -28,10 +28,10 @@ def __init__(self, seed: Optional[int] = None):
# For logging (and TD3 delayed updates)
self._n_updates = 0 # type: int
# The logger object
self._logger: Optional[Logger] = None
self._logger: Optional[ArtifactManager] = None
self.timeout_checker = None

def set_logger(self, logger: Optional[Logger]) -> None:
def set_logger(self, logger: Optional[ArtifactManager]) -> None:
"""
Setter for for logger object.
Expand All @@ -40,7 +40,7 @@ def set_logger(self, logger: Optional[Logger]) -> None:
self._logger = logger

@property
def logger(self) -> Logger:
def logger(self) -> ArtifactManager:
"""Getter for the logger object."""
assert self._logger is not None
return self._logger
Expand Down
14 changes: 7 additions & 7 deletions tune/protox/agent/build_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
create_vae_model,
fetch_vae_parameters_from_workload,
)
from tune.protox.env.logger import Logger
from tune.protox.env.logger import ArtifactManager
from tune.protox.env.lsc.lsc import LSC
from tune.protox.env.lsc.lsc_wrapper import LSCWrapper
from tune.protox.env.mqo.mqo_wrapper import MQOWrapper
Expand Down Expand Up @@ -146,8 +146,8 @@ def _build_utilities(
tuning_mode: TuningMode,
pgport: int,
hpo_params: dict[str, Any],
) -> tuple[Logger, RewardUtility, PostgresConn, Workload]:
logger = Logger(
) -> tuple[ArtifactManager, RewardUtility, PostgresConn, Workload]:
logger = ArtifactManager(
dbgym_cfg,
hpo_params["trace"],
)
Expand Down Expand Up @@ -203,7 +203,7 @@ def _build_actions(
seed: int,
hpo_params: dict[str, Any],
workload: Workload,
logger: Logger,
logger: ArtifactManager,
) -> tuple[HolonSpace, LSC]:
sysknobs = LatentKnobSpace(
logger=logger,
Expand Down Expand Up @@ -336,7 +336,7 @@ def _build_env(
lsc: LSC,
workload: Workload,
reward_utility: RewardUtility,
logger: Logger,
logger: ArtifactManager,
) -> tuple[TargetResetWrapper, AgentEnv]:

env = gym.make(
Expand Down Expand Up @@ -404,7 +404,7 @@ def _build_agent(
hpo_params: dict[str, Any],
observation_space: StateSpace,
action_space: HolonSpace,
logger: Logger,
logger: ArtifactManager,
ray_trial_id: Optional[str],
) -> Wolp:
action_dim = noise_action_dim = action_space.latent_dim()
Expand Down Expand Up @@ -539,7 +539,7 @@ def build_trial(
seed: int,
hpo_params: dict[str, Any],
ray_trial_id: Optional[str] = None,
) -> tuple[Logger, TargetResetWrapper, AgentEnv, Wolp, str]:
) -> tuple[ArtifactManager, TargetResetWrapper, AgentEnv, Wolp, str]:
# The massive trial builder.

port, signal = _get_signal(hpo_params["pgconn_info"]["pgbin_path"])
Expand Down
4 changes: 2 additions & 2 deletions tune/protox/agent/wolp/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tune.protox.agent.noise import ActionNoise
from tune.protox.agent.policies import Actor, BaseModel, ContinuousCritic
from tune.protox.agent.utils import polyak_update
from tune.protox.env.logger import Logger, time_record
from tune.protox.env.logger import ArtifactManager, time_record
from tune.protox.env.space.holon_space import HolonSpace
from tune.protox.env.types import (
DEFAULT_NEIGHBOR_PARAMETERS,
Expand Down Expand Up @@ -54,7 +54,7 @@ def __init__(
policy_l2_reg: float = 0.0,
tau: float = 0.005,
gamma: float = 0.99,
logger: Optional[Logger] = None,
logger: Optional[ArtifactManager] = None,
):
super().__init__(observation_space, action_space)
self.actor = actor
Expand Down
10 changes: 8 additions & 2 deletions tune/protox/env/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def wrapped_f(*args: P.args, **kwargs: P.kwargs) -> T:
# If there is no logger, just return.
return ret

assert isinstance(first_arg.logger, Logger)
assert isinstance(first_arg.logger, ArtifactManager)
if first_arg.logger is not None:
cls_name = type(first_arg).__name__
first_arg.logger.record(f"{cls_name}_{key}", time.time() - start)
Expand All @@ -54,7 +54,13 @@ def default(self, obj: Any) -> Any:
return super(Encoder, self).default(obj)


class Logger(object):
class ArtifactManager(object):
"""
This class manages the following artifacts of Proto-X: info for replaying and TensorBoard output.
Importantly, this class should *not* be used for general-purpose logging. You should directly
use the logging library to do that.
"""
def __init__(
self,
dbgym_cfg: DBGymConfig,
Expand Down
4 changes: 2 additions & 2 deletions tune/protox/env/lsc/lsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import torch

from tune.protox.env.logger import Logger
from tune.protox.env.logger import ArtifactManager
from tune.protox.env.types import ProtoAction

T = TypeVar("T", torch.Tensor, np.typing.NDArray[np.float32])
Expand All @@ -15,7 +15,7 @@ def __init__(
horizon: int,
lsc_parameters: dict[str, Any],
vae_config: dict[str, Any],
logger: Optional[Logger],
logger: Optional[ArtifactManager],
):
self.frozen = False
self.horizon = horizon
Expand Down
4 changes: 2 additions & 2 deletions tune/protox/env/lsc/lsc_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

import gymnasium as gym

from tune.protox.env.logger import Logger
from tune.protox.env.logger import ArtifactManager
from tune.protox.env.lsc.lsc import LSC
from tune.protox.env.target_reset.target_reset_wrapper import TargetResetWrapper


class LSCWrapper(gym.Wrapper[Any, Any, Any, Any]):
def __init__(self, lsc: LSC, env: gym.Env[Any, Any], logger: Optional[Logger]):
def __init__(self, lsc: LSC, env: gym.Env[Any, Any], logger: Optional[ArtifactManager]):
assert not isinstance(env, TargetResetWrapper)
super().__init__(env)
self.lsc = lsc
Expand Down
6 changes: 3 additions & 3 deletions tune/protox/env/mqo/mqo_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import torch

from tune.protox.env.logger import Logger
from tune.protox.env.logger import ArtifactManager
from tune.protox.env.pg_env import PostgresEnv
from tune.protox.env.space.holon_space import HolonSpace
from tune.protox.env.space.primitive import SettingType, is_binary_enum, is_knob_enum
Expand Down Expand Up @@ -89,7 +89,7 @@ def _regress_query_knobs(
qknobs: QuerySpaceKnobAction,
sysknobs: Union[KnobSpaceAction, KnobSpaceContainer],
ams: QueryTableAccessMap,
logger: Optional[Logger] = None,
logger: Optional[ArtifactManager] = None,
) -> QuerySpaceKnobAction:
global_qknobs = {}
for knob, _ in qknobs.items():
Expand Down Expand Up @@ -138,7 +138,7 @@ def __init__(
query_timeout: int,
benchbase_config: dict[str, Any],
env: gym.Env[Any, Any],
logger: Optional[Logger],
logger: Optional[ArtifactManager],
):
assert isinstance(env, PostgresEnv) or isinstance(
env.unwrapped, PostgresEnv
Expand Down
4 changes: 2 additions & 2 deletions tune/protox/env/pg_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from plumbum import local

from misc.utils import DBGymConfig, TuningMode
from tune.protox.env.logger import Logger, time_record
from tune.protox.env.logger import ArtifactManager, time_record
from tune.protox.env.space.holon_space import HolonSpace
from tune.protox.env.space.state.space import StateSpace
from tune.protox.env.space.utils import fetch_server_indexes, fetch_server_knobs
Expand Down Expand Up @@ -38,7 +38,7 @@ def __init__(
pg_conn: PostgresConn,
query_timeout: int,
benchbase_config: dict[str, Any],
logger: Optional[Logger] = None,
logger: Optional[ArtifactManager] = None,
):
super().__init__()

Expand Down
4 changes: 2 additions & 2 deletions tune/protox/env/space/holon_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from gymnasium import spaces
from psycopg import Connection

from tune.protox.env.logger import Logger, time_record
from tune.protox.env.logger import ArtifactManager, time_record
from tune.protox.env.space.latent_space import (
LatentIndexSpace,
LatentKnobSpace,
Expand Down Expand Up @@ -62,7 +62,7 @@ def __init__(
index_space: LatentIndexSpace,
query_space: LatentQuerySpace,
seed: int,
logger: Optional[Logger],
logger: Optional[ArtifactManager],
):
spaces: Iterable[gym.spaces.Space[Any]] = [knob_space, index_space, query_space]
super().__init__(spaces, seed=seed)
Expand Down
4 changes: 2 additions & 2 deletions tune/protox/env/space/latent_space/latent_index_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from numpy.typing import NDArray

from tune.protox.embedding.vae import VAE
from tune.protox.env.logger import Logger, time_record
from tune.protox.env.logger import ArtifactManager, time_record
from tune.protox.env.space.primitive.index import IndexAction
from tune.protox.env.space.primitive_space import IndexSpace
from tune.protox.env.space.utils import check_subspace, fetch_server_indexes
Expand Down Expand Up @@ -41,7 +41,7 @@ def __init__(
index_noise_scale: Optional[
Callable[[ProtoAction, Optional[torch.Tensor]], ProtoAction]
] = None,
logger: Optional[Logger] = None,
logger: Optional[ArtifactManager] = None,
) -> None:

super().__init__(
Expand Down
4 changes: 2 additions & 2 deletions tune/protox/env/space/latent_space/latent_knob_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
from psycopg import Connection

from tune.protox.env.logger import Logger, time_record
from tune.protox.env.logger import ArtifactManager, time_record
from tune.protox.env.space.primitive import KnobClass, SettingType, is_knob_enum
from tune.protox.env.space.primitive.knob import resolve_enum_value
from tune.protox.env.space.primitive.latent_knob import (
Expand All @@ -28,7 +28,7 @@

class LatentKnobSpace(KnobSpace):
def __init__(
self, logger: Optional[Logger] = None, *args: Any, **kwargs: Any
self, logger: Optional[ArtifactManager] = None, *args: Any, **kwargs: Any
) -> None:
super().__init__(*args, **kwargs)
self.final_dim = gym.spaces.utils.flatdim(self)
Expand Down
4 changes: 2 additions & 2 deletions tune/protox/env/space/latent_space/latent_query_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import psycopg

from tune.protox.env.logger import Logger
from tune.protox.env.logger import ArtifactManager
from tune.protox.env.space.latent_space.latent_knob_space import LatentKnobSpace
from tune.protox.env.space.primitive_space import QuerySpace
from tune.protox.env.types import (
Expand All @@ -15,7 +15,7 @@

class LatentQuerySpace(LatentKnobSpace, QuerySpace):
def __init__(
self, logger: Optional[Logger] = None, *args: Any, **kwargs: Any
self, logger: Optional[ArtifactManager] = None, *args: Any, **kwargs: Any
) -> None:
# Only manually initialize against QuerySpace.
QuerySpace.__init__(self, *args, **kwargs)
Expand Down
4 changes: 2 additions & 2 deletions tune/protox/env/space/latent_space/lsc_index_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch

from tune.protox.embedding.vae import VAE
from tune.protox.env.logger import Logger
from tune.protox.env.logger import ArtifactManager
from tune.protox.env.lsc.lsc import LSC
from tune.protox.env.space.latent_space.latent_index_space import LatentIndexSpace
from tune.protox.env.space.primitive.index import IndexAction
Expand Down Expand Up @@ -37,7 +37,7 @@ def __init__(
index_noise_scale: Optional[
Callable[[ProtoAction, Optional[torch.Tensor]], ProtoAction]
] = None,
logger: Optional[Logger] = None,
logger: Optional[ArtifactManager] = None,
lsc: Optional[LSC] = None,
) -> None:

Expand Down
4 changes: 2 additions & 2 deletions tune/protox/env/target_reset/target_reset_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import gymnasium as gym

from tune.protox.env.logger import Logger
from tune.protox.env.logger import ArtifactManager
from tune.protox.env.pg_env import PostgresEnv
from tune.protox.env.types import EnvInfoDict, HolonStateContainer, TargetResetConfig
from tune.protox.env.util.reward import RewardUtility
Expand All @@ -16,7 +16,7 @@ def __init__(
maximize_state: bool,
reward_utility: RewardUtility,
start_reset: bool,
logger: Optional[Logger],
logger: Optional[ArtifactManager],
):
super().__init__(env)
self.maximize_state = maximize_state
Expand Down
8 changes: 4 additions & 4 deletions tune/protox/env/util/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from psycopg import Connection
from psycopg.errors import QueryCanceled

from tune.protox.env.logger import Logger
from tune.protox.env.logger import ArtifactManager
from tune.protox.env.space.primitive.knob import CategoricalKnob, Knob
from tune.protox.env.space.state.space import StateSpace
from tune.protox.env.types import (
Expand All @@ -31,7 +31,7 @@ def _force_statement_timeout(


def _time_query(
logger: Optional[Logger],
logger: Optional[ArtifactManager],
prefix: str,
connection: psycopg.Connection[Any],
query: str,
Expand Down Expand Up @@ -71,7 +71,7 @@ def _time_query(


def _acquire_metrics_around_query(
logger: Optional[Logger],
logger: Optional[ArtifactManager],
prefix: str,
connection: psycopg.Connection[Any],
query: str,
Expand Down Expand Up @@ -110,7 +110,7 @@ def execute_variations(
runs: list[QueryRun],
query: str,
query_timeout: float = 0,
logger: Optional[Logger] = None,
logger: Optional[ArtifactManager] = None,
sysknobs: Optional[KnobSpaceAction] = None,
observation_space: Optional[StateSpace] = None,
) -> BestQueryRun:
Expand Down
4 changes: 2 additions & 2 deletions tune/protox/env/util/pg_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from psycopg.errors import ProgramLimitExceeded, QueryCanceled

from misc.utils import DBGymConfig, link_result, open_and_save, parent_dpath_of_path
from tune.protox.env.logger import Logger, time_record
from tune.protox.env.logger import ArtifactManager, time_record
from util.pg import (
DBGYM_POSTGRES_DBNAME,
DBGYM_POSTGRES_PASS,
Expand All @@ -40,7 +40,7 @@ def __init__(
connect_timeout: int,
enable_boot: bool,
boot_config_fpath: Path,
logger: Logger,
logger: ArtifactManager,
) -> None:

self.dbgym_cfg = dbgym_cfg
Expand Down
4 changes: 2 additions & 2 deletions tune/protox/env/util/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@

import pandas as pd

from tune.protox.env.logger import Logger
from tune.protox.env.logger import ArtifactManager

# Initial penalty to apply to create the "worst" perf from the baseline.
INITIAL_PENALTY_MULTIPLIER = 4.0


class RewardUtility(object):
def __init__(
self, target: str, metric: str, reward_scaler: float, logger: Logger
self, target: str, metric: str, reward_scaler: float, logger: ArtifactManager
) -> None:
self.reward_scaler = reward_scaler
self.target = target
Expand Down
Loading

0 comments on commit 8cb673b

Please sign in to comment.