Add support for passing array attributes via ffi_call and an example demonstrating the different attribute APIs supported by the FFI #23951
+211
−5
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
The XLA FFI supports a rich interface for specifying "attributes", but this interface hasn't been widely used from the JAX side. We discovered that passing array attributes wasn't supported by the current front end, so I added support for that and took advantage of this change to add a more complete example of the e2e usage of attributes.
The core JAX change that was needed was to add support for passing
np.array
objects (which aren't hashable!) as keyword arguments to theffi_call
primitive. To do this, I wrap any unhashable keyword arguments in the_HashableByObjectId
helper which creates a hashable object based on theid
."Why not pass these unhashable objects as positional arguments?" you ask. Well, the problem is that we need access to the concrete value of this array when lowering, so we can't have them turning into tracers.
I don't expect this to introduce any real caching issues, but if anyone has concerns or other suggestions for a better API, I'd love to hear them!