diff --git a/acme/agents/jax/dqn/config.py b/acme/agents/jax/dqn/config.py index 3200f02367..2ec19f56c8 100644 --- a/acme/agents/jax/dqn/config.py +++ b/acme/agents/jax/dqn/config.py @@ -79,8 +79,9 @@ class DQNConfig: num_sgd_steps_per_step: int = 1 -def logspace_epsilons(num_epsilons: int, epsilon: float = 0.017 - ) -> Sequence[float]: +def logspace_epsilons( + num_epsilons: int, epsilon: float = 0.017 +) -> Union[Sequence[float], jnp.ndarray]: """`num_epsilons` of logspace-distributed values, with median `epsilon`.""" if num_epsilons <= 1: return (epsilon,) diff --git a/acme/agents/jax/r2d2/learning.py b/acme/agents/jax/r2d2/learning.py index 8ac7b475ab..f974d76b01 100644 --- a/acme/agents/jax/r2d2/learning.py +++ b/acme/agents/jax/r2d2/learning.py @@ -78,8 +78,8 @@ def loss( params: networks_lib.Params, target_params: networks_lib.Params, key_grad: networks_lib.PRNGKey, - sample: reverb.ReplaySample - ) -> Tuple[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]: + sample: reverb.ReplaySample, + ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Computes mean transformed N-step loss for a batch of sequences.""" # Get core state & warm it up on observations for a burn-in period. diff --git a/acme/jax/losses/mpo.py b/acme/jax/losses/mpo.py index 6902c00cf2..d2958f75a6 100644 --- a/acme/jax/losses/mpo.py +++ b/acme/jax/losses/mpo.py @@ -50,24 +50,24 @@ class MPOParams(NamedTuple): class MPOStats(NamedTuple): """NamedTuple to store loss statistics.""" - dual_alpha_mean: float - dual_alpha_stddev: float - dual_temperature: float + dual_alpha_mean: Union[float, jnp.ndarray] + dual_alpha_stddev: Union[float, jnp.ndarray] + dual_temperature: Union[float, jnp.ndarray] - loss_policy: float - loss_alpha: float - loss_temperature: float - kl_q_rel: float + loss_policy: Union[float, jnp.ndarray] + loss_alpha: Union[float, jnp.ndarray] + loss_temperature: Union[float, jnp.ndarray] + kl_q_rel: Union[float, jnp.ndarray] - kl_mean_rel: float - kl_stddev_rel: float + kl_mean_rel: Union[float, jnp.ndarray] + kl_stddev_rel: Union[float, jnp.ndarray] - q_min: float - q_max: float + q_min: Union[float, jnp.ndarray] + q_max: Union[float, jnp.ndarray] - pi_stddev_min: float - pi_stddev_max: float - pi_stddev_cond: float + pi_stddev_min: Union[float, jnp.ndarray] + pi_stddev_max: Union[float, jnp.ndarray] + pi_stddev_cond: Union[float, jnp.ndarray] penalty_kl_q_rel: Optional[float] = None diff --git a/acme/jax/networks/atari.py b/acme/jax/networks/atari.py index 37ca991879..7df76fbf3f 100644 --- a/acme/jax/networks/atari.py +++ b/acme/jax/networks/atari.py @@ -22,7 +22,7 @@ - X?: X is optional (e.g. optional batch/sequence dimension). """ -from typing import Optional, Tuple, Sequence +from typing import Any, Optional, Sequence, Tuple from acme.jax.networks import base from acme.jax.networks import duelling @@ -120,9 +120,9 @@ def __init__(self, num_actions: int): self._head = policy_value.PolicyValueHead(num_actions) self._num_actions = num_actions - def __call__(self, inputs: observation_action_reward.OAR, - state: hk.LSTMState) -> base.LSTMOutputs: - + def __call__( + self, inputs: observation_action_reward.OAR, state: hk.LSTMState + ) -> Any: embeddings = self._embed(inputs) # [B?, D+A+1] embeddings, new_state = self._core(embeddings, state) logits, value = self._head(embeddings) # logits: [B?, A], value: [B?, 1] @@ -133,8 +133,9 @@ def initial_state(self, batch_size: Optional[int], **unused_kwargs) -> hk.LSTMState: return self._core.initial_state(batch_size) - def unroll(self, inputs: observation_action_reward.OAR, - state: hk.LSTMState) -> base.LSTMOutputs: + def unroll( + self, inputs: observation_action_reward.OAR, state: hk.LSTMState + ) -> Any: """Efficient unroll that applies embeddings, MLP, & convnet in one pass.""" embeddings = self._embed(inputs) embeddings, new_states = hk.static_unroll(self._core, embeddings, state)