Dilated convolution compatibility with PyTorch weights #614
-
Hey guys, I've been trying to get pytorch's dilated convolution equivalent in flax (working with pytorch weights however I'm getting these small differences in results which I can't figure out. It could be a missing transposition I'm not seeing. I would highly appreciate if someone could have a look over this small example and help me out. Thanks! https://colab.research.google.com/drive/1ArSB9dByJsx1cfQ6I_hJEZh37skXPKoy?usp=sharing |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
In XLA and by extension JAX and FLAX dilation is treated as a dilation 'rate' not an 'amount' of interior padding on the kernel. As a result the default value is 1. If I am reading the colab properly, then the intended DILATION=2 is being passed as _dilation = 2 - 1 = 1 which would be an undilated convolution. Hopefully this helps resolve the issue. |
Beta Was this translation helpful? Give feedback.
In XLA and by extension JAX and FLAX dilation is treated as a dilation 'rate' not an 'amount' of interior padding on the kernel. As a result the default value is 1. If I am reading the colab properly, then the intended DILATION=2 is being passed as _dilation = 2 - 1 = 1 which would be an undilated convolution. Hopefully this helps resolve the issue.