Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

restore test_apply_paddings_check runtime_checks test #771

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions axlearn/common/transducer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# pylint: disable=duplicate-code,invalid-name

import jax
import jaxlib
import numpy as np
import tensorflow as tf
from absl import logging
Expand All @@ -28,7 +29,7 @@
log_prob_suffix_alignments,
log_probs_from_blank_and_tokens,
)
from axlearn.common.utils import NestedTensor, Tensor
from axlearn.common.utils import NestedTensor, Tensor, runtime_checks


def numpy_log_prob_prefix_alignments(
Expand Down Expand Up @@ -299,18 +300,19 @@ def test_apply_paddings_check(self):
)
log_prob_blank, log_prob_y = jnp.log(prob_blank), jnp.log(prob_y)

# TODO(matthew_e_hopkins): test fails as of jax 0.4.33 through 0.4.35, revisit
# with runtime_checks():
# with self.assertRaisesRegex(
# jaxlib.xla_extension.XlaRuntimeError,
# "lm_paddings cannot be all 1s.",
# ):
# jax.jit(jax.vmap(apply_paddings))(
# log_prob_blank=log_prob_blank,
# log_prob_y=log_prob_y,
# am_paddings=am_paddings,
# lm_paddings=lm_paddings,
# )
# TODO(mattjj): replace with jax.errors.JaxRuntimeError when minimum jax
# version is jax>=0.4.35
cls = getattr(jax.errors, 'JaxRuntimeError',
jaxlib.xla_extension.XlaRuntimeError)

with self.assertRaisesRegex(cls, "lm_paddings cannot be all 1s."):
with runtime_checks():
jax.jit(jax.vmap(apply_paddings))(
log_prob_blank=log_prob_blank,
log_prob_y=log_prob_y,
am_paddings=am_paddings,
lm_paddings=lm_paddings,
)
check_apply_paddings = checkify.checkify(apply_paddings, errors=checkify.user_checks)
err, _ = jax.jit(jax.vmap(check_apply_paddings))(
log_prob_blank=log_prob_blank,
Expand Down
7 changes: 5 additions & 2 deletions axlearn/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,11 @@ def switch(value):
jax.config.update("jax_experimental_unsafe_xla_runtime_errors", value)

switch(enabled)
yield
switch(old_state)
try:
yield
jax.effects_barrier()
finally:
switch(old_state)


@contextlib.contextmanager
Expand Down