-
Originally asked by Nal Kalchbrenner: "Is there a nn.MaskedConv in 2d or something similar in Flax? (ideally supporting dilation)" |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
@levskaya points out that: Masked convolutions are a fair bit of work to implement, and I don't think anyone's done this yet in any JAX setting that I'm aware of. @j-towns actually implemented PixelCNN++ style causal conv layers, e.g. https://github.com/google/flax/blob/master/examples/pixelcnn/pixelcnn.py#L231 ( |
Beta Was this translation helpful? Give feedback.
@levskaya points out that:
Masked convolutions are a fair bit of work to implement, and I don't think anyone's done this yet in any JAX setting that I'm aware of.
Curious: what kind of masking are you wanting? "masked conv" is a bit of an overloaded term -- Causal constraints for generation or general sparsity? Something more like what are called partial convs?
@j-towns actually implemented PixelCNN++ style causal conv layers, e.g. https://github.com/google/flax/blob/master/examples/pixelcnn/pixelcnn.py#L231 (
ConvDown
,ConvDownRight
) but it's pretty specific to that model, rather than a fully parametrized, general mask-conv.