Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch axes slicing #20

Open
dongwoonhyun opened this issue Jun 1, 2024 · 1 comment
Open

Batch axes slicing #20

dongwoonhyun opened this issue Jun 1, 2024 · 1 comment

Comments

@dongwoonhyun
Copy link

Hi Brent,

Thanks for the fantastic library. We're finding it very helpful in our ultrasound imaging research. The updated batch axes handling is great for our use case and solves the issues I was trying to hack around in a fork.

Another feature that would be nice is slicing into the batch axes of a MatrixLieGroup object.

For instance,

import jaxlie as jl
import numpy as onp
a = jl.SE3(onp.random.random((4, 1, 10, 7)))
a_slice = jl.SE3(a.wxyz_xyz[..., :5, :])  # current syntax
a_slice = a[..., :5]  # desired syntax

It's a minor thing, but it would improve our ergonomics a lot (e.g., inspecting the pose of an individual sensor in a large array). I think the following code snippet should add that functionality to MatrixLieGroup:

    def __getitem__(self, key):
        """Allow batch axes slicing using [] indexing operator."""
        # If the key has an ellipsis (i.e. "..."), make sure the parameter dimension is
        # included explicitly.
        if hasattr(key, "__iter__") and any([k == Ellipsis for k in key]):
            key = (*key, slice(None))
        return self.__class__(
            **{f.name: self.__dict__[f.name][key] for f in jdc.fields(self)}
        )

What do you think?

@brentyi
Copy link
Owner

brentyi commented Jun 1, 2024

Hi @dongwoonhyun, thanks for the suggestion! Glad the library is useful.

Sure, I'm happy to support this. Do you have time for a PR? At a glance the implementation you wrote looks good to me, only suggestions are:

  • Type annotations, eg typing_extensions.Self for the output.
  • Perhaps some checks to make sure the last parameter dimension isn't be sliced into accidentally?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants