diff --git a/acme/agents/jax/mpo/config.py b/acme/agents/jax/mpo/config.py index 3e1d70ab31..596cf29b3b 100644 --- a/acme/agents/jax/mpo/config.py +++ b/acme/agents/jax/mpo/config.py @@ -32,8 +32,9 @@ class MPOConfig: discrete_policy: bool = False # Specification of the type of experience the learner will consume. - experience_type: mpo_types.ExperienceType = mpo_types.FromTransitions( - n_step=5) + experience_type: mpo_types.ExperienceType = dataclasses.field( + default_factory=lambda: mpo_types.FromTransitions(n_step=5) + ) num_stacked_observations: int = 1 # Optional data-augmentation transformation for observations. observation_transform: Optional[Callable[[types.NestedTensor],