Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding jax backend #562

Closed
wants to merge 29 commits into from
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
062c3bb
feat: started to add support for jax
mrava87 Dec 20, 2023
6a0eac4
feat: added jaxoperator
mrava87 Dec 21, 2023
372c39f
minor: fix backend to return jnp
mrava87 Dec 21, 2023
d3b3b66
minor: added ncp to fft2d
mrava87 Dec 21, 2023
780efa3
minor: added ncp to fftnd
mrava87 Dec 21, 2023
c2b7038
feat: added jax to backends
mrava87 Dec 21, 2023
9c008af
feat: add jax backend for convolve1d
mrava87 Dec 21, 2023
e1fa060
minor: remove prints from convolve1d
mrava87 Dec 21, 2023
68efee3
feat: remove cusignal since it is incorporated in cupy
mrava87 Dec 21, 2023
648c7fd
feature: continue integrating jax
mrava87 Mar 16, 2024
3093bd8
feat: adapted first derivative and vstack for jax
mrava87 Apr 14, 2024
208660d
feat: enabled jax for more basicoperators
mrava87 Jun 21, 2024
0ef312f
doc: added docstring and types to jaxoperator
mrava87 Jun 21, 2024
61b5807
feat: enabled jax for firstderivative
mrava87 Jun 21, 2024
cfb09e3
feat: enable jax in FirstDerivative and SecondDerivative
mrava87 Jun 23, 2024
597317f
minor: fix inconsistency in convolve1d
mrava87 Jun 23, 2024
a99ff5d
doc: added jax example in doc
mrava87 Jun 23, 2024
11c7e1d
feat: adapted nonstatconvolve1d to jax
mrava87 Jun 23, 2024
72c30d1
feat: enable Shift with jax arrays
mrava87 Jun 24, 2024
f5205ab
minor: add explicit and clinear to JaxOperator
mrava87 Jun 24, 2024
3d703d4
minor: fix FFT2D rmatvec for jax
mrava87 Jun 24, 2024
7604e0f
fix: fix get_block_diag from using cupy without checking if available
mrava87 Jun 24, 2024
a0813ec
feat: finalized jax integration of signalprocessing
mrava87 Jun 29, 2024
b52c537
feat: jax integration of avo
mrava87 Jun 29, 2024
83ea4c7
feat: jax integration of waveeqprocessing
mrava87 Jun 29, 2024
7a31c3b
doc: added jax tutorial
mrava87 Jul 1, 2024
4f2aabd
build: temporarely force numpy/scipy versions
mrava87 Jul 1, 2024
8bd76f1
doc: added jax tutorial
mrava87 Jul 2, 2024
aa458f5
test: added tests for JaxOperator
mrava87 Jul 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pylops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from .config import *
from .linearoperator import *
from .torchoperator import *
from .jaxoperator import *
from .basicoperators import *
from . import (
avo,
Expand Down
19 changes: 19 additions & 0 deletions pylops/jaxoperator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
__all__ = [
"JaxOperator",
]


import jax

from pylops import LinearOperator


class JaxOperator(LinearOperator):
def __init__(self, Op):
super().__init__(dtype=Op.dtype, dims=Op.dims, dimsd=Op.dimsd, name=Op.name)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you put a comma at the end it will be more readable

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure what you mean...

self._matvec = jax.jit(Op._matvec)
self._rmatvec = jax.jit(Op._rmatvec)

def _rmatvecad(self, x, y):
mrava87 marked this conversation as resolved.
Show resolved Hide resolved
_, f_vjp = jax.vjp(self._matvec, x)
return jax.jit(f_vjp)(y)
19 changes: 11 additions & 8 deletions pylops/signalprocessing/convolve1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ def _choose_convfunc(


def _pad_along_axis(array: np.ndarray, pad_size: tuple, axis: int = 0) -> np.ndarray:

ncp = get_array_module(array)
npad = [(0, 0)] * array.ndim
npad[axis] = pad_size
return np.pad(array, pad_width=npad)
return ncp.pad(array, pad_width=npad)


class _Convolve1Dshort(LinearOperator):
Expand All @@ -67,6 +67,7 @@ def __init__(
dtype: DTypeLike = "float64",
name: str = "C",
) -> None:
ncp = get_array_module(h)
dims = _value_or_sized_to_tuple(dims)
super().__init__(dtype=np.dtype(dtype), dims=dims, dimsd=dims, name=name)
self.axis = axis
Expand All @@ -83,7 +84,7 @@ def __init__(
(max(self.offset, 0), -min(self.offset, 0)),
axis=-1 if h.ndim == 1 else axis,
)
self.hstar = np.flip(self.h, axis=-1)
self.hstar = ncp.flip(self.h, axis=-1)

# add dimensions to filter to match dimensions of model and data
if self.h.ndim == 1:
Expand Down Expand Up @@ -127,6 +128,7 @@ def __init__(
dtype: DTypeLike = "float64",
name: str = "C",
) -> None:
ncp = get_array_module(h)
dims = _value_or_sized_to_tuple(dims)
dimsd = h.shape
super().__init__(dtype=np.dtype(dtype), dims=dims, dimsd=dimsd, name=name)
Expand All @@ -140,13 +142,13 @@ def __init__(
self.offset = 2 * (self.dims[self.axis] // 2 - int(offset))
if self.dims[self.axis] % 2 == 0:
self.offset -= 1
self.hstar = np.flip(self.h, axis=-1)
self.hstar = ncp.flip(self.h, axis=-1)

self.pad = np.zeros((len(dims), 2), dtype=int)
self.pad = ncp.zeros((len(dims), 2), dtype=int)
self.pad[self.axis, 0] = max(self.offset, 0)
self.pad[self.axis, 1] = -min(self.offset, 0)

self.padd = np.zeros((len(dims), 2), dtype=int)
self.padd = ncp.zeros((len(dims), 2), dtype=int)
self.padd[self.axis, 1] = max(self.offset, 0)
self.padd[self.axis, 0] = -min(self.offset, 0)

Expand All @@ -162,12 +164,13 @@ def __init__(

@reshaped
def _matvec(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
if type(self.h) is not type(x):
self.h = to_cupy_conditional(x, self.h)
self.convfunc, self.method = _choose_convfunc(
self.h, self.method, self.dims
)
x = np.pad(x, self.pad)
x = ncp.pad(x, self.pad)
y = self.convfunc(self.h, x, mode="same")
return y

Expand All @@ -179,7 +182,7 @@ def _rmatvec(self, x: NDArray) -> NDArray:
self.convfunc, self.method = _choose_convfunc(
self.hstar, self.method, self.dims
)
x = np.pad(x, self.padd)
x = ncp.pad(x, self.padd)
y = self.convfunc(self.hstar, x)
if self.dims[self.axis] % 2 == 0:
y = ncp.take(
Expand Down
39 changes: 21 additions & 18 deletions pylops/signalprocessing/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pylops import LinearOperator
from pylops.signalprocessing._baseffts import _BaseFFT, _FFTNorms
from pylops.utils import deps
from pylops.utils.backend import get_array_module
from pylops.utils.decorators import reshaped
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray

Expand Down Expand Up @@ -63,50 +64,52 @@ def __init__(

@reshaped
def _matvec(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
if self.ifftshift_before:
x = np.fft.ifftshift(x, axes=self.axis)
x = ncp.fft.ifftshift(x, axes=self.axis)
if not self.clinear:
x = np.real(x)
x = ncp.real(x)
if self.real:
y = np.fft.rfft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs)
y = ncp.fft.rfft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs)
# Apply scaling to obtain a correct adjoint for this operator
y = np.swapaxes(y, -1, self.axis)
y[..., 1 : 1 + (self.nfft - 1) // 2] *= np.sqrt(2)
y = np.swapaxes(y, self.axis, -1)
y = ncp.swapaxes(y, -1, self.axis)
y[..., 1 : 1 + (self.nfft - 1) // 2] *= ncp.sqrt(2)
y = ncp.swapaxes(y, self.axis, -1)
else:
y = np.fft.fft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs)
y = ncp.fft.fft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs)
if self.norm is _FFTNorms.ONE_OVER_N:
y *= self._scale
if self.fftshift_after:
y = np.fft.fftshift(y, axes=self.axis)
y = ncp.fft.fftshift(y, axes=self.axis)
y = y.astype(self.cdtype)
return y

@reshaped
def _rmatvec(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
if self.fftshift_after:
x = np.fft.ifftshift(x, axes=self.axis)
x = ncp.fft.ifftshift(x, axes=self.axis)
if self.real:
# Apply scaling to obtain a correct adjoint for this operator
x = x.copy()
x = np.swapaxes(x, -1, self.axis)
x[..., 1 : 1 + (self.nfft - 1) // 2] /= np.sqrt(2)
x = np.swapaxes(x, self.axis, -1)
y = np.fft.irfft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs)
x = ncp.swapaxes(x, -1, self.axis)
x[..., 1 : 1 + (self.nfft - 1) // 2] /= ncp.sqrt(2)
x = ncp.swapaxes(x, self.axis, -1)
y = ncp.fft.irfft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs)
else:
y = np.fft.ifft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs)
y = ncp.fft.ifft(x, n=self.nfft, axis=self.axis, **self._norm_kwargs)
if self.norm is _FFTNorms.NONE:
y *= self._scale

if self.nfft > self.dims[self.axis]:
y = np.take(y, range(0, self.dims[self.axis]), axis=self.axis)
y = ncp.take(y, range(0, self.dims[self.axis]), axis=self.axis)
elif self.nfft < self.dims[self.axis]:
y = np.pad(y, self.ifftpad)
y = ncp.pad(y, self.ifftpad)

if not self.clinear:
y = np.real(y)
y = ncp.real(y)
if self.ifftshift_before:
y = np.fft.fftshift(y, axes=self.axis)
y = ncp.fft.fftshift(y, axes=self.axis)
y = y.astype(self.rdtype)
return y

Expand Down
41 changes: 22 additions & 19 deletions pylops/signalprocessing/fft2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from pylops import LinearOperator
from pylops.signalprocessing._baseffts import _BaseFFTND, _FFTNorms
from pylops.utils.backend import get_array_module
from pylops.utils.decorators import reshaped
from pylops.utils.typing import DTypeLike, InputDimsLike

Expand Down Expand Up @@ -67,51 +68,53 @@ def __init__(

@reshaped
def _matvec(self, x):
ncp = get_array_module(x)
if self.ifftshift_before.any():
x = np.fft.ifftshift(x, axes=self.axes[self.ifftshift_before])
x = ncp.fft.ifftshift(x, axes=self.axes[self.ifftshift_before])
if not self.clinear:
x = np.real(x)
x = ncp.real(x)
if self.real:
y = np.fft.rfft2(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
y = ncp.fft.rfft2(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
# Apply scaling to obtain a correct adjoint for this operator
y = np.swapaxes(y, -1, self.axes[-1])
y[..., 1 : 1 + (self.nffts[-1] - 1) // 2] *= np.sqrt(2)
y = np.swapaxes(y, self.axes[-1], -1)
y = ncp.swapaxes(y, -1, self.axes[-1])
y[..., 1 : 1 + (self.nffts[-1] - 1) // 2] *= ncp.sqrt(2)
y = ncp.swapaxes(y, self.axes[-1], -1)
else:
y = np.fft.fft2(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
y = ncp.fft.fft2(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
if self.norm is _FFTNorms.ONE_OVER_N:
y *= self._scale
y = y.astype(self.cdtype)
if self.fftshift_after.any():
y = np.fft.fftshift(y, axes=self.axes[self.fftshift_after])
y = ncp.fft.fftshift(y, axes=self.axes[self.fftshift_after])
return y

@reshaped
def _rmatvec(self, x):
ncp = get_array_module(x)
if self.fftshift_after.any():
x = np.fft.ifftshift(x, axes=self.axes[self.fftshift_after])
x = ncp.fft.ifftshift(x, axes=self.axes[self.fftshift_after])
if self.real:
# Apply scaling to obtain a correct adjoint for this operator
x = x.copy()
x = np.swapaxes(x, -1, self.axes[-1])
x[..., 1 : 1 + (self.nffts[-1] - 1) // 2] /= np.sqrt(2)
x = np.swapaxes(x, self.axes[-1], -1)
y = np.fft.irfft2(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
x = ncp.swapaxes(x, -1, self.axes[-1])
x[..., 1 : 1 + (self.nffts[-1] - 1) // 2] /= ncp.sqrt(2)
x = ncp.swapaxes(x, self.axes[-1], -1)
y = ncp.fft.irfft2(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
else:
y = np.fft.ifft2(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
y = ncp.fft.ifft2(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
if self.norm is _FFTNorms.NONE:
y *= self._scale
if self.nffts[0] > self.dims[self.axes[0]]:
y = np.take(y, range(self.dims[self.axes[0]]), axis=self.axes[0])
y = ncp.take(y, range(self.dims[self.axes[0]]), axis=self.axes[0])
if self.nffts[1] > self.dims[self.axes[1]]:
y = np.take(y, range(self.dims[self.axes[1]]), axis=self.axes[1])
y = ncp.take(y, range(self.dims[self.axes[1]]), axis=self.axes[1])
if self.doifftpad:
y = np.pad(y, self.ifftpad)
y = ncp.pad(y, self.ifftpad)
if not self.clinear:
y = np.real(y)
y = ncp.real(y)
y = y.astype(self.rdtype)
if self.ifftshift_before.any():
y = np.fft.fftshift(y, axes=self.axes[self.ifftshift_before])
y = ncp.fft.fftshift(y, axes=self.axes[self.ifftshift_before])
return y

def __truediv__(self, y):
Expand Down
40 changes: 21 additions & 19 deletions pylops/signalprocessing/fftnd.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy.typing as npt

from pylops.signalprocessing._baseffts import _BaseFFTND, _FFTNorms
from pylops.utils.backend import get_sp_fft
from pylops.utils.backend import get_array_module, get_sp_fft
from pylops.utils.decorators import reshaped
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray

Expand Down Expand Up @@ -56,50 +56,52 @@ def __init__(

@reshaped
def _matvec(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
if self.ifftshift_before.any():
x = np.fft.ifftshift(x, axes=self.axes[self.ifftshift_before])
x = ncp.fft.ifftshift(x, axes=self.axes[self.ifftshift_before])
if not self.clinear:
x = np.real(x)
x = ncp.real(x)
if self.real:
y = np.fft.rfftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
y = ncp.fft.rfftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
# Apply scaling to obtain a correct adjoint for this operator
y = np.swapaxes(y, -1, self.axes[-1])
y[..., 1 : 1 + (self.nffts[-1] - 1) // 2] *= np.sqrt(2)
y = np.swapaxes(y, self.axes[-1], -1)
y = ncp.swapaxes(y, -1, self.axes[-1])
y[..., 1 : 1 + (self.nffts[-1] - 1) // 2] *= ncp.sqrt(2)
y = ncp.swapaxes(y, self.axes[-1], -1)
else:
y = np.fft.fftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
y = ncp.fft.fftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
if self.norm is _FFTNorms.ONE_OVER_N:
y *= self._scale
y = y.astype(self.cdtype)
if self.fftshift_after.any():
y = np.fft.fftshift(y, axes=self.axes[self.fftshift_after])
y = ncp.fft.fftshift(y, axes=self.axes[self.fftshift_after])
return y

@reshaped
def _rmatvec(self, x: NDArray) -> NDArray:
ncp = get_array_module(x)
if self.fftshift_after.any():
x = np.fft.ifftshift(x, axes=self.axes[self.fftshift_after])
x = ncp.fft.ifftshift(x, axes=self.axes[self.fftshift_after])
if self.real:
# Apply scaling to obtain a correct adjoint for this operator
x = x.copy()
x = np.swapaxes(x, -1, self.axes[-1])
x[..., 1 : 1 + (self.nffts[-1] - 1) // 2] /= np.sqrt(2)
x = np.swapaxes(x, self.axes[-1], -1)
y = np.fft.irfftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
x = ncp.swapaxes(x, -1, self.axes[-1])
x[..., 1 : 1 + (self.nffts[-1] - 1) // 2] /= ncp.sqrt(2)
x = ncp.swapaxes(x, self.axes[-1], -1)
y = ncp.fft.irfftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
else:
y = np.fft.ifftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
y = ncp.fft.ifftn(x, s=self.nffts, axes=self.axes, **self._norm_kwargs)
if self.norm is _FFTNorms.NONE:
y *= self._scale
for ax, nfft in zip(self.axes, self.nffts):
if nfft > self.dims[ax]:
y = np.take(y, range(self.dims[ax]), axis=ax)
y = ncp.take(y, range(self.dims[ax]), axis=ax)
if self.doifftpad:
y = np.pad(y, self.ifftpad)
y = ncp.pad(y, self.ifftpad)
if not self.clinear:
y = np.real(y)
y = ncp.real(y)
y = y.astype(self.rdtype)
if self.ifftshift_before.any():
y = np.fft.fftshift(y, axes=self.axes[self.ifftshift_before])
y = ncp.fft.fftshift(y, axes=self.axes[self.ifftshift_before])
return y

def __truediv__(self, y: npt.ArrayLike) -> npt.ArrayLike:
Expand Down
Loading
Loading