flax switch error #16250
-
Hello I have couple big branches that I earlier tried to switch between using jax switch. switch function worked but some modules did not with error message that I should not mix jax and flax transformations - hence I am trying to use flax version first I prepare all possible versions with different indexes
next I invoke and it DO work
switch do not work
error
I can not make sense of it as the index is already set by partial application main_fun_op - is complex Thanks for help !! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
It's unclear from your partial code, but it appears that nn.switch(indexx, self.main_fun_ops, curried) This would be easier to answer if you could provide a minimal reproducible example. Also, since since this is a flax-specific question, you'll probably get better answers if you ask at http://github.com/google/flax/ |
Beta Was this translation helpful? Give feedback.
It's unclear from your partial code, but it appears that
main_func_op
is a method of a class (because you're referring to it asself.main_func_op
) in which case theself
argument is implicit so you probably need something like:This would be easier to answer if you could provide a minimal reproducible example. Also, since since this is a flax-specific question, you'll probably get better answers if you ask at http://github.com/google/flax/