From ac668d54af556b70391c23a7216e676d26375c67 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 26 Sep 2023 06:27:02 -0700 Subject: [PATCH] [LSC] Ignore incorrect type annotations related to jax.numpy APIs PiperOrigin-RevId: 568520617 Change-Id: I372051fa57315aa4710e842a0fc4582c685a78c6 --- acme/agents/jax/ail/losses.py | 8 ++++---- acme/agents/jax/ail/rewards.py | 2 +- acme/agents/jax/mpo/categorical_mpo.py | 4 ++-- acme/agents/jax/mpo/utils.py | 5 +++-- acme/agents/jax/r2d2/learning.py | 5 +++-- acme/agents/jax/rnd/learning.py | 2 +- 6 files changed, 14 insertions(+), 12 deletions(-) diff --git a/acme/agents/jax/ail/losses.py b/acme/agents/jax/ail/losses.py index 83f88723ff..8d68dbd474 100644 --- a/acme/agents/jax/ail/losses.py +++ b/acme/agents/jax/ail/losses.py @@ -114,7 +114,7 @@ def loss_fn( 'entropy_loss': entropy_loss, 'classification_loss': classification_loss } - return total_loss, (metrics, discriminator_state) + return total_loss, (metrics, discriminator_state) # pytype: disable=bad-return-type # jnp-type return loss_fn @@ -166,7 +166,7 @@ def loss_fn( 'entropy_loss': entropy_loss, 'classification_loss': classification_loss } - return total_loss, (metrics, discriminator_state) + return total_loss, (metrics, discriminator_state) # pytype: disable=bad-return-type # jnp-type return loss_fn @@ -194,7 +194,7 @@ def _compute_gradient_penalty(gradient_penalty_data: types.Transition, gradients.next_observation]) gradient_norms = jnp.linalg.norm(gradients + 1e-8) k = gradient_penalty_target * jnp.ones_like(gradient_norms) - return jnp.mean(jnp.square(gradient_norms - k)) + return jnp.mean(jnp.square(gradient_norms - k)) # pytype: disable=bad-return-type # jnp-type def add_gradient_penalty(base_loss: Loss, @@ -231,6 +231,6 @@ def apply_discriminator_fn(transitions: types.Transition) -> float: total_loss = partial_loss + gradient_penalty losses['total_loss'] = total_loss - return total_loss, (losses, discriminator_state) + return total_loss, (losses, discriminator_state) # pytype: disable=bad-return-type # jnp-type return loss_fn diff --git a/acme/agents/jax/ail/rewards.py b/acme/agents/jax/ail/rewards.py index ea737ad86d..2586734aab 100644 --- a/acme/agents/jax/ail/rewards.py +++ b/acme/agents/jax/ail/rewards.py @@ -71,6 +71,6 @@ def imitation_reward(logits: networks_lib.Logits) -> float: # pylint: disable=invalid-unary-operand-type rewards = jnp.clip( rewards, a_min=-max_reward_magnitude, a_max=max_reward_magnitude) - return rewards + return rewards # pytype: disable=bad-return-type # jnp-type return imitation_reward # pytype: disable=bad-return-type # jax-ndarray diff --git a/acme/agents/jax/mpo/categorical_mpo.py b/acme/agents/jax/mpo/categorical_mpo.py index 0faea932f7..8dfcd0742e 100644 --- a/acme/agents/jax/mpo/categorical_mpo.py +++ b/acme/agents/jax/mpo/categorical_mpo.py @@ -165,7 +165,7 @@ def __call__( loss = loss_policy + loss_kl + loss_dual # Create statistics. - stats = CategoricalMPOStats( + stats = CategoricalMPOStats( # pytype: disable=wrong-arg-types # jnp-type # Dual Variables. dual_alpha=jnp.mean(alpha), dual_temperature=jnp.mean(temperature), @@ -183,7 +183,7 @@ def __call__( q_min=jnp.mean(jnp.min(q_values, axis=0)), q_max=jnp.mean(jnp.max(q_values, axis=0)), entropy_online=jnp.mean(online_action_distribution.entropy()), - entropy_target=jnp.mean(target_action_distribution.entropy()) + entropy_target=jnp.mean(target_action_distribution.entropy()), ) return loss, stats diff --git a/acme/agents/jax/mpo/utils.py b/acme/agents/jax/mpo/utils.py index 88fd73f065..ba4e60420d 100644 --- a/acme/agents/jax/mpo/utils.py +++ b/acme/agents/jax/mpo/utils.py @@ -99,10 +99,11 @@ def make_sequences_from_transitions( transitions.next_observation) reward = duplicate(transitions.reward) - return adders.Step( + return adders.Step( # pytype: disable=wrong-arg-types # jnp-type observation=observation, action=duplicate(transitions.action), reward=reward, discount=duplicate(transitions.discount), start_of_episode=jnp.zeros_like(reward, dtype=jnp.bool_), - extras=jax.tree_map(duplicate, transitions.extras)) + extras=jax.tree_map(duplicate, transitions.extras), + ) diff --git a/acme/agents/jax/r2d2/learning.py b/acme/agents/jax/r2d2/learning.py index f974d76b01..1b6df2186e 100644 --- a/acme/agents/jax/r2d2/learning.py +++ b/acme/agents/jax/r2d2/learning.py @@ -228,12 +228,13 @@ def update_priorities( logging.info('Total number of params: %d', sum(tree.flatten(sizes.values()))) - state = TrainingState( + state = TrainingState( # pytype: disable=wrong-arg-types # jnp-type params=initial_params, target_params=initial_params, opt_state=opt_state, steps=jnp.array(0), - random_key=random_key) + random_key=random_key, + ) # Replicate parameters. self._state = utils.replicate_in_all_devices(state) diff --git a/acme/agents/jax/rnd/learning.py b/acme/agents/jax/rnd/learning.py index 168c213289..dde00fff1e 100644 --- a/acme/agents/jax/rnd/learning.py +++ b/acme/agents/jax/rnd/learning.py @@ -107,7 +107,7 @@ def rnd_loss( predictor_output = networks.predictor.apply(predictor_params, transitions.observation, transitions.action) - return jnp.mean(jnp.square(target_output - predictor_output)) + return jnp.mean(jnp.square(target_output - predictor_output)) # pytype: disable=bad-return-type # jnp-type class RNDLearner(acme.Learner):