Skip to content

Commit

Permalink
Merge pull request #3158 from google:haiku-migration-guide-transforms
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 542999636
  • Loading branch information
Flax Authors committed Jun 24, 2023
2 parents f447e63 + d5a25b7 commit 115b8a5
Showing 1 changed file with 141 additions and 0 deletions.
141 changes: 141 additions & 0 deletions docs/guides/haiku_migration_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -571,3 +571,144 @@ explicitly passed, even though the module does not use any stochastic
operations during ``apply``. In Flax this is not necessary. The Haiku ``rng``
is set to ``None`` here, but you could also use ``hk.without_apply_rng`` on the
``apply`` function to remove the ``rng`` argument.


Lifted Transforms
-----------------

Both Flax and Haiku provide a set of transforms, which we will refer to as lifted transforms,
that wrap JAX transformations in such a way that they can be used with Modules and sometimes
provide additional functionality. In this section we will take a look at how to use the
lifted version of ``scan`` in both Flax and Haiku to implement a simple RNN layer.

To begin, we will first define a ``RNNCell`` module that will contain the logic for a single
step of the RNN. We will also define a ``initial_state`` method that will be used to initialize
the state (a.k.a. ``carry``) of the RNN. Like with ``jax.lax.scan``, the ``RNNCell.__call__``
method will be a function that takes the carry and input, and returns the new
carry and output. In this case, the carry and the output are the same.

.. codediff::
:title_left: Haiku
:title_right: Flax
:sync:

class RNNCell(hk.Module):
def __init__(self, hidden_size: int, name=None):
super().__init__(name=name)
self.hidden_size = hidden_size

def __call__(self, carry, x):
x = jnp.concatenate([carry, x], axis=-1)
x = hk.Linear(self.hidden_size)(x)
x = jax.nn.relu(x)
return x, x

def initial_state(self, batch_size: int):
return jnp.zeros((batch_size, self.hidden_size))

---

class RNNCell(nn.Module):
hidden_size: int


@nn.compact
def __call__(self, carry, x):
x = jnp.concatenate([carry, x], axis=-1)
x = nn.Dense(self.hidden_size)(x)
x = jax.nn.relu(x)
return x, x

def initial_state(self, batch_size: int):
return jnp.zeros((batch_size, self.hidden_size))

Next, we will define a ``RNN`` Module that will contain the logic for the entire RNN.
In Haiku, we will first initialze the ``RNNCell``, then use it to construct the ``carry``,
and finally use ``hk.scan`` to run the ``RNNCell`` over the input sequence. In Flax its
done a bit different, we will use ``nn.scan`` to define a new temporary type that wraps
``RNNCell``. During this process we will also specify instruct ``nn.scan`` to broadcast
the ``params`` collection (all steps share the same parameters) and to not split the
``params`` rng stream (so all steps intialize with the same parameters), and finally
we will specify that we want scan to run over the second axis of the input and stack
the outputs along the second axis as well. We will then use this temporary type immediately
to create an instance of the lifted ``RNNCell`` and use it to create the ``carry`` and
the run the ``__call__`` method which will ``scan`` over the sequence.

.. codediff::
:title_left: Haiku
:title_right: Flax
:sync:

class RNN(hk.Module):
def __init__(self, hidden_size: int, name=None):
super().__init__(name=name)
self.hidden_size = hidden_size

def __call__(self, x):
cell = RNNCell(self.hidden_size)
carry = cell.initial_state(x.shape[0])
carry, y = hk.scan(cell, carry, jnp.swapaxes(x, 1, 0))
return jnp.swapaxes(y, 0, 1)

---

class RNN(nn.Module):
hidden_size: int

@nn.compact
def __call__(self, x):
rnn = nn.scan(RNNCell, variable_broadcast='params', split_rngs={'params': False},
in_axes=1, out_axes=1)(self.hidden_size)
carry = rnn.initial_state(x.shape[0])
carry, y = rnn(carry, x)
return y

In general, the main difference between lifted transforms between Flax and Haiku is that
in Haiku the lifted transforms don't operate over the state, that is, Haiku will handle the
``params`` and ``state`` in such a way that it keeps the same shape inside and outside of the
transform. In Flax, the lifted transforms can operate over both variable collections and rng
streams, the user must define how different collections are treated by each transform
according to the transforms semantics.

Finally, lets quickly view how the ``RNN`` Module would be used in both Haiku and Flax.

.. codediff::
:title_left: Haiku
:title_right: Flax
:sync:

def forward(x):
return RNN(64)(x)

model = hk.without_apply_rng(hk.transform(forward))

params = model.init(
PRNGKey(0),
x=jax.numpy.ones((3, 12, 32)),
)

y = model.apply(
params,
x=jax.numpy.ones((3, 12, 32)),
)

---

...


model = RNN(64)

variables = model.init(
PRNGKey(0),
x=jax.numpy.ones((3, 12, 32)),
)
params = variables['params']
y = model.apply(
{'params': params},
x=jax.numpy.ones((3, 12, 32)),
)

The only notable change with respect to the examples in the previous sections is that
this time around we used ``hk.without_apply_rng``in Haiku so that we didn't have to
the ``rng`` argument as ``None`` to the ``apply``method.

0 comments on commit 115b8a5

Please sign in to comment.