From 43074eb1d8faad37a7eac9972a1ece3d8ce31bdd Mon Sep 17 00:00:00 2001 From: Daniel Duckworth Date: Wed, 30 Oct 2024 10:57:39 -0700 Subject: [PATCH] Add support for storing arbitrary PyTrees with `Module.perturb()` PiperOrigin-RevId: 691480632 --- flax/linen/module.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/flax/linen/module.py b/flax/linen/module.py index 9de568b874..f8a57b9546 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -2702,11 +2702,13 @@ def perturb( if not self.scope.has_variable(collection, name): self.scope.reserve(name, collection) self._state.children[name] = collection - self.scope.put_variable(collection, name, jnp.zeros_like(value)) # type: ignore + zeros = jax.tree.map(jnp.zeros_like, value) + self.scope.put_variable(collection, name, zeros) # type: ignore if collection in self.scope.root._variables: if self.scope.has_variable(collection, name): - value += self.scope.get_variable(collection, name) # type: ignore + old_value = self.scope.get_variable(collection, name) + value = jax.tree.map(jnp.add, value, old_value) # type: ignore else: raise ValueError(f"Perturbation collection {collection} present, but " f"missing perturbation variable {name}")