Skip to content

Commit

Permalink
Add support for storing arbitrary PyTrees with Module.perturb()
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 691480632
  • Loading branch information
duckworthd authored and Flax Authors committed Oct 30, 2024
1 parent e4dad9c commit 43074eb
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2702,11 +2702,13 @@ def perturb(
if not self.scope.has_variable(collection, name):
self.scope.reserve(name, collection)
self._state.children[name] = collection
self.scope.put_variable(collection, name, jnp.zeros_like(value)) # type: ignore
zeros = jax.tree.map(jnp.zeros_like, value)
self.scope.put_variable(collection, name, zeros) # type: ignore

if collection in self.scope.root._variables:
if self.scope.has_variable(collection, name):
value += self.scope.get_variable(collection, name) # type: ignore
old_value = self.scope.get_variable(collection, name)
value = jax.tree.map(jnp.add, value, old_value) # type: ignore
else:
raise ValueError(f"Perturbation collection {collection} present, but "
f"missing perturbation variable {name}")
Expand Down

0 comments on commit 43074eb

Please sign in to comment.