Skip to content

Commit

Permalink
* get result_type to handle non-arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
Joshuaalbert committed Sep 16, 2024
1 parent f0b8e27 commit e0c0835
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions jax/_src/scipy/sparse/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,10 +288,11 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None,
else:
# real-valued positive-definite linear operators are symmetric.
def real_valued(x):
return not issubclass(x.dtype.type, np.complexfloating)
return not issubclass(jnp.result_type(x).type, np.complexfloating)
if callable(A) and x0 is not None:
# we use output dtype as the proxy for dtype.
symmetric = all(map(real_valued, tree_leaves(jax.eval_shape(A, x0))))
symmetric = all(map(real_valued,
tree_leaves(jax.eval_shape(A, x0))))
else:
try:
# Prefer to use the dtype of the operator, if available.
Expand Down

0 comments on commit e0c0835

Please sign in to comment.