You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Dear @tianjuxue,
I noticed a problematic use of assert statements in the FEM's solver.py code which could lead to failures. Here is the relevant code:
def jax_solve(problem, A_fn, b, x0, precond):
pc = get_jacobi_precond(jacobi_preconditioner(problem)) if precond else None
x, info = jax.scipy.sparse.linalg.bicgstab(A_fn, b, x0=x0, M=pc, tol=1e-10, atol=1e-10, maxiter=10000)
# Verify convergence
err = np.linalg.norm(A_fn(x) - b)
print(f"JAX scipy linear solve res = {err}")
# HERE IS THE PROBLEMATIC ASSERT:
assert err < 0.1, f"JAX linear solver failed to converge with err = {err}"
return x
The assert statement above acts as a control flow statement and requires concrete values to work properly. In one of my use cases, the err variable is actually a:
Traced<ShapedArray(float64[])>with<BatchTrace(level=1/0)> with
val = Array([0.], dtype=float64)
batch_dim = 0
and the assert statement breaks. I am not exactly sure why in my case err is not a concrete value but I do feel like having pure asserts here doesn't fit with the functional purity phiolosophy of JAX. I think a feasible alternative for this could be using jax.lax.cond or perhaps assertions from the Chex library.
The text was updated successfully, but these errors were encountered:
Dear @tianjuxue,
I noticed a problematic use of assert statements in the FEM's solver.py code which could lead to failures. Here is the relevant code:
The assert statement above acts as a control flow statement and requires concrete values to work properly. In one of my use cases, the err variable is actually a:
and the assert statement breaks. I am not exactly sure why in my case err is not a concrete value but I do feel like having pure asserts here doesn't fit with the functional purity phiolosophy of JAX. I think a feasible alternative for this could be using jax.lax.cond or perhaps assertions from the Chex library.
The text was updated successfully, but these errors were encountered: