From 10cbfe100c6556b233ddd71fd64a5bdd6c4e00ef Mon Sep 17 00:00:00 2001 From: Patrick Wang Date: Tue, 3 Sep 2024 00:36:45 +0000 Subject: [PATCH] fixed other mypy bugs --- tune/protox/agent/build_trial.py | 5 +++-- tune/protox/env/space/latent_space/lsc_index_space.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tune/protox/agent/build_trial.py b/tune/protox/agent/build_trial.py index e46129d..7c84366 100644 --- a/tune/protox/agent/build_trial.py +++ b/tune/protox/agent/build_trial.py @@ -126,8 +126,9 @@ def _modify_benchbase_config( def _gen_noise_scale( vae_config: dict[str, Any], hpo_params: dict[str, Any] -) -> Callable[[ProtoAction, torch.Tensor], ProtoAction]: - def f(p: ProtoAction, n: torch.Tensor) -> ProtoAction: +) -> Callable[[ProtoAction, Optional[torch.Tensor]], ProtoAction]: + def f(p: ProtoAction, n: Optional[torch.Tensor]) -> ProtoAction: + assert n is not None if hpo_params["scale_noise_perturb"]: return ProtoAction( torch.clamp( diff --git a/tune/protox/env/space/latent_space/lsc_index_space.py b/tune/protox/env/space/latent_space/lsc_index_space.py index e142508..87290dc 100644 --- a/tune/protox/env/space/latent_space/lsc_index_space.py +++ b/tune/protox/env/space/latent_space/lsc_index_space.py @@ -35,7 +35,7 @@ def __init__( latent_dim: int = 0, index_output_transform: Optional[Callable[[ProtoAction], ProtoAction]] = None, index_noise_scale: Optional[ - Callable[[ProtoAction, torch.Tensor], ProtoAction] + Callable[[ProtoAction, Optional[torch.Tensor]], ProtoAction] ] = None, logger: Optional[Logger] = None, lsc: Optional[LSC] = None,