Skip to content

Commit

Permalink
Merge pull request #3292 from chiamp:haiku_guide_fix
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 560218315
  • Loading branch information
Flax Authors committed Aug 25, 2023
2 parents 4879b4c + 031bf3e commit ffb68c3
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions docs/guides/haiku_migration_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,10 @@ state. As before, in Flax you construct the Module directly.
To initialize both the parameters and state you just call the ``init`` method
as before. However, in Haiku you now get ``state`` as a second return value, and
in Flax you get a new ``batch_stats`` collection in the ``variables`` dictionary.
Note that since ``hk.BatchNorm`` only initializes batch statistics when
``is_training=True``, we must set ``training=True`` when initializing parameters
of a Haiku model with an ``hk.BatchNorm`` layer. In Flax, we can set
``training=False`` as usual.

.. codediff::
:title_left: Haiku
Expand All @@ -307,15 +311,15 @@ in Flax you get a new ``batch_stats`` collection in the ``variables`` dictionary
sample_x = jax.numpy.ones((1, 784))
params, state = model.init(
PRNGKey(0),
sample_x, training=True # <== inputs
sample_x, training=True # <== inputs #!
)
...

---

sample_x = jax.numpy.ones((1, 784))
variables = model.init(
PRNGKey(0),
PRNGKey(0), #!
sample_x, training=False # <== inputs
)
params, batch_stats = variables["params"], variables["batch_stats"]
Expand Down

0 comments on commit ffb68c3

Please sign in to comment.