You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello everyone
I am glad to see this page. Thanks . I am new to FLAX so I try converting from FLAX to Pytorch .
For the following class:
class FlaxMLP(nn.Module):
"""Implements an MLP in Flax"""
features: Sequence[int]
layer_norm: bool
@nn.compact
def __call__(self, inputs):
x = inputs
for i, feat in enumerate(self.features):
x = nn.Dense(feat, dtype=DTYPE)(x)
if i != len(self.features) - 1:
x = nn.tanh(x)
if self.layer_norm:
x = nn.LayerNorm()(x)
return
``` x
`I tried converting to pytorch and did like this:
`
class MLP(nn.Module):
def __init__(self, features , layer_norm=False):
super(MLP, self).__init__()
self.features = features
self.layer_norm = layer_norm
self.dense_layers = nn.ModuleList([
nn.Linear(features[i], features[i+1]) for i in range(len(features) - 1)
])
if self.layer_norm:
self.layer_norm = nn.LayerNorm(features[-1])
def forward(self, inputs):
x = inputs
for i, layer in enumerate(self.dense_layers):
x = layer(x)
if i != len(self.features) - 2: # Apply activation for all layers except the last one
x = torch.tanh(x)
if self.layer_norm:
x = self.layer_norm(x)
return x
Can some one please review this and give some feedback .Thankyou.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hello everyone
I am glad to see this page. Thanks . I am new to FLAX so I try converting from FLAX to Pytorch .
For the following class:
Beta Was this translation helpful? Give feedback.
All reactions