Replies: 1 comment
-
Through some trial and error, I may have found a solution. My New Flax code: import jax
import jax.numpy as jnp
from flax import linen as nn
from einops import rearrange
def make_initializer(out_channels, in_channels, kernel_size, groups):
# https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
k = groups / (in_channels * jnp.prod(jnp.array(kernel_size)))
scale = jnp.sqrt(k)
def init_fn(key, shape, dtype):
return jax.random.uniform(key, shape, minval=-scale, maxval=scale, dtype=dtype)
return init_fn
class CustomConv1d(nn.Conv):
@nn.compact
def __call__(self, x):
# note: we just ignore whatever self.kernel_init is
kernel_init = make_initializer(
self.features, x.shape[-1], self.kernel_size, self.feature_group_count
)
if self.use_bias:
# note: we just ignore whatever self.bias_init is
bias_init = make_initializer(
self.features, x.shape[-1], self.kernel_size, self.feature_group_count
)
else:
bias_init = None
return nn.Conv(
features=self.features,
kernel_size=self.kernel_size,
strides=self.strides,
padding=self.padding,
input_dilation=self.input_dilation,
kernel_dilation=self.kernel_dilation,
feature_group_count=self.feature_group_count,
use_bias=self.use_bias,
mask=self.mask,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
kernel_init=kernel_init,
bias_init=bias_init
)(x)
class LeakyReLU(nn.Module):
negative_slope: float = .01
@nn.compact
def __call__(self, x):
return nn.leaky_relu(x, negative_slope=self.negative_slope)
def WNConv2d(scale_init, *args, **kwargs):
conv = nn.WeightNorm(CustomConv1d(*args, **kwargs), scale_init=scale_init)
return conv
class MPD(nn.Module):
period: int
def pad_to_period(self, x):
t = x.shape[-1]
x = jnp.pad(x, pad_width=((0, 0), (0, 0), (0, self.period - t % self.period)), mode='reflect')
return x
@nn.compact
def __call__(self, x):
convs = [
WNConv2d(nn.initializers.constant(1/jnp.sqrt(3)), features=32, kernel_size=(5, 1), strides=(3, 1), padding=((2, 2), (0, 0))),
WNConv2d(nn.initializers.constant(1/jnp.sqrt(3)), features=128, kernel_size=(5, 1), strides=(3, 1), padding=((2, 2), (0, 0))),
WNConv2d(nn.initializers.constant(1/jnp.sqrt(3)), features=512, kernel_size=(5, 1), strides=(3, 1), padding=((2, 2), (0, 0))),
WNConv2d(nn.initializers.constant(1/jnp.sqrt(3)), features=1024, kernel_size=(5, 1), strides=(3, 1), padding=((2, 2), (0, 0))),
WNConv2d(nn.initializers.constant(1/jnp.sqrt(3)), features=1024, kernel_size=(5, 1), strides=(1, 1), padding=((2, 2), (0, 0))),
WNConv2d(nn.initializers.constant(1/jnp.sqrt(3)), features=1, kernel_size=(3, 1), strides=(1, 1), padding=((1, 1), (0, 0))),
]
fmap = []
x = self.pad_to_period(x)
x = rearrange(x, "b c (l p) -> b l p c", p=self.period)
for i, layer in enumerate(convs):
x = layer(x)
if i != (len(convs) - 1):
x = LeakyReLU(negative_slope=0.1)(x)
fmap.append(x)
return fmap
def summary_stats(name, x):
print(f'Stats for {name}:')
print(f'shape:', list(x.shape))
print(f'mean: { jnp.mean(x):,.5f} min: { jnp.min(x):,.5f} max: {jnp.max(x):,.5f} std: {jnp.std(x):,.5f}')
key = jax.random.PRNGKey(1)
B, C, T = 1, 1, 44100
x = jnp.zeros((B, C, T))
period = 2
model = MPD(period)
fmaps, variables = model.init_with_output({"params": key}, x)
# Print summary stats for each feature map
for i, fmap in enumerate(fmaps):
summary_stats(f"fmap {i}", fmap)
print()
params = variables["params"]
for i in range(6):
params[f"WeightNorm_{i}"][f"CustomConv1d_{i}/Conv_0/kernel/scale"]
params[f"CustomConv1d_{i}"]["Conv_0"]["bias"]
params[f"CustomConv1d_{i}"]["Conv_0"]["kernel"]
print(model.tabulate({"params": key}, x, console_kwargs={"width": 400})) New output:
And another randomly sampled PyTorch output:
|
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I'm trying to port some PyTorch code to Flax. The model involves Conv2d layers wrapped with weight_norm. There are also LeakyReLU activations except on the last layer. I've confirmed that the parameter counts and input/output shapes are the same between PyTorch and Flax, and yet the mean/min/max/std of the outputs seem off. So can someone help me identify what went wrong in the porting of the code? I think the issue is related to weight initializations (see #4091)
Here's the PyTorch code:
and PyTorch output:
Here's the Flax code:
and the Flax output:
To me, the most glaring differences in the outputs are the
max:
values, even when changing JAX seeds. Again, here's the PyTorch output:and Flax output:
Beta Was this translation helpful? Give feedback.
All reactions