diff --git a/docs/guides/haiku_migration_guide.rst b/docs/guides/haiku_migration_guide.rst index b48180fb02..42dfd5e989 100644 --- a/docs/guides/haiku_migration_guide.rst +++ b/docs/guides/haiku_migration_guide.rst @@ -713,4 +713,152 @@ Finally, let's quickly view how the ``RNN`` Module would be used in both Haiku a The only notable change with respect to the examples in the previous sections is that this time around we used ``hk.without_apply_rng`` in Haiku so we didn't have to -pass the ``rng`` argument as ``None`` to the ``apply`` method. \ No newline at end of file +pass the ``rng`` argument as ``None`` to the ``apply`` method. + +Scan over layers +---------------- +One very important application of ``scan`` is apply a sequence of layers iteratively +over an input, passing the output of each layer as the input to the next layer. This +is very useful to reduce compilation time for big models. As an example we will create +a simple ``Block`` Module, and then use it inside an ``MLP`` Module that will apply +the ``Block`` Module ``num_layers`` times. + +In Haiku, we define the ``Block`` Module as usual, and then inside ``MLP`` we will +use ``hk.experimental.layer_stack`` over a ``stack_block`` function to create a stack +of ``Block`` Modules. In Flax, the definition of ``Block`` is a little different, +``__call__`` will accept and return a second dummy input/output that in both cases will +be ``None``. In ``MLP``, we will use ``nn.scan`` as in the previous example, but +by setting ``split_rngs={'params': True}`` and ``variable_axes={'params': 0}`` +we are telling ``nn.scan`` create different parameters for each step and slice the +``params`` collection along the first axis, effectively implementing a stack of +``Block`` Modules as in Haiku. + + +.. codediff:: + :title_left: Haiku + :title_right: Flax + :sync: + + class Block(hk.Module): + def __init__(self, features: int, name=None): + super().__init__(name=name) + self.features = features + + def __call__(self, x, training: bool): + x = hk.Linear(self.features)(x) + x = hk.dropout(hk.next_rng_key(), 0.5 if training else 0, x) + x = jax.nn.relu(x) + return x + + class MLP(hk.Module): + def __init__(self, features: int, num_layers: int, name=None): + super().__init__(name=name) + self.features = features + self.num_layers = num_layers + + def __call__(self, x, training: bool): + @hk.experimental.layer_stack(self.num_layers) + def stack_block(x): + return Block(self.features)(x, training) + + stack = hk.experimental.layer_stack(self.num_layers) + return stack_block(x) + + --- + + class Block(nn.Module): + features: int + training: bool + + @nn.compact + def __call__(self, x, _): + x = nn.Dense(self.features)(x) + x = nn.Dropout(0.5)(x, deterministic=not self.training) + x = jax.nn.relu(x) + return x, None + + class MLP(nn.Module): + features: int + num_layers: int + + @nn.compact + def __call__(self, x, training: bool): + ScanBlock = nn.scan( + Block, variable_axes={'params': 0}, split_rngs={'params': True}, + length=self.num_layers) + + y, _ = ScanBlock(self.features, training)(x, None) + return y + +Notice how in Flax we pass ``None`` as the second argument to ``ScanBlock`` and ignore +its second output. These represent the inputs/outputs per-step but they are ``None`` +because in this case we don't have any. + +Initializing each model is the same as in previous examples. In this case, +we will be specifying that we want to use ``5`` layers each with ``64`` features. + +.. codediff:: + :title_left: Haiku + :title_right: Flax + :sync: + + def forward(x, training: bool): + return MLP(64, num_layers=5)(x, training) + + model = hk.transform(forward) + + sample_x = jax.numpy.ones((1, 64)) + params = model.init( + PRNGKey(0), + sample_x, training=False # <== inputs + ) + ... + + --- + + ... + + + model = MLP(64, num_layers=5) + + sample_x = jax.numpy.ones((1, 64)) + variables = model.init( + PRNGKey(0), + sample_x, training=False # <== inputs + ) + params = variables['params'] + +When using scan over layers the one thing you should notice is that all layers +are fused into a single layer whose parameters have an extra "layer" dimension on +the first axis. In this case, the shape of all parameters will start with ``(5, ...)`` +as we are using ``5`` layers. + +.. tab-set:: + + .. tab-item:: Haiku + :sync: Haiku + + .. code-block:: python + + ... + { + 'mlp/__layer_stack_no_per_layer/block/linear': { + 'b': (5, 64), + 'w': (5, 64, 64) + } + } + ... + + .. tab-item:: Flax + :sync: Flax + + .. code-block:: python + + FrozenDict({ + ScanBlock_0: { + Dense_0: { + bias: (5, 64), + kernel: (5, 64, 64), + }, + }, + }) \ No newline at end of file diff --git a/docs/requirements.txt b/docs/requirements.txt index 1ec5310527..f548acdf0c 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -9,6 +9,7 @@ recommonmark ipython_genutils sphinx-design jupytext==1.13.8 +dm-haiku # Need to pin docutils to 0.16 to make bulleted lists appear correctly on # ReadTheDocs: https://stackoverflow.com/a/68008428