diff --git a/acme/agents/jax/pwil/rewarder.py b/acme/agents/jax/pwil/rewarder.py index b94a41bdfc..e12a4edabc 100644 --- a/acme/agents/jax/pwil/rewarder.py +++ b/acme/agents/jax/pwil/rewarder.py @@ -63,8 +63,10 @@ def __init__(self, self._std = (self._std < 1e-6) + self._std self.expert_atoms = self._vectorized_demonstrations / self._std - self._compute_norm = jax.jit(lambda a, b: jnp.linalg.norm(a - b, axis=1), - device=jax.devices('cpu')[0]) + self._compute_norm = jax.jit( + lambda a, b: jnp.linalg.norm(a - b, axis=1), + device=jax.local_devices(backend='cpu')[0], + ) def _vectorize(self, demonstrations_it: Iterator[types.Transition]) -> np.ndarray: