From cbed59a6064e87afc5edaa357c1d944b7d39ebd6 Mon Sep 17 00:00:00 2001 From: ivyzheng Date: Fri, 25 Aug 2023 16:24:10 -0700 Subject: [PATCH] add method migration to Haiku guide --- docs/guides/haiku_migration_guide.rst | 46 +++++++++++++++++++++++++-- 1 file changed, 44 insertions(+), 2 deletions(-) diff --git a/docs/guides/haiku_migration_guide.rst b/docs/guides/haiku_migration_guide.rst index 4b90d33194..c0c6097a70 100644 --- a/docs/guides/haiku_migration_guide.rst +++ b/docs/guides/haiku_migration_guide.rst @@ -1,6 +1,6 @@ Migrating from Haiku to Flax -========== +============================ This guide will walk through the process of migrating Haiku models to Flax, and highlight the differences between the two libraries. @@ -865,4 +865,46 @@ as we are using ``5`` layers. kernel: (5, 64, 64), }, }, - }) \ No newline at end of file + }) + +Top-level Haiku functions vs top-level Flax modules +----------------------------------- + +In Haiku, it is possible to write the entire model as a single function by using the raw ``hk.{get,set}_{parameter,state}`` to define/access model parameters and states. It very common to write the top-level "Module" as a function instead: + +The Flax team recommends a more Module-centric approach that uses `__call__` to define the forward function. The corresponding accessor will be `nn.module.param` and `nn.module.variable` (go to `Handling State <#handling-state>`__ for an explanaion on collections). + +.. codediff:: + :title_left: Haiku + :title_right: Flax + :sync: + + def forward(x): + + + counter = hk.get_state('counter', shape=[], dtype=jnp.int32, init=jnp.ones) + multiplier = hk.get_parameter('multiplier', shape=[1,], dtype=x.dtype, init=jnp.ones) + output = x + multiplier * counter + hk.set_state("counter", counter + 1) + + return output + + model = hk.transform_with_state(forward) + + params, state = model.init(PRNGKey(0), jax.numpy.ones((1, 64))) + + --- + + class FooModule(nn.Module): + @nn.compact + def __call__(self, x): + counter = self.variable('counter', 'count', lambda: jnp.ones((), jnp.int32)) + multiplier = self.param('multiplier', nn.initializers.ones_init(), [1,], x.dtype) + output = x + multiplier * counter.value + if not self.is_initializing(): # otherwise model.init() also increases it + counter.value += 1 + return output + + model = FooModule() + variables = model.init(PRNGKey(0), jax.numpy.ones((1, 64))) + params, counter = variables['params'], variables['counter'] \ No newline at end of file