Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADD] Minimal linear operator interface for PyTorch #130

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

f-dangel
Copy link
Owner

@f-dangel f-dangel commented Sep 21, 2024

Long-term, I want to add native PyTorch support for linear operators in curvlinops to address inefficiencies like #71, but also to clearly separate PyTorch from SciPy so that it will be easier to tackle features like supporting distributed settings.

This PR is a first step towards this goal.

From an API perspective, I plan to keep the constructor of all existing linear operators identical. The only backward-incompatible change will be that the produced linear operator will be purely PyTorch. To obtain the old behaviour one has to call .to_scipy() after the constructor.

Old: H = HessianLinearOperator(...)
Planned new: H = HessianLinearOperator(...).to_scipy()

The PR defines a linear operator interface in PyTorch which allows easy export to SciPy linear operators.
Importantly, the interface can multiply onto vectors/matrices represented by single Tensors, or a List[Tensor], which is more common in PyTorch. It verifies the input and output formats and all methods that need to be implemented assume the (more natural) tensor list format.

The next steps will be:

  • Define a base class CurvatureLinearOperator that replicates curvlinops._base._LinearOperator but inherits from our PyTorchLinearOperator, rather than scipy.sparse.linalg.LinearOperator.
  • Migrate each supported linear operator to inherit from CurvatureLinearOperator. I already tried that for the Hessian and was able to migrate without breaking the tests. I will set up a separate PR to keep the diffs manageable
  • Once all operators have been migrated (and probably we can get rid of a lot of boilerplate to check shapes, e.g. in KFAC), we can remove the current base class in curvlinops._base.

Let me know if this makes sense.

@coveralls
Copy link

coveralls commented Sep 21, 2024

Pull Request Test Coverage Report for Build 10968720506

Details

  • 57 of 77 (74.03%) changed or added relevant lines in 1 file are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage decreased (-0.7%) to 88.295%

Changes Missing Coverage Covered Lines Changed/Added Lines %
curvlinops/_torch_base.py 57 77 74.03%
Totals Coverage Status
Change from base Build 10967050112: -0.7%
Covered Lines: 1403
Relevant Lines: 1589

💛 - Coveralls

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants