From e0c0835d6ff01a652e855f60ae6936c4358913a6 Mon Sep 17 00:00:00 2001 From: joshuaalbert Date: Tue, 17 Sep 2024 00:30:39 +0200 Subject: [PATCH] * get result_type to handle non-arrays --- jax/_src/scipy/sparse/linalg.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/jax/_src/scipy/sparse/linalg.py b/jax/_src/scipy/sparse/linalg.py index 3fbcdfeed620..db87e8982d7d 100644 --- a/jax/_src/scipy/sparse/linalg.py +++ b/jax/_src/scipy/sparse/linalg.py @@ -288,10 +288,11 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None, else: # real-valued positive-definite linear operators are symmetric. def real_valued(x): - return not issubclass(x.dtype.type, np.complexfloating) + return not issubclass(jnp.result_type(x).type, np.complexfloating) if callable(A) and x0 is not None: # we use output dtype as the proxy for dtype. - symmetric = all(map(real_valued, tree_leaves(jax.eval_shape(A, x0)))) + symmetric = all(map(real_valued, + tree_leaves(jax.eval_shape(A, x0)))) else: try: # Prefer to use the dtype of the operator, if available.