From d5a25b7754799c30954d2f456e4f62c55d49fd30 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Wed, 21 Jun 2023 22:50:06 +0000 Subject: [PATCH] add lifted transforms to haiku migration guide --- docs/guides/haiku_migration_guide.rst | 141 ++++++++++++++++++++++++++ 1 file changed, 141 insertions(+) diff --git a/docs/guides/haiku_migration_guide.rst b/docs/guides/haiku_migration_guide.rst index e4629a6601..7e6e5cef59 100644 --- a/docs/guides/haiku_migration_guide.rst +++ b/docs/guides/haiku_migration_guide.rst @@ -571,3 +571,144 @@ explicitly passed, even though the module does not use any stochastic operations during ``apply``. In Flax this is not necessary. The Haiku ``rng`` is set to ``None`` here, but you could also use ``hk.without_apply_rng`` on the ``apply`` function to remove the ``rng`` argument. + + +Lifted Transforms +----------------- + +Both Flax and Haiku provide a set of transforms, which we will refer to as lifted transforms, +that wrap JAX transformations in such a way that they can be used with Modules and sometimes +provide additional functionality. In this section we will take a look at how to use the +lifted version of ``scan`` in both Flax and Haiku to implement a simple RNN layer. + +To begin, we will first define a ``RNNCell`` module that will contain the logic for a single +step of the RNN. We will also define a ``initial_state`` method that will be used to initialize +the state (a.k.a. ``carry``) of the RNN. Like with ``jax.lax.scan``, the ``RNNCell.__call__`` +method will be a function that takes the carry and input, and returns the new +carry and output. In this case, the carry and the output are the same. + +.. codediff:: + :title_left: Haiku + :title_right: Flax + :sync: + + class RNNCell(hk.Module): + def __init__(self, hidden_size: int, name=None): + super().__init__(name=name) + self.hidden_size = hidden_size + + def __call__(self, carry, x): + x = jnp.concatenate([carry, x], axis=-1) + x = hk.Linear(self.hidden_size)(x) + x = jax.nn.relu(x) + return x, x + + def initial_state(self, batch_size: int): + return jnp.zeros((batch_size, self.hidden_size)) + + --- + + class RNNCell(nn.Module): + hidden_size: int + + + @nn.compact + def __call__(self, carry, x): + x = jnp.concatenate([carry, x], axis=-1) + x = nn.Dense(self.hidden_size)(x) + x = jax.nn.relu(x) + return x, x + + def initial_state(self, batch_size: int): + return jnp.zeros((batch_size, self.hidden_size)) + +Next, we will define a ``RNN`` Module that will contain the logic for the entire RNN. +In Haiku, we will first initialze the ``RNNCell``, then use it to construct the ``carry``, +and finally use ``hk.scan`` to run the ``RNNCell`` over the input sequence. In Flax its +done a bit different, we will use ``nn.scan`` to define a new temporary type that wraps +``RNNCell``. During this process we will also specify instruct ``nn.scan`` to broadcast +the ``params`` collection (all steps share the same parameters) and to not split the +``params`` rng stream (so all steps intialize with the same parameters), and finally +we will specify that we want scan to run over the second axis of the input and stack +the outputs along the second axis as well. We will then use this temporary type immediately +to create an instance of the lifted ``RNNCell`` and use it to create the ``carry`` and +the run the ``__call__`` method which will ``scan`` over the sequence. + +.. codediff:: + :title_left: Haiku + :title_right: Flax + :sync: + + class RNN(hk.Module): + def __init__(self, hidden_size: int, name=None): + super().__init__(name=name) + self.hidden_size = hidden_size + + def __call__(self, x): + cell = RNNCell(self.hidden_size) + carry = cell.initial_state(x.shape[0]) + carry, y = hk.scan(cell, carry, jnp.swapaxes(x, 1, 0)) + return jnp.swapaxes(y, 0, 1) + + --- + + class RNN(nn.Module): + hidden_size: int + + @nn.compact + def __call__(self, x): + rnn = nn.scan(RNNCell, variable_broadcast='params', split_rngs={'params': False}, + in_axes=1, out_axes=1)(self.hidden_size) + carry = rnn.initial_state(x.shape[0]) + carry, y = rnn(carry, x) + return y + +In general, the main difference between lifted transforms between Flax and Haiku is that +in Haiku the lifted transforms don't operate over the state, that is, Haiku will handle the +``params`` and ``state`` in such a way that it keeps the same shape inside and outside of the +transform. In Flax, the lifted transforms can operate over both variable collections and rng +streams, the user must define how different collections are treated by each transform +according to the transforms semantics. + +Finally, lets quickly view how the ``RNN`` Module would be used in both Haiku and Flax. + +.. codediff:: + :title_left: Haiku + :title_right: Flax + :sync: + + def forward(x): + return RNN(64)(x) + + model = hk.without_apply_rng(hk.transform(forward)) + + params = model.init( + PRNGKey(0), + x=jax.numpy.ones((3, 12, 32)), + ) + + y = model.apply( + params, + x=jax.numpy.ones((3, 12, 32)), + ) + + --- + + ... + + + model = RNN(64) + + variables = model.init( + PRNGKey(0), + x=jax.numpy.ones((3, 12, 32)), + ) + params = variables['params'] + y = model.apply( + {'params': params}, + x=jax.numpy.ones((3, 12, 32)), + ) + +The only notable change with respect to the examples in the previous sections is that +this time around we used ``hk.without_apply_rng``in Haiku so that we didn't have to +the ``rng`` argument as ``None`` to the ``apply``method. \ No newline at end of file