-
Hello! I was exploring the new Linen API and found an error. import functools
from typing import Any, Callable, Sequence, Optional
import numpy as np
import jax
from jax import lax, random, numpy as jnp
import flax
from flax.core import freeze, unfreeze
from flax import linen as nn
class ExplicitMLP(nn.Module):
features: Sequence[int]
def setup(self):
# we automatically know what to do with lists, dicts of submodules
self.layers = [nn.Dense(self.features[0])]
self.layers += [nn.Dense(feat) for feat in self.features[1:]]
# instead of self.layers = [nn.Dense(feat) for feat in self.features]
# for single submodules, we would just write:
# self.layer1 = nn.Dense(self, feat1)
def __call__(self, inputs):
x = inputs
for i, lyr in enumerate(self.layers):
x = lyr(x)
if i != len(self.layers) - 1:
x = nn.relu(x)
return x
key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (4,4))
model = ExplicitMLP(features=[3,4,5])
init_variables = model.init(key2, x)
y = model.apply(init_variables, x) The error message:
|
Beta Was this translation helpful? Give feedback.
Answered by
avital
Dec 24, 2020
Replies: 2 comments 1 reply
-
Hi @Daulbaev -- there are two issues here:
|
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
Daulbaev
-
So, for now, in your case you could do: layers = [nn.Dense(self.features[0])]
layers += [nn.Dense(feat) for feat in self.features[1:]]
self.layers = layers |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi @Daulbaev -- there are two issues here:
setup
, we attach submodules directly during assignment (via__setattr__
). So Linen doesn't "see" the modification done with+=
. There have been a few proposals for making this clearer, including: (a) making assigned values read only after assignment, (b) requiring the use ofModuleList
, where we could then override__iadd__
, (c) automatically converting assigned values toModuleList
s with that functionality.