Convert for loop stacking, potentially with remat_scan
#3336
-
Hi all,
Thanks
import flax.linen as nn
import jax.numpy as jnp
from jax import random
import jax
from typing import List
def test_without_scan():
class MLP(nn.Module):
width:int
@nn.compact
def __call__(self, x):
x = nn.Dense(self.width)(x)
x = nn.Dense(self.width)(x)
x = nn.Dense(self.width)(x)
return nn.Dense(1)(x)
class BigModel(nn.Module):
width_list: List
@nn.compact
def __call__(self, x):
# x is two dimensional input
# y is one dimensional observation
# We want to stack two Dense networks with access to the intermediate values f.
# The network for the first variable has 8 nodes and 3 layers
# The network for the second variable has 4 nodes and three layers.
h = 0
for i, width in enumerate(self.width_list):
f = MLP(width, name=f'MLP_{i}')(x[:,i].reshape(-1,1))
self.sow("f", f"f_{i}", f)
h += f
return h
x = jnp.ones((5, 2))
model = BigModel(width_list=[8,4])
variables = model.init(random.PRNGKey(0), x)
print(variables)
test_without_scan() |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
From my understanding, scan only works when the |
Beta Was this translation helpful? Give feedback.
From my understanding, scan only works when the
target
arg is using the samemodule
arg; i.e. you want to apply the same module/layer iteratively over an input. In the documentation, it says "you can do this [scan] when you have a sequence of identical layers that you want to apply iteratively to an input".In your example since you are using MLP layers with different configurations, I don't think there's a more efficient implementation that can be written. With
remat
, you are reducing memory usage at the cost of more computation.