Skip to content

Commit

Permalink
Add normalization options for FFT (#146)
Browse files Browse the repository at this point in the history
* Add normalization options for FFT

* Harmonize function signatures
  • Loading branch information
mmuckley authored Jul 20, 2021
1 parent 0da79ae commit c2432ff
Showing 1 changed file with 21 additions and 8 deletions.
29 changes: 21 additions & 8 deletions fastmri/fftc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,58 +14,70 @@
import torch.fft # type: ignore


def fft2c_old(data: torch.Tensor) -> torch.Tensor:
def fft2c_old(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor:
"""
Apply centered 2 dimensional Fast Fourier Transform.
Args:
data: Complex valued input data containing at least 3 dimensions:
dimensions -3 & -2 are spatial dimensions and dimension -1 has size
2. All other dimensions are assumed to be batch dimensions.
norm: Whether to include normalization. Must be one of ``"backward"``
or ``"ortho"``. See ``torch.fft.fft`` on PyTorch 1.9.0 for details.
Returns:
The FFT of the input.
"""
if not data.shape[-1] == 2:
raise ValueError("Tensor does not have separate complex dim.")
if norm not in ("ortho", "backward"):
raise ValueError("norm must be 'ortho' or 'backward'.")
normalized = True if norm == "ortho" else False

data = ifftshift(data, dim=[-3, -2])
data = torch.fft(data, 2, normalized=True)
data = torch.fft(data, 2, normalized=normalized)
data = fftshift(data, dim=[-3, -2])

return data


def ifft2c_old(data: torch.Tensor) -> torch.Tensor:
def ifft2c_old(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor:
"""
Apply centered 2-dimensional Inverse Fast Fourier Transform.
Args:
data: Complex valued input data containing at least 3 dimensions:
dimensions -3 & -2 are spatial dimensions and dimension -1 has size
2. All other dimensions are assumed to be batch dimensions.
norm: Whether to include normalization. Must be one of ``"backward"``
or ``"ortho"``. See ``torch.fft.ifft`` on PyTorch 1.9.0 for
details.
Returns:
The IFFT of the input.
"""
if not data.shape[-1] == 2:
raise ValueError("Tensor does not have separate complex dim.")
if norm not in ("ortho", "backward"):
raise ValueError("norm must be 'ortho' or 'backward'.")
normalized = True if norm == "ortho" else False

data = ifftshift(data, dim=[-3, -2])
data = torch.ifft(data, 2, normalized=True)
data = torch.ifft(data, 2, normalized=normalized)
data = fftshift(data, dim=[-3, -2])

return data


def fft2c_new(data: torch.Tensor) -> torch.Tensor:
def fft2c_new(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor:
"""
Apply centered 2 dimensional Fast Fourier Transform.
Args:
data: Complex valued input data containing at least 3 dimensions:
dimensions -3 & -2 are spatial dimensions and dimension -1 has size
2. All other dimensions are assumed to be batch dimensions.
norm: Normalization mode. See ``torch.fft.fft``.
Returns:
The FFT of the input.
Expand All @@ -76,22 +88,23 @@ def fft2c_new(data: torch.Tensor) -> torch.Tensor:
data = ifftshift(data, dim=[-3, -2])
data = torch.view_as_real(
torch.fft.fftn( # type: ignore
torch.view_as_complex(data), dim=(-2, -1), norm="ortho"
torch.view_as_complex(data), dim=(-2, -1), norm=norm
)
)
data = fftshift(data, dim=[-3, -2])

return data


def ifft2c_new(data: torch.Tensor) -> torch.Tensor:
def ifft2c_new(data: torch.Tensor, norm: str = "ortho") -> torch.Tensor:
"""
Apply centered 2-dimensional Inverse Fast Fourier Transform.
Args:
data: Complex valued input data containing at least 3 dimensions:
dimensions -3 & -2 are spatial dimensions and dimension -1 has size
2. All other dimensions are assumed to be batch dimensions.
norm: Normalization mode. See ``torch.fft.ifft``.
Returns:
The IFFT of the input.
Expand All @@ -102,7 +115,7 @@ def ifft2c_new(data: torch.Tensor) -> torch.Tensor:
data = ifftshift(data, dim=[-3, -2])
data = torch.view_as_real(
torch.fft.ifftn( # type: ignore
torch.view_as_complex(data), dim=(-2, -1), norm="ortho"
torch.view_as_complex(data), dim=(-2, -1), norm=norm
)
)
data = fftshift(data, dim=[-3, -2])
Expand Down

0 comments on commit c2432ff

Please sign in to comment.