Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add __array_namespace__ to ArrayBox #647

Merged

Conversation

yaugenst-flex
Copy link
Contributor

@yaugenst-flex yaugenst-flex commented Oct 10, 2024

Small change that adds the __array_namespace__ method from the Array API standard to autograd's ArrayBox.

This makes autograd (partially) compatible with libraries that check for array API compatibility, such as xarray, because xarray will coerce types that do not implement the array API into regular numpy arrays, see relevant discussion here.

The following code snippet:

import xarray as xr

import autograd.numpy as np
from autograd import grad


def f(x):
    data = xr.DataArray(x, dims=range(x.ndim)).data
    print(data)
    return np.sum(data)


x = np.random.uniform(-1, 1, 10)
grad(f)(x)

prints

[<autograd.numpy.numpy_boxes.ArrayBox object at 0x7f9465347640>
 <autograd.numpy.numpy_boxes.ArrayBox object at 0x7f94652f7000>
 <autograd.numpy.numpy_boxes.ArrayBox object at 0x7f94652f7c40>]

currently, because the traced ArrayBox is being coerced into a regular numpy array of dtype=object with ArrayBox elements. Differentiating through this works (with some limitations), but is very inefficient.

With the addition suggested in this PR, the snippet instead prints:

Autograd ArrayBox with value [1. 1. 1.]

i.e. data is now a "proper" ArrayBox.

I might also be interested in working on a full xarray wrapper in the future, but for now just this change would already help a lot.

@fjosw
Copy link
Collaborator

fjosw commented Oct 13, 2024

Hi @yaugenst-flex, thanks for contributing to autograd! I made the type hints compatible with python<3.10 and now all tests seem to pass 👍 Any objections against merging this @agriyakhetarpal @j-towns ?

@j-towns
Copy link
Collaborator

j-towns commented Oct 14, 2024

No objections from me.

Copy link
Collaborator

@agriyakhetarpal agriyakhetarpal left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No objections from me either! Thanks, @yaugenst-flex!

@agriyakhetarpal agriyakhetarpal merged commit 317be57 into HIPS:master Oct 15, 2024
30 checks passed
@agriyakhetarpal
Copy link
Collaborator

I've created https://github.com/HIPS/autograd/milestone/1 to track things we can add to a v1.8.0 release (or to see which things we can split up and add to a v1.7.1 release). This sounds like a useful improvement.

@agriyakhetarpal agriyakhetarpal added this to the v1.8.0 milestone Oct 15, 2024
@yaugenst-flex
Copy link
Contributor Author

This is great, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants