diff --git a/jax/_src/scipy/sparse/linalg.py b/jax/_src/scipy/sparse/linalg.py index 3fbcdfeed620..db87e8982d7d 100644 --- a/jax/_src/scipy/sparse/linalg.py +++ b/jax/_src/scipy/sparse/linalg.py @@ -288,10 +288,11 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None, else: # real-valued positive-definite linear operators are symmetric. def real_valued(x): - return not issubclass(x.dtype.type, np.complexfloating) + return not issubclass(jnp.result_type(x).type, np.complexfloating) if callable(A) and x0 is not None: # we use output dtype as the proxy for dtype. - symmetric = all(map(real_valued, tree_leaves(jax.eval_shape(A, x0)))) + symmetric = all(map(real_valued, + tree_leaves(jax.eval_shape(A, x0)))) else: try: # Prefer to use the dtype of the operator, if available.