Skip to content

Commit

Permalink
* touch up
Browse files Browse the repository at this point in the history
  • Loading branch information
Joshuaalbert committed Nov 14, 2024
1 parent 4a11a58 commit 036e7a4
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
13 changes: 6 additions & 7 deletions jax/_src/scipy/sparse/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,7 @@ def _isolve(_isolve_solve, A, b, x0=None, *, tol=1e-5, atol=0.0,
return x, info


def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None,
assume_ipart_is_zero: bool = False):
def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None, assume_ipart_is_zero: bool = False):
"""Use Conjugate Gradient iteration to solve ``Ax = b``.
The numerics of JAX's ``cg`` should exact match SciPy's ``cg`` (up to
Expand Down Expand Up @@ -275,7 +274,9 @@ def cg(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None,
to reach a given error tolerance.
assume_ipart_is_zero : bool, optional
Whether the linear operator dtype can be assumed to have imag(A)=0.
Defaults to False for complex-value systems.
Defaults to False for complex-value systems. When True, then it can be
assumed that for complex operators A, A @ x and A.T @ x are equivalent,
leading to more efficient reverse-mode autodiff.
See also
--------
Expand All @@ -293,8 +294,7 @@ def real_valued(x):
try:
# Prefer to use the dtype of the operator, if available.
if callable(A) and x0 is not None:
symmetric = all(map(real_valued,
tree_leaves(jax.eval_shape(A, x0))))
symmetric = all(map(real_valued, tree_leaves(jax.eval_shape(A, x0))))
else:
symmetric = real_valued(A)
except TypeError:
Expand Down Expand Up @@ -721,8 +721,7 @@ def _solve(A, b):
return x, info


def bicgstab(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None,
symmetric=False):
def bicgstab(A, b, x0=None, *, tol=1e-5, atol=0.0, maxiter=None, M=None, symmetric=False):
"""Use Bi-Conjugate Gradient Stable iteration to solve ``Ax = b``.
The numerics of JAX's ``bicgstab`` should exact match SciPy's
Expand Down
4 changes: 2 additions & 2 deletions tests/lax_scipy_sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def test_cg_against_scipy(self, shape, dtype, preconditioner, assume_ipart_is_ze
rng = jtu.rand_default(self.rng())
A = rand_sym_pos_def(rng, shape, dtype)
if assume_ipart_is_zero:
A = A - jnp.imag(A)
A = A.real.astype(A.dtype)
b = rng(shape[:1], dtype)
M = self._fetch_preconditioner(preconditioner, A, rng=rng)

Expand Down Expand Up @@ -230,7 +230,7 @@ def test_bicgstab_against_scipy(self, shape, dtype, preconditioner, symmetric):
rng = jtu.rand_default(self.rng())
A = rng(shape, dtype)
if symmetric:
A = A @ A.T.conj()
A = A + A.T
b = rng(shape[:1], dtype)
M = self._fetch_preconditioner(preconditioner, A, rng=rng)

Expand Down

0 comments on commit 036e7a4

Please sign in to comment.