using scan with multiple layer - each with diffrent initialization - working example #14715
Unanswered
jakubMitura14
asked this question in
Q&A
Replies: 1 comment 1 reply
-
Hey @jakubMitura14, I added some
import flax.linen as nn
import jax
import jax.numpy as jnp
class Layer(nn.Module):
dummy: int
@nn.compact
def __call__(self, c, x, t):
jax.debug.print("x: {x}", x=x)
jax.debug.print("t: {t}", t=t)
x = nn.Dense(len(x))(x)
x = jax.nn.softmax(jnp.exp(t) * x)
c = c + jnp.sum(jnp.ravel(x))
jax.debug.print("c: {c}\n", c=c.flatten())
return c, x
class Model(nn.Module):
@nn.compact
def __call__(self, x, t):
LayerScanned = nn.scan(
Layer,
variable_axes={"params": 0, "dummy": 0},
split_rngs={"params": False},
length=5,
in_axes=(0, nn.broadcast),
out_axes=0,
)
carry = jnp.zeros_like(x)
carry, y = LayerScanned(10)(carry, x, t)
jax.debug.print("final y: {y}", y=y)
jax.debug.print("final c: {c}", c=carry)
return y, carry
x = jax.random.uniform(jax.random.PRNGKey(0), (5, 2))
model = Model()
print("\ninit")
params = model.init(jax.random.PRNGKey(0), x, t=1.0)
print("\napply")
y, c = model.apply(params, x, t=0.1)
|
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
The post is also published in flax forum https://github.com/google/flax/discussions/2912 - so this one can be removed by administrator
I used google/flax#2127 example a a basis
I want to have multiple layers of the same type but with diffrent value of a dummy variable.
Then sequentially invoke each layer, summing theit outputs to a single number.
Main problem is here is that when I pass array to LayerScanned I got the same array broadcasted to each layer instead of hetting diffrent inteer for each.
I also for some reason get the array as a final result not cumulative sum (although this is a minor problem)
Beta Was this translation helpful? Give feedback.
All reactions