Skip to content

Commit

Permalink
Merge pull request #3195 from google:haiku-migration-scan-over-layers
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 550293709
  • Loading branch information
Flax Authors committed Jul 23, 2023
2 parents 4f1884c + 06583b5 commit 9ba7cc0
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 1 deletion.
150 changes: 149 additions & 1 deletion docs/guides/haiku_migration_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -713,4 +713,152 @@ Finally, let's quickly view how the ``RNN`` Module would be used in both Haiku a

The only notable change with respect to the examples in the previous sections is that
this time around we used ``hk.without_apply_rng`` in Haiku so we didn't have to
pass the ``rng`` argument as ``None`` to the ``apply`` method.
pass the ``rng`` argument as ``None`` to the ``apply`` method.

Scan over layers
----------------
One very important application of ``scan`` is apply a sequence of layers iteratively
over an input, passing the output of each layer as the input to the next layer. This
is very useful to reduce compilation time for big models. As an example we will create
a simple ``Block`` Module, and then use it inside an ``MLP`` Module that will apply
the ``Block`` Module ``num_layers`` times.

In Haiku, we define the ``Block`` Module as usual, and then inside ``MLP`` we will
use ``hk.experimental.layer_stack`` over a ``stack_block`` function to create a stack
of ``Block`` Modules. In Flax, the definition of ``Block`` is a little different,
``__call__`` will accept and return a second dummy input/output that in both cases will
be ``None``. In ``MLP``, we will use ``nn.scan`` as in the previous example, but
by setting ``split_rngs={'params': True}`` and ``variable_axes={'params': 0}``
we are telling ``nn.scan`` create different parameters for each step and slice the
``params`` collection along the first axis, effectively implementing a stack of
``Block`` Modules as in Haiku.


.. codediff::
:title_left: Haiku
:title_right: Flax
:sync:

class Block(hk.Module):
def __init__(self, features: int, name=None):
super().__init__(name=name)
self.features = features

def __call__(self, x, training: bool):
x = hk.Linear(self.features)(x)
x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x)
x = jax.nn.relu(x)
return x

class MLP(hk.Module):
def __init__(self, features: int, num_layers: int, name=None):
super().__init__(name=name)
self.features = features
self.num_layers = num_layers

def __call__(self, x, training: bool):
@hk.experimental.layer_stack(self.num_layers)
def stack_block(x):
return Block(self.features)(x, training)

stack = hk.experimental.layer_stack(self.num_layers)
return stack_block(x)

---

class Block(nn.Module):
features: int
training: bool

@nn.compact
def __call__(self, x, _):
x = nn.Dense(self.features)(x)
x = nn.Dropout(0.5)(x, deterministic=not self.training)
x = jax.nn.relu(x)
return x, None

class MLP(nn.Module):
features: int
num_layers: int

@nn.compact
def __call__(self, x, training: bool):
ScanBlock = nn.scan(
Block, variable_axes={'params': 0}, split_rngs={'params': True},
length=self.num_layers)

y, _ = ScanBlock(self.features, training)(x, None)
return y

Notice how in Flax we pass ``None`` as the second argument to ``ScanBlock`` and ignore
its second output. These represent the inputs/outputs per-step but they are ``None``
because in this case we don't have any.

Initializing each model is the same as in previous examples. In this case,
we will be specifying that we want to use ``5`` layers each with ``64`` features.

.. codediff::
:title_left: Haiku
:title_right: Flax
:sync:

def forward(x, training: bool):
return MLP(64, num_layers=5)(x, training)

model = hk.transform(forward)

sample_x = jax.numpy.ones((1, 64))
params = model.init(
PRNGKey(0),
sample_x, training=False # <== inputs
)
...

---

...


model = MLP(64, num_layers=5)

sample_x = jax.numpy.ones((1, 64))
variables = model.init(
PRNGKey(0),
sample_x, training=False # <== inputs
)
params = variables['params']

When using scan over layers the one thing you should notice is that all layers
are fused into a single layer whose parameters have an extra "layer" dimension on
the first axis. In this case, the shape of all parameters will start with ``(5, ...)``
as we are using ``5`` layers.

.. tab-set::

.. tab-item:: Haiku
:sync: Haiku

.. code-block:: python
...
{
'mlp/__layer_stack_no_per_layer/block/linear': {
'b': (5, 64),
'w': (5, 64, 64)
}
}
...
.. tab-item:: Flax
:sync: Flax

.. code-block:: python
FrozenDict({
ScanBlock_0: {
Dense_0: {
bias: (5, 64),
kernel: (5, 64, 64),
},
},
})
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ recommonmark
ipython_genutils
sphinx-design
jupytext==1.13.8
dm-haiku

# Need to pin docutils to 0.16 to make bulleted lists appear correctly on
# ReadTheDocs: https://stackoverflow.com/a/68008428
Expand Down

0 comments on commit 9ba7cc0

Please sign in to comment.