-
Hi, I have two questions: (1) Does the Jax compiler know to not carry out calculations which are multiplied by zero? (2) What kind of introspection tools are there? Is there a way I can see the low level code generated by the compiler? Thanks in advance! :) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
In general, no it does not, because multiplication by zero cannot safely be ignored. Consider this function:
When f(1.0, 1.0) # 1.0
f_ignore_mult_by_zero(1.0, 1.0) # 1.0 However, when f(1.0, np.inf) # NaN
f_ignore_mult_by_zero(1.0, np.inf) # 1.0 One of the key goals of the compiler is to not optimize in a way that will change program outputs, and so the compiler will not elide multiplication by zero. If you want "ignore multiplication by zero" semantics, you can do this explicitly with def f(a, b, k):
return a + lax.cond(k == 0, lambda: 0.0, lambda: k * b) |
Beta Was this translation helpful? Give feedback.
In general, no it does not, because multiplication by zero cannot safely be ignored. Consider this function:
When
b
is finite, the functions return the same results:However, when
b
isNaN
orinf
, the functions return different results:One of the key goals of the compiler is to not optimize in a way that will change program outputs, and so the compiler will not elide multiplication by zero.
If you want "ignore multiplication by zero" semantics, you can do this explicitly with
lax.cond
; e.g. some…