diff --git a/tune/protox/agent/build_trial.py b/tune/protox/agent/build_trial.py index e46129df..7c84366f 100644 --- a/tune/protox/agent/build_trial.py +++ b/tune/protox/agent/build_trial.py @@ -126,8 +126,9 @@ def _modify_benchbase_config( def _gen_noise_scale( vae_config: dict[str, Any], hpo_params: dict[str, Any] -) -> Callable[[ProtoAction, torch.Tensor], ProtoAction]: - def f(p: ProtoAction, n: torch.Tensor) -> ProtoAction: +) -> Callable[[ProtoAction, Optional[torch.Tensor]], ProtoAction]: + def f(p: ProtoAction, n: Optional[torch.Tensor]) -> ProtoAction: + assert n is not None if hpo_params["scale_noise_perturb"]: return ProtoAction( torch.clamp( diff --git a/tune/protox/env/space/latent_space/lsc_index_space.py b/tune/protox/env/space/latent_space/lsc_index_space.py index e1425081..87290dcf 100644 --- a/tune/protox/env/space/latent_space/lsc_index_space.py +++ b/tune/protox/env/space/latent_space/lsc_index_space.py @@ -35,7 +35,7 @@ def __init__( latent_dim: int = 0, index_output_transform: Optional[Callable[[ProtoAction], ProtoAction]] = None, index_noise_scale: Optional[ - Callable[[ProtoAction, torch.Tensor], ProtoAction] + Callable[[ProtoAction, Optional[torch.Tensor]], ProtoAction] ] = None, logger: Optional[Logger] = None, lsc: Optional[LSC] = None,