Replies: 2 comments
-
Hey @sendeniz, the problem is that you are not using Here is the code: import jax
import jax.numpy as jnp
import flax.linen as nn
class RnnCell(nn.Module):
hidden_size: int
activation: str = "tanh"
def setup(self):
if self.activation not in ["tanh", "relu", "sigmoid"]:
raise ValueError("Invalid nonlinearity selected for RNN. Please use tanh, relu, or sigmoid.")
self.input2hidden = nn.Dense(self.hidden_size)
self.hidden2hidden = nn.Dense(self.hidden_size)
def __call__(self, inputs, hidden_state=None):
'''
Inputs: inputs (jax array) of shape [batchsize, input_size]
hidden state (jax array) of shape [batchsize, hidden_size]
Output: output (jax array) of shape [batchsize, hidden_size]
'''
# initialize hidden state at first iteration if None
if hidden_state is None:
hidden_state = jnp.zeros((inputs.shape[0], self.hidden_size))
# here the rnn magic happens, once we have a hidden state, it becomes the
# input for the next hidden state, that way we keep an internal memory
hidden_state = self.input2hidden(inputs) + self.hidden2hidden(hidden_state)
# apply activation function
if self.activation == "tanh":
output = nn.tanh(hidden_state)
elif self.activation == "relu":
output = nn.relu(hidden_state)
elif self.activation == "sigmoid":
output = nn.sigmoid(hidden_state)
return output
class SimpleRNN(nn.Module):
input_size: int
hidden_size: int
num_layers: int
output_size: int
activation: str = 'relu'
@nn.compact
def __call__(self, inputs, hidden_state=None):
'''
Inputs: inputs (jax array) of shape [batchsize, sequence length, inputsize]
Output: output (jax array) of shape [batchsize, outputsize]
'''
rnn_cell_list = []
if self.activation == 'tanh':
rnn_cell_list.append(RnnCell(self.hidden_size, "tanh"))
for _ in range(1, self.num_layers):
rnn_cell_list.append(RnnCell(self.hidden_size, "tanh"))
elif self.activation == 'relu':
rnn_cell_list.append(RnnCell(self.hidden_size, "relu"))
for _ in range(1, self.num_layers):
rnn_cell_list.append(RnnCell(self.hidden_size, "relu"))
elif self.activation == 'sigmoid':
rnn_cell_list.append(RnnCell(self.hidden_size, "sigmoid"))
for _ in range(1, self.num_layers):
rnn_cell_list.append(RnnCell(self.hidden_size, "sigmoid"))
else:
raise ValueError("Invalid activation. Please use tanh, relu, or sigmoid activation.")
fc = nn.Dense(self.output_size)
# Initialize hidden state at the first timestep if None
if hidden_state is None:
hidden_state = jnp.zeros((inputs.shape[0], self.num_layers, self.hidden_size))
hidden = hidden_state
outs = []
for t in range(inputs.shape[1]):
for layer in range(self.num_layers):
if layer == 0:
hidden_l = rnn_cell_list[layer](inputs[:, t, :], hidden[:, layer, :])
else:
hidden_l = rnn_cell_list[layer](hidden[:, layer - 1, :], hidden[:, layer, :])
hidden = hidden.at[:, layer, :].set(hidden_l)
outs.append(hidden_l)
# Select the last timestep indexed at [-1]
out = outs[-1].squeeze()
out = fc(out)
return out
def test_rnn():
model = SimpleRNN(input_size=28*28, hidden_size=128, num_layers=3, output_size=10)
x = jax.random.normal(jax.random.PRNGKey(0), (64, 784, 1))
vals = jax.random.normal(jax.random.PRNGKey(1), (64, 784, 783))
x = jnp.concatenate([x, vals], axis=-1)
variables = model.init(jax.random.PRNGKey(0), x)
out = model.apply(variables, x)
xshape = out.shape
return x, xshape
testx, xdims = test_rnn()
print("Simple RNN size test: passed.") |
Beta Was this translation helpful? Give feedback.
0 replies
-
I am working on something similar and running into performance issues. How would one use nn.scan or the RNN module in this case? |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Dear community,
I have been trying to re-write a base torch implementation of an simple rnn I wrote earlier into flax. I am however dealing with some issues with the submodule definition of the RnnCell. You can find a copy of my code that runs in a notebook with a simple test case.
When running this code I get the following error:
AssignSubModuleError: Submodule RnnCell must be defined in
setup()or in a method wrapped in
@compactHowever when I define theSubmodule RnnCell in
setup()
or us the @compact decour to fix it I get theCallCompactUnboundModuleError: Can't call compact methods on unbound modules
.Please let me know. Id be happy to learn.
Vanilla RNN in base Flax
Beta Was this translation helpful? Give feedback.
All reactions