Skip to content

Commit

Permalink
Patch #21 and #23 (#26)
Browse files Browse the repository at this point in the history
Small fixes for #21 and #23.

- Fix to how we select the `normalize_state` HPO parameter. We always
want to normalize the representations to `[-1, 1]` (since the metrics
can be arbitrarily large) for neural network stability (#21).
- Added an error check such that when we clip the gradient, it'll throw
if it's already bad for clearer error checking (#21).
- For a bit more "backwards" compatibility, added back the
`structure_normalize` option where we use the latent projection directly
and not rely on the observation based normalization.
- When a query times out, there might be additional DDL afterwards
(TPC-H Q15 is `[create view, select, drop view]`). We need to actually
execute the `drop view` cleanup if executing the select exceeds the
workload timeout (#23).

---------

Co-authored-by: Patrick Wang <wang.patrick57@gmail.com>
  • Loading branch information
17zhangw and wangpatrick57 committed Apr 4, 2024
1 parent bdd99a1 commit da7c5b0
Show file tree
Hide file tree
Showing 10 changed files with 64 additions and 22 deletions.
8 changes: 8 additions & 0 deletions tune/protox/agent/build_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,14 @@ def _build_obs_space(
return LSCStructureStateSpace(
lsc=lsc,
action_space=action_space,
normalize=False,
seed=seed,
)
elif hpoed_params["metric_state"] == "structure_normalize":
return LSCStructureStateSpace(
lsc=lsc,
action_space=action_space,
normalize=True,
seed=seed,
)
else:
Expand Down
8 changes: 3 additions & 5 deletions tune/protox/agent/hpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,13 +227,11 @@ def build_space(
"reward_scaler": tune.choice([1, 2, 5, 10]),
"workload_timeout_penalty": tune.choice([1, 2, 4]),
# State.
"metric_state": tune.choice(["metric", "structure"]),
"metric_state": tune.choice(["metric", "structure", "structure_normalize"]),
"maximize_state": not benchmark_config.get("oltp_workload", False),
# Whether to normalize state or not.
"normalize_state": tune.sample_from(
lambda spc: False
if spc["config"]["metric_state"] == "structure_normalize"
else bool(np.random.choice([False, True]))
lambda spc: False if spc["config"]["metric_state"] == "structure_normalize" else True
),
# Whether to normalize reward or not.
"normalize_reward": tune.choice([False, True]),
Expand Down Expand Up @@ -567,4 +565,4 @@ def _tune_hpo(dbgym_cfg: DBGymConfig, hpo_args: AgentHPOArgs) -> None:
for i in range(len(results)):
if results[i].error:
print(f"Trial {results[i]} FAILED")
assert False, print("Encountered exceptions!")
assert False, print("Encountered exceptions!")
4 changes: 2 additions & 2 deletions tune/protox/agent/wolp/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def train_critic(
self.critic_optimizer.zero_grad()
assert not th.isnan(critic_loss).any()
critic_loss.backward() # type: ignore
th.nn.utils.clip_grad_norm_(list(self.critic.parameters()), self.grad_clip) # type: ignore
th.nn.utils.clip_grad_norm_(list(self.critic.parameters()), self.grad_clip, error_if_nonfinite=True) # type: ignore
self.critic.check_grad()
self.critic_optimizer.step()
return critic_loss
Expand Down Expand Up @@ -282,7 +282,7 @@ def train_actor(self, replay_data: ReplayBufferSamples) -> Any:
self.actor_optimizer.zero_grad()
assert not th.isnan(actor_loss).any()
actor_loss.backward() # type: ignore
th.nn.utils.clip_grad_norm_(list(self.actor.parameters()), self.grad_clip) # type: ignore
th.nn.utils.clip_grad_norm_(list(self.actor.parameters()), self.grad_clip, error_if_nonfinite=True) # type: ignore
self.actor.check_grad()
self.actor_optimizer.step()
return actor_loss
Expand Down
3 changes: 3 additions & 0 deletions tune/protox/env/space/latent_space/latent_index_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def critic_dim(self) -> int:

return index_dim

def uses_embed(self) -> bool:
return True

def transform_noise(
self, subproto: ProtoAction, noise: Optional[torch.Tensor] = None
) -> ProtoAction:
Expand Down
3 changes: 3 additions & 0 deletions tune/protox/env/space/latent_space/latent_knob_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def latent_dim(self) -> int:
def critic_dim(self) -> int:
return self.latent_dim()

def uses_embed(self) -> bool:
return False

def transform_noise(
self, subproto: ProtoAction, noise: Optional[torch.Tensor] = None
) -> ProtoAction:
Expand Down
3 changes: 3 additions & 0 deletions tune/protox/env/space/latent_space/latent_query_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ def __init__(
self.logger = logger
self.name = "query"

def uses_embed(self) -> bool:
return False

def generate_state_container(
self,
prev_state: Optional[QuerySpaceContainer],
Expand Down
3 changes: 3 additions & 0 deletions tune/protox/env/space/latent_space/lsc_index_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ def __init__(
assert lsc is not None
self.lsc = lsc

def uses_embed(self) -> bool:
return True

def pad_center_latent(
self, subproto: ProtoAction, lscs: torch.Tensor
) -> ProtoAction:
Expand Down
3 changes: 2 additions & 1 deletion tune/protox/env/space/state/lsc_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@ def __init__(
self,
lsc: LSC,
action_space: HolonSpace,
normalize: bool,
seed: int,
) -> None:
spaces = {"lsc": Box(low=-1, high=1.0)}
super().__init__(action_space, spaces, seed)
super().__init__(action_space, spaces, normalize, seed)
self.lsc = lsc

def construct_offline(
Expand Down
41 changes: 28 additions & 13 deletions tune/protox/env/space/state/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,25 @@ def __init__(
self,
action_space: HolonSpace,
spaces: Mapping[str, spaces.Space[Any]],
normalize: bool,
seed: int,
) -> None:
self.action_space = action_space
self.normalize = normalize

if self.normalize:
self.internal_spaces: Dict[str, gym.spaces.Space[Any]] = {
k: gym.spaces.Box(low=-np.inf, high=np.inf, shape=(s.critic_dim(),))
for k, s in action_space.get_spaces()
}
else:
self.internal_spaces = {
k: gym.spaces.Box(low=-np.inf, high=np.inf, shape=(s.critic_dim(),))
if s.uses_embed() else s
for k, s in action_space.get_spaces()

}

self.internal_spaces: Dict[str, gym.spaces.Space[Any]] = {
k: gym.spaces.Box(low=-np.inf, high=np.inf, shape=(s.critic_dim(),))
for k, s in action_space.get_spaces()
}
self.internal_spaces.update(spaces)
super().__init__(self.internal_spaces, seed)

Expand Down Expand Up @@ -63,21 +74,25 @@ def construct_offline(
assert isinstance(knobs, LatentKnobSpace)
assert check_subspace(knobs, knob_state)

knob_state = np.array(
knobs.to_latent([knob_state]),
dtype=np.float32,
)[0]

if self.normalize:
knob_state = np.array(
knobs.to_latent([knob_state]),
dtype=np.float32,
)[0]
assert self.internal_spaces["knobs"].contains(knob_state)

if ql_state is not None:
query = self.action_space.get_query_space()
assert isinstance(query, LatentQuerySpace)
assert check_subspace(query, ql_state)
query_state = np.array(
query.to_latent([ql_state]),
dtype=np.float32,
)[0]

if self.normalize:
query_state = np.array(
query.to_latent([ql_state]),
dtype=np.float32,
)[0]
else:
query_state = ql_state

# Handle indexes.
indexes_ = [v[1] for v in splits if isinstance(v[0], LatentIndexSpace)]
Expand Down
10 changes: 9 additions & 1 deletion tune/protox/env/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,9 +417,10 @@ def _execute_workload(
# Skip any query in blocklist.
continue

for sql_type, query in queries:
for qidx, (sql_type, query) in enumerate(queries):
assert sql_type != QueryType.UNKNOWN
if sql_type != QueryType.SELECT:
# This is a sanity check because any OLTP workload should be run through benchbase, and any OLAP workload should not have INS_UPD_DEL queries.
assert sql_type != QueryType.INS_UPD_DEL
pgconn.conn().execute(query)
continue
Expand Down Expand Up @@ -533,6 +534,13 @@ def _execute_workload(
time_left -= qid_runtime / 1e6
workload_time += qid_runtime / 1e6
if time_left < 0:
# We need to undo any potential statements after the timed out query.
for st, rq in queries[qidx+1:]:
if st != QueryType.SELECT:
# This is a sanity check because any OLTP workload should be run through benchbase, and any OLAP workload should not have INS_UPD_DEL queries. If we do have INS_UPD_DEL queries, our "undo" logic will likely have to change.
assert st != QueryType.INS_UPD_DEL
pgconn.conn().execute(rq)

stop_running = True
break

Expand Down

0 comments on commit da7c5b0

Please sign in to comment.