-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unnecessary device transfer in _postprocess
method
#71
Comments
Chipping in to say that I am almost done writing a library to compute matrix-free SVD/eigh using random matrices, and this was one of my main issues with curvlinops: my library is pytorch "native" and does make use of the Hessian from curvlinops, but not of the scipy routines, so that step was also redundant. A few thoughts:
Thoughts? Snippetclass TorchLinOpWrapper:
"""
Since this library operates mainly with PyTorch tensors, but some useful
LinOps interface with NumPy arrays instead, this mixin class acts as a
wraper on the ``__matmul__`` and ``__rmatmul__`` operators,
so that the operator expects and returns torch tensors, even when the
wrapped operator interfaces with NumPy. Usage example::
# extend NumPy linear operator via multiple inheritance
class TorchWrappedLinOp(TorchLinOpWrapper, LinOp):
pass
lop = TorchWrappedLinOp(...) # instantiate normally
w = lop @ v # now v can be a PyTorch tensor
"""
@staticmethod
def _input_wrapper(x):
""" """
if isinstance(x, torch.Tensor):
return x.cpu().numpy(), x.device
else:
return x, None
@staticmethod
def _output_wrapper(x, torch_device=None):
""" """
if torch_device is not None:
return torch.from_numpy(x).to(torch_device)
else:
return x
def __matmul__(self, x):
""" """
x, device = self._input_wrapper(x)
result = self._output_wrapper(super().__matmul__(x), device)
return result
def __rmatmul__(self, x):
""" """
x, device = self._input_wrapper(x)
result = self._output_wrapper(super().__rmatmul__(x), device)
return result |
This is a fair point, and I would be happy to offer a PyTorch-only mode of linear operators. I agree with @andres-fr that there should be a clean way to define a sub-class that works purely in PyTorch. One downside is that we will have to replicate most of the |
Actually I didn't mention that option, but now that you mention it's clear that it's the best one: extend the linops. I think a good target would be to reach this agnosticity with minimal code overhead. But we really want to avoid running into "tensorflow backend" problem, leading to an extremely messy API. Just a couple ideas to bounce |
In the
_postprocess
method the result of a matrix-vector product will always be transferred to CPU. While this is consistent with the scipy interface, in many use cases where we only operate with torch GPU tensors this is not desirable, as it creates unnecessary overhead.The text was updated successfully, but these errors were encountered: