Do jax get rid of unused branches #15661
-
Hello I have a model with multiple hyperparameters - some parts of the code should be executed and some not ; the information wheather those parts will be executed or not is known at compile time. For example:
As in the toy example above information is in the immutable frozen dict and known in compile time. So in this scenario what is best use if, lax.select; something else, is there a way to point jit that this part of the code for this configuration is a dead end? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
In general, the semantics of From your question it's not clear whether any of this is relevant, because you have a static, trace-time condition (using a Python |
Beta Was this translation helpful? Give feedback.
Yes, if you have static conditions you should use a Python
if
statement, and then the compiler won't even see the other branch.