Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for passing array attributes via ffi_call and an example demonstrating the different attribute APIs supported by the FFI #23951

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

dfm
Copy link
Collaborator

@dfm dfm commented Sep 26, 2024

The XLA FFI supports a rich interface for specifying "attributes", but this interface hasn't been widely used from the JAX side. We discovered that passing array attributes wasn't supported by the current front end, so I added support for that and took advantage of this change to add a more complete example of the e2e usage of attributes.

The core JAX change that was needed was to add support for passing np.array objects (which aren't hashable!) as keyword arguments to the ffi_call primitive. To do this, I wrap any unhashable keyword arguments in the _HashableByObjectId helper which creates a hashable object based on the id.

"Why not pass these unhashable objects as positional arguments?" you ask. Well, the problem is that we need access to the concrete value of this array when lowering, so we can't have them turning into tracers.

I don't expect this to introduce any real caching issues, but if anyone has concerns or other suggestions for a better API, I'd love to hear them!

@dfm dfm self-assigned this Sep 26, 2024
@gspschmid
Copy link
Contributor

I don't expect this to introduce any real caching issues, but if anyone has concerns or other suggestions for a better API, I'd love to hear them!

Just to clarify, does this defeat the hashing at the JAX level and would any ffi_call using arrays or dicts hence be uncacheable at the JAX level (but presumably cached fine at the MLIR level)? If I recall correctly, there's a caching mechanism for jitted functions both within JAX itself (on jaxprs?) and another on the emitted StableHLO.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants