Skip to content

Commit

Permalink
Merge pull request #3277 from IvyZX:pen
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 565205192
  • Loading branch information
Flax Authors committed Sep 14, 2023
2 parents 302a7fa + cbed59a commit d9f83bf
Showing 1 changed file with 44 additions and 2 deletions.
46 changes: 44 additions & 2 deletions docs/guides/haiku_migration_guide.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

Migrating from Haiku to Flax
==========
============================

This guide will walk through the process of migrating Haiku models to Flax,
and highlight the differences between the two libraries.
Expand Down Expand Up @@ -865,4 +865,46 @@ as we are using ``5`` layers.
kernel: (5, 64, 64),
},
},
})
})
Top-level Haiku functions vs top-level Flax modules
-----------------------------------

In Haiku, it is possible to write the entire model as a single function by using the raw ``hk.{get,set}_{parameter,state}`` to define/access model parameters and states. It very common to write the top-level "Module" as a function instead:

The Flax team recommends a more Module-centric approach that uses `__call__` to define the forward function. The corresponding accessor will be `nn.module.param` and `nn.module.variable` (go to `Handling State <#handling-state>`__ for an explanaion on collections).

.. codediff::
:title_left: Haiku
:title_right: Flax
:sync:

def forward(x):


counter = hk.get_state('counter', shape=[], dtype=jnp.int32, init=jnp.ones)
multiplier = hk.get_parameter('multiplier', shape=[1,], dtype=x.dtype, init=jnp.ones)
output = x + multiplier * counter
hk.set_state("counter", counter + 1)

return output

model = hk.transform_with_state(forward)

params, state = model.init(PRNGKey(0), jax.numpy.ones((1, 64)))

---

class FooModule(nn.Module):
@nn.compact
def __call__(self, x):
counter = self.variable('counter', 'count', lambda: jnp.ones((), jnp.int32))
multiplier = self.param('multiplier', nn.initializers.ones_init(), [1,], x.dtype)
output = x + multiplier * counter.value
if not self.is_initializing(): # otherwise model.init() also increases it
counter.value += 1
return output

model = FooModule()
variables = model.init(PRNGKey(0), jax.numpy.ones((1, 64)))
params, counter = variables['params'], variables['counter']

0 comments on commit d9f83bf

Please sign in to comment.