Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

restore test_apply_paddings_check runtime_checks test #771

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

mattjj
Copy link

@mattjj mattjj commented Oct 22, 2024

The main idea is that we need to call jax.effects_barrier(), because the error may be raised in an XLA computation that is asynchronous with the main Python thread and therefore we need to block. (There may have been a recent change in behavior, where JAX runs more computations asynchronously on the CPU backend.) We could put that call to jax.effects_barrier() in the test code (and corresponding user code), or we could bulid it into the runtime_checks context manager. Currently this commit does the latter.

I also tweaked the runtime_checks logic to use a try/finally pattern to restore the state when the context is exited, even when it's exited via exception. We may want to do the same to context managers like numeric_checks.

While the test now passes, there is a gross warning printed about "Exception ignored in atexit callback". That may be a JAX internal bug, or it may be some quirk of CPython 3.10; I haven't investigated further. Let me know if that seems like a problem.

What do you think?

The main idea is that we need to call `jax.effects_barrier()`, because the
error may be raised in an XLA computation that is asynchronous with the main
Python thread and therefore we need to block. (There may have been a recent
change in behavior, where JAX runs more computations asynchronously on the CPU
backend.) We could put that call to `jax.effects_barrier()` in the test code
(and corresponding user code), or we could bulid it into the `runtime_checks`
context manager.  Currently this commit does the latter.

I also tweaked the `runtime_checks` logic to use a `try/finally` pattern to
restore the state when the context is exited, even when it's exited via
exception. We may want to do the same to context managers like
`numeric_checks`.

While the test now passes, there is a gross warning printed about "Exception
ignored in atexit callback". That may be a JAX internal bug, or it may be some
quirk of CPython 3.10; I haven't investigated further. Let me know if that
seems like a problem.
Copy link
Contributor

@markblee markblee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @mattjj! While the warning isn't pleasant, I think we can live with it for now.

Copy link
Contributor

@kelvin-zou kelvin-zou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@mattjj
Copy link
Author

mattjj commented Oct 22, 2024

My intrepid teammates @yashk2810 and @hawkinsp noticed that in the most recent release of JAX we no longer raise jaxlib.xla_extension.XlaRuntimeError but rather jax.errors.JaxRuntimeError (EDIT: or maybe that's just a public-facing alias for the same object...). See jax-ml/jax#23943. I'll try to update that in this PR (under a version switch), or send a follow-up PR if this PR gets merged before I make the fix.

@matthew-e-hopkins
Copy link
Contributor

Thank you!

Copy link
Contributor

@ruomingp ruomingp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants