diff --git a/axlearn/common/transducer_test.py b/axlearn/common/transducer_test.py index d56a7c28..987722e8 100644 --- a/axlearn/common/transducer_test.py +++ b/axlearn/common/transducer_test.py @@ -4,6 +4,7 @@ # pylint: disable=duplicate-code,invalid-name import jax +import jaxlib import numpy as np import tensorflow as tf from absl import logging @@ -28,7 +29,7 @@ log_prob_suffix_alignments, log_probs_from_blank_and_tokens, ) -from axlearn.common.utils import NestedTensor, Tensor +from axlearn.common.utils import NestedTensor, Tensor, runtime_checks def numpy_log_prob_prefix_alignments( @@ -299,18 +300,19 @@ def test_apply_paddings_check(self): ) log_prob_blank, log_prob_y = jnp.log(prob_blank), jnp.log(prob_y) - # TODO(matthew_e_hopkins): test fails as of jax 0.4.33 through 0.4.35, revisit - # with runtime_checks(): - # with self.assertRaisesRegex( - # jaxlib.xla_extension.XlaRuntimeError, - # "lm_paddings cannot be all 1s.", - # ): - # jax.jit(jax.vmap(apply_paddings))( - # log_prob_blank=log_prob_blank, - # log_prob_y=log_prob_y, - # am_paddings=am_paddings, - # lm_paddings=lm_paddings, - # ) + # TODO(mattjj): replace with jax.errors.JaxRuntimeError when minimum jax + # version is jax>=0.4.35 + cls = getattr(jax.errors, 'JaxRuntimeError', + jaxlib.xla_extension.XlaRuntimeError) + + with self.assertRaisesRegex(cls, "lm_paddings cannot be all 1s."): + with runtime_checks(): + jax.jit(jax.vmap(apply_paddings))( + log_prob_blank=log_prob_blank, + log_prob_y=log_prob_y, + am_paddings=am_paddings, + lm_paddings=lm_paddings, + ) check_apply_paddings = checkify.checkify(apply_paddings, errors=checkify.user_checks) err, _ = jax.jit(jax.vmap(check_apply_paddings))( log_prob_blank=log_prob_blank, diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index 9677a695..d3b864f1 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -141,8 +141,11 @@ def switch(value): jax.config.update("jax_experimental_unsafe_xla_runtime_errors", value) switch(enabled) - yield - switch(old_state) + try: + yield + jax.effects_barrier() + finally: + switch(old_state) @contextlib.contextmanager