Skip to content

Commit

Permalink
Improve error reporting for complex eigh on TPU
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Feb 5, 2021
1 parent 3575bc7 commit 7e1439b
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
10 changes: 10 additions & 0 deletions jax/_src/lax/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,14 @@ def _eigh_cpu_gpu_translation_rule(syevd_impl, c, operand, lower):
_nan_like(c, w))
return xops.Tuple(c, [v, w])

def _eigh_tpu_translation_rule(c, operand, lower):
# Fail gracefully for complex dtype (unsupported on TPU).
shape = c.get_shape(operand)
dtype = shape.element_type().type
if np.issubdtype(dtype, np.complexfloating):
raise NotImplementedError("eigh is not implemented on TPU for complex inputs.")
return eigh_translation_rule(c, operand, lower)

def eigh_jvp_rule(primals, tangents, lower):
# Derivative for eigh in the simplest case of distinct eigenvalues.
# This is classic nondegenerate perurbation theory, but also see
Expand Down Expand Up @@ -542,6 +550,8 @@ def eigh_batching_rule(batched_args, batch_dims, lower):
xla.backend_specific_translations['gpu'][eigh_p] = partial(
_eigh_cpu_gpu_translation_rule, rocsolver.syevd)

xla.backend_specific_translations['tpu'][eigh_p] = _eigh_tpu_translation_rule


triangular_solve_dtype_rule = partial(
naryop_dtype_rule, _input_dtype, (_float | _complex, _float | _complex),
Expand Down
19 changes: 19 additions & 0 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@
import jax
import jax.lib
from jax import jit, grad, jvp, vmap
from jax.interpreters import xla
from jax import lax
from jax import numpy as jnp
from jax import scipy as jsp
from jax import test_util as jtu
from jax._src.lax import linalg as lax_linalg

from jax.config import config
config.parse_flags_with_absl()
Expand Down Expand Up @@ -357,6 +359,23 @@ def norm(x):
self._CompileAndCheck(partial(jnp.linalg.eigh, UPLO=uplo), args_maker,
rtol=1e-3)

def testEighComplexTPUFails(self):
# The eigh TPU translation rule raises NotImplementedError for complex
# input. This test is designed to fail if TPU ever starts supporting
# complex input, so that we know to remove that check.
if jtu.device_under_test() != 'tpu':
self.skipTest("Test requires TPU")

with self.assertRaisesRegex(NotImplementedError, "eigh is not implemented on TPU for complex inputs."):
jnp.linalg.eigh(jnp.ones((4, 4), dtype='complex64'))

tpu_rule = xla.backend_specific_translations['tpu'].pop(lax_linalg.eigh_p)
try:
with self.assertRaisesRegex(RuntimeError, "Invalid argument: Type of the input matrix must be float: got c64.*"):
jnp.linalg.eigh(jnp.ones((4, 4), dtype='complex64'))
finally:
xla.backend_specific_translations['tpu'][lax_linalg.eigh_p] = tpu_rule

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}".format(
jtu.format_shape_dtype_string(shape, dtype)),
Expand Down

0 comments on commit 7e1439b

Please sign in to comment.