Skip to content

Commit

Permalink
Merge pull request #3162 from chiamp:relax_tol
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 543010413
  • Loading branch information
Flax Authors committed Jun 24, 2023
2 parents 115b8a5 + 3e470bf commit 25ab83c
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions docs/guides/convert_pytorch_to_flax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ and the Flax kernel has shape [inC, outC]. Transposing the kernel will do the tr
t_out = t_fc(t_x)
t_out = t_out.detach().cpu().numpy()

np.testing.assert_almost_equal(j_out, t_out)
np.testing.assert_almost_equal(j_out, t_out, decimal=6)


Convolutions
Expand Down Expand Up @@ -205,7 +205,7 @@ while Flax multiplies the estimated statistic with ``momentum`` and the new obse
# [N, C, H, W] -> [N, H, W, C]
t_out = np.transpose(t_out.detach().cpu().numpy(), (0, 2, 3, 1))

np.testing.assert_almost_equal(j_out, t_out)
np.testing.assert_almost_equal(j_out, t_out, decimal=6)



Expand Down Expand Up @@ -253,7 +253,7 @@ operation. ``nn.pool()`` is the core function behind |nn.avg_pool()|_ and |nn.ma
# [N, C, H, W] -> [N, H, W, C]
t_out = np.transpose(t_out.detach().cpu().numpy(), (0, 2, 3, 1))

np.testing.assert_almost_equal(j_out, t_out)
np.testing.assert_almost_equal(j_out, t_out, decimal=6)



Expand Down

0 comments on commit 25ab83c

Please sign in to comment.