From e3a73564b9eb849ee1cd9686dbe6b3556c5a3510 Mon Sep 17 00:00:00 2001 From: GiannakopoulosIlias Date: Thu, 23 May 2024 09:53:40 -0400 Subject: [PATCH 1/6] add feature varnet --- LIST_OF_PAPERS.md | 22 + README.md | 3 + fastmri/models/feature_varnet.py | 1575 +++++++++++++++++ fastmri_examples/README.md | 1 + fastmri_examples/feature_varnet/README.md | 72 + .../feature_varnet/pl_modules/__init__.py | 9 + .../pl_modules/feature_varnet_module.py | 158 ++ .../feature_varnet/requirements.txt | 11 + .../feature_varnet/train_feature_varnet.py | 270 +++ 9 files changed, 2121 insertions(+) create mode 100644 fastmri/models/feature_varnet.py create mode 100644 fastmri_examples/feature_varnet/README.md create mode 100644 fastmri_examples/feature_varnet/pl_modules/__init__.py create mode 100644 fastmri_examples/feature_varnet/pl_modules/feature_varnet_module.py create mode 100644 fastmri_examples/feature_varnet/requirements.txt create mode 100644 fastmri_examples/feature_varnet/train_feature_varnet.py diff --git a/LIST_OF_PAPERS.md b/LIST_OF_PAPERS.md index d0e1e31d..28f7117d 100644 --- a/LIST_OF_PAPERS.md +++ b/LIST_OF_PAPERS.md @@ -17,6 +17,7 @@ The following is a short list of fastMRI publications. Clicking on the title wil 13. Bakker, T., Muckley, M.J., Romero-Soriano, A., Drozdzal, M. & Pineda, L. (2022). [On learning adaptive acquisition policies for undersampled multi-coil MRI reconstruction](#on-learning-adaptive-acquisition-policies-for-undersampled-multi-coil-mri-reconstruction). In * *International Conference on Medical Imaging with Deep Learning*, pages 63-85. 14. Radmanesh, A.\*, Muckley, M. J.\*, Murrell, T., Lindsey, E., Sriram, A., Knoll, F., ... & Lui, Y. W. (2022). [Exploring the Acceleration Limits of Deep Learning VarNet-based Two-dimensional Brain MRI](#exploring-the-acceleration-limits-of-deep-learning-varnet-based-two-dimensional-brain-mri). *Radiology: Artificial Intelligence*, 4(6), page e210313. 15. Johnson, Patricia M., Lin, D. J., Zbontar, J., Zitnick, C. L., Sriram, A., Mucklye, M., ..., & Knoll, F. (2023). [Deep learning reconstruction enables prospectively accelerated clinical knee MRI](#deep-learning-reconstruction-enables-prospectively-accelerated-clinical-knee-mri) *Radiology*, page 220425. +16. Giannakopoulos, I. I., Muckley, M. J., Kim, J., Breen, M., Johnson, P. M., Lui, Y. W., & Lattanzi, R. (2024). [Accelerated MRI reconstructions via variational network and feature domain learning](#accelerated-mri-reconstructions-via-variational-network-and-feature-domain-learning) *Scientific Reports*, 14(1), 10991. ## fastMRI: An open dataset and benchmarks for accelerated MRI @@ -355,3 +356,24 @@ In a clinical setting, deep learning reconstruction enabled a nearly twofold red doi = {10.1148/radiol.220425}, } ``` + +## Accelerated MRI reconstructions via variational network and feature domain learning + +[Publication](https://doi.org/10.1038/s41598-024-59705-0) [Code](https://github.com/facebookresearch/fastMRI/tree/main/fastmri_examples/feature_varnet) + +**Abstract** + +We introduce three architecture modifications to enhance the performance of the end-to-end (E2E) variational network (VarNet) for undersampled MRI reconstructions. We first implemented the Feature VarNet, which propagates information throughout the cascades of the network in an N-channel feature-space instead of a 2-channel feature-space. Then, we add an attention layer that utilizes the spatial locations of Cartesian undersampling artifacts to further improve performance. Lastly, we combined the Feature and E2E VarNets into the Feature-Image (FI) VarNet, to facilitate cross-domain learning and boost accuracy. Reconstructions were evaluated on the fastMRI dataset using standard metrics and clinical scoring by three neuroradiologists. Feature and FI VarNets outperformed the E2E VarNet for 4, 5 and 8 Cartesian undersampling in all studied metrics. FI VarNet secured second place in the public fastMRI leaderboard for 4 Cartesian undersampling, outperforming all open-source models in the leaderboard. Radiologists rated FI VarNet brain reconstructions with higher quality and sharpness than the E2E VarNet reconstructions. FI VarNet excelled in preserving anatomical details, including blood vessels, whereas E2E VarNet discarded or blurred them in some cases. The proposed FI VarNet enhances the reconstruction quality of undersampled MRI and could enable clinically acceptable reconstructions at higher acceleration factors than currently possible. + +```BibTeX +@article{giannakopoulos2024accelerated, + title={Accelerated MRI reconstructions via variational network and feature domain learning}, + author={Giannakopoulos, Ilias I and Muckley, Matthew J and Kim, Jesi and Breen, Matthew and Johnson, Patricia M and Lui, Yvonne W and Lattanzi, Riccardo}, + journal={Scientific Reports}, + volume={14}, + number={1}, + pages={10991}, + year={2024}, + publisher={Nature Publishing Group UK London} +} +``` diff --git a/README.md b/README.md index 55e9840f..2734c01e 100644 --- a/README.md +++ b/README.md @@ -122,6 +122,7 @@ in another repository. * [End-to-End Variational Networks for Accelerated MRI Reconstruction ({A. Sriram*, J. Zbontar*} et al., 2020)](https://github.com/facebookresearch/fastMRI/tree/master/fastmri_examples/varnet/) * [MRI Banding Removal via Adversarial Training (A. Defazio, et al., 2020)](https://github.com/facebookresearch/fastMRI/tree/master/banding_removal) * [Deep Learning Reconstruction Enables Prospectively Accelerated Clinical Knee MRI (P. Johnson et al., 2023)](https://github.com/facebookresearch/fastMRI/tree/main/fastmri_examples/RadiologyJohnson2022) + * [Accelerated MRI reconstructions via variational network and feature domain learning (I. Giannakopoulos et al., 2024)](https://github.com/facebookresearch/fastMRI/tree/main/fastmri_examples/feature_varnet) * **Active Acquisition** * (external repository) [Reducing uncertainty in undersampled MRI reconstruction with active acquisition (Z. Zhang et al., 2019)](https://github.com/facebookresearch/active-mri-acquisition/tree/master/activemri/experimental/cvpr19_models) @@ -212,3 +213,5 @@ corresponding abstracts, as well as links to preprints and code can be found 14. Radmanesh, A.\*, Muckley, M. J.\*, Murrell, T., Lindsey, E., Sriram, A., Knoll, F., ... & Lui, Y. W. (2022). [Exploring the Acceleration Limits of Deep Learning VarNet-based Two-dimensional Brain MRI](https://doi.org/10.1148/ryai.210313). *Radiology: Artificial Intelligence*, 4(6), page e210313. 15. Johnson, P.M., Lin, D.J., Zbontar, J., Zitnick, C.L., Sriram, A., Muckley, M., Babb, J.S., Kline, M., Ciavarra, G., Alaia, E., ..., & Knoll, F. (2023). [Deep Learning Reconstruction Enables Prospectively Accelerated Clinical Knee MRI](https://doi.org/10.1148/radiol.220425). *Radiology*, 307(2), page e220425. 16. Tibrewala, R., Dutt, T., Tong, A., Ginocchio, L., Keerthivasan, M.B., Baete, S.H., Lui, Y.W., Sodickson, D.K., Chandarana, H., Johnson, P.M. (2023). [FastMRI Prostate: A Publicly Available, Biparametric MRI Dataset to Advance Machine Learning for Prostate Cancer Imaging](https://arxiv.org/abs/2304.09254). *arXiv preprint, arXiv:2034.09254*. +16. Giannakopoulos, I. I., Muckley, M. J., Kim, J., Breen, M., Johnson, P. M., Lui, Y. W., Lattanzi, R. (2024). [Accelerated MRI reconstructions via variational network and feature domain learning. Scientific Reports](https://www.nature.com/articles/s41598-024-59705-0). *Scientific Reports, 14(1), 10991*. + diff --git a/fastmri/models/feature_varnet.py b/fastmri/models/feature_varnet.py new file mode 100644 index 00000000..a325ea6e --- /dev/null +++ b/fastmri/models/feature_varnet.py @@ -0,0 +1,1575 @@ +from typing import NamedTuple, Optional, Tuple, List +import math +import torch +import torch.nn as nn +from torch import Tensor +torch.set_float32_matmul_precision('high') +import torch.nn.functional as F +import torch.distributed as dist +import numpy as np +import math +from fastmri.data.transforms import center_crop, batched_mask_center +from fastmri.fftc import ifft2c_new as ifft2c +from fastmri.fftc import fft2c_new as fft2c +from fastmri.coil_combine import rss_complex, rss +from fastmri.math import complex_abs, complex_mul, complex_conj + +def image_crop(image: Tensor, crop_size: Optional[Tuple[int, int]] = None) -> Tensor: + if crop_size is None: + return image + return center_crop(image, crop_size).contiguous() + +def _calc_uncrop(crop_height: int, in_height: int) -> Tuple[int, int]: + pad_height = (in_height - crop_height) // 2 + if (in_height - crop_height) % 2 != 0: + pad_height_top = pad_height + 1 + else: + pad_height_top = pad_height + + pad_height = in_height - pad_height + + return pad_height_top, pad_height + +def image_uncrop(image: Tensor, original_image: Tensor) -> Tensor: + """Insert values back into original image.""" + in_shape = original_image.shape + original_image = original_image.clone() + + if in_shape == image.shape: + return image + + pad_height_top, pad_height = _calc_uncrop(image.shape[-2], in_shape[-2]) + pad_height_left, pad_width = _calc_uncrop(image.shape[-1], in_shape[-1]) + + try: + original_image[ + ..., pad_height_top:pad_height, pad_height_left:pad_width + ] = image[...] + except RuntimeError: + print(f"in_shape: {in_shape}, image shape: {image.shape}") + raise + + return original_image + +def norm_fn(image: Tensor, means: Tensor, variances: Tensor) -> Tensor: + means = means.view(1, -1, 1, 1) + variances = variances.view(1, -1, 1, 1) + return (image - means) * torch.rsqrt(variances) + +def unnorm_fn(image: Tensor, means: Tensor, variances: Tensor) -> Tensor: + means = means.view(1, -1, 1, 1) + variances = variances.view(1, -1, 1, 1) + return image * torch.sqrt(variances) + means + +def complex_to_chan_dim(x: Tensor) -> Tensor: + b, c, h, w, two = x.shape + assert two == 2 + assert c == 1 + return x.permute(0, 4, 1, 2, 3).reshape(b, 2 * c, h, w) + +def chan_complex_to_last_dim(x: Tensor) -> Tensor: + b, c2, h, w = x.shape + assert c2 == 2 + c = c2 // 2 + return x.view(b, 2, c, h, w).permute(0, 2, 3, 4, 1).contiguous() + +def sens_expand(x: Tensor, sens_maps: Tensor) -> Tensor: + return fft2c(complex_mul(chan_complex_to_last_dim(x), sens_maps)) + +def sens_reduce(x: Tensor, sens_maps: Tensor) -> Tensor: + return complex_to_chan_dim( + complex_mul(ifft2c(x), complex_conj(sens_maps)).sum( + dim=1, keepdim=True + ) + ) + +class NormStats(nn.Module): + def forward(self, data: Tensor) -> Tuple[Tensor, Tensor]: + # group norm + batch, chans, _, _ = data.shape + + if batch != 1: + raise ValueError("Unexpected input dimensions.") + + data = data.view(chans, -1) + + mean = data.mean(dim=1) + variance = data.var(dim=1, unbiased=False) + + assert mean.ndim == 1 + assert variance.ndim == 1 + assert mean.shape[0] == chans + assert variance.shape[0] == chans + + return mean, variance + +class RunningChannelStats(nn.Module): + def __init__(self, chans: int, eps: float = 1e-14, freeze_step: int = 20000): + super().__init__() + + self.means: Tensor + self.vars: Tensor + self.current_step: Tensor + self.eps = eps + self.chans = chans + self.freeze_step = freeze_step + + self.register_buffer("current_step", torch.zeros(1, dtype=torch.int)) + self.register_buffer("means", torch.zeros(chans)) + self.register_buffer("vars", torch.zeros(chans)) + + def forward(self, image: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + if image.shape[1] != self.chans: + raise ValueError("Invalid channel number.") + + if self.current_step < self.freeze_step and self.training: + stats = image.permute(1, 0, 2, 3).reshape(self.chans, -1) + mean = stats.mean(1) + var = stats.var(1, unbiased=True) + + var = var / dist.get_world_size() + self.means.copy_(self.means + (mean - self.means) / (self.current_step + 1)) + self.vars.copy_(self.vars + (var - self.vars) / (self.current_step + 1)) + + self.current_step += 1 + + if self.current_step == 0 and not self.training: + stats = image.permute(1, 0, 2, 3).reshape(self.chans, -1) + run_mean = stats.mean(1).view(1, -1, 1, 1) + run_var = (stats.var(1, unbiased=True) + self.eps).view(1, -1, 1, 1) + else: + run_mean = self.means.clone().view(1, -1, 1, 1) + run_var = self.vars.clone().view(1, -1, 1, 1) + self.eps + + return run_mean, run_var + +class FeatureImage(NamedTuple): + features: Tensor + sens_maps: Tensor = None + crop_size: Optional[Tuple[int, int]] = None + means: Tensor = None + variances: Tensor = None + mask: Tensor = None + ref_kspace: Tensor = None + beta: Optional[Tensor] = None + gamma: Optional[Tensor] = None + +class FeatureEncoder(nn.Module): + def __init__(self, in_chans: int, feature_chans: int = 32, drop_prob: float = 0.0): + super().__init__() + self.feature_chans = feature_chans + + self.encoder = nn.Sequential( + nn.Conv2d( + in_channels=in_chans, + out_channels=feature_chans, + kernel_size=5, + padding=2, + bias=True, + ), + ) + + def forward(self, image: Tensor, means: Tensor, variances: Tensor) -> Tensor: + means = means.view(1, -1, 1, 1) + variances = variances.view(1, -1, 1, 1) + return self.encoder((image - means) * torch.rsqrt(variances)) + +class FeatureDecoder(nn.Module): + def __init__(self, feature_chans: int = 32, out_chans: int = 2): + super().__init__() + self.feature_chans = feature_chans + + self.decoder = nn.Conv2d( + in_channels=feature_chans, + out_channels=out_chans, + kernel_size=5, + padding=2, + bias=True, + ) + + def forward(self, features: Tensor, means: Tensor, variances: Tensor) -> Tensor: + means = means.view(1, -1, 1, 1) + variances = variances.view(1, -1, 1, 1) + return self.decoder(features) * torch.sqrt(variances) + means + +class AttentionPE(nn.Module): + def __init__(self, in_chans: int): + super().__init__() + self.in_chans = in_chans + + self.norm = nn.InstanceNorm2d(in_chans) + self.q = nn.Conv2d(in_chans, in_chans, kernel_size=1, stride=1, padding=0) + self.k = nn.Conv2d(in_chans, in_chans, kernel_size=1, stride=1, padding=0) + self.v = nn.Conv2d(in_chans, in_chans, kernel_size=1, stride=1, padding=0) + self.proj_out = nn.Conv2d(in_chans, in_chans, kernel_size=1, stride=1, padding=0) + self.dilated_conv = nn.Conv2d(in_chans, in_chans, kernel_size=3, stride=1, padding=2, dilation=2) + + def reshape_to_blocks(self, x: Tensor, accel: int) -> Tensor: + chans = x.shape[1] + pad_total = (accel - (x.shape[3] - accel)) % accel + pad_right = pad_total // 2 + pad_left = pad_total - pad_right + x = F.pad(x, (pad_left, pad_right, 0, 0), "reflect") + return (torch.stack(x.chunk(chunks=accel, dim=3), dim=-1).view(chans, -1, accel).permute(1, 0, 2).contiguous()) + + def reshape_from_blocks(self, x: Tensor, image_size: Tuple[int, int], accel: int) -> Tensor: + chans = x.shape[1] + num_freq, num_phase = image_size + x = (x.permute(1, 0, 2).reshape(1, chans, num_freq, -1, accel).permute(0, 1, 2, 4, 3).reshape(1, chans, num_freq, -1)) + padded_phase = x.shape[3] + pad_total = padded_phase - num_phase + pad_right = pad_total // 2 + pad_left = pad_total - pad_right + return x[:, :, :, pad_left : padded_phase - pad_right] + + def get_positional_encodings(self, seq_len: int, embed_dim: int, device: str) -> Tensor: + freqs = torch.tensor([1 / (10000 ** (2 * (i // 2) / embed_dim)) for i in range(embed_dim)], device=device) + freqs = freqs.unsqueeze(0) + positions = torch.arange(seq_len, dtype=torch.float, device=device).unsqueeze(1) + scaled = positions * freqs + sin_encodings = torch.sin(scaled) + cos_encodings = torch.cos(scaled) + encodings = torch.cat([sin_encodings, cos_encodings], dim=1)[:,:embed_dim] + return encodings + + def forward(self, x: Tensor, accel: int) -> Tensor: + im_size = (x.shape[2], x.shape[3]) + h_ = x + h_ = self.norm(h_) + + pos_enc = self.get_positional_encodings(x.shape[2], x.shape[3], h_.device) + + h_ = h_ + pos_enc + + q = self.dilated_conv(self.q(h_)) + k = self.dilated_conv(self.k(h_)) + v = self.dilated_conv(self.v(h_)) + + # compute attention + c = q.shape[1] + q = self.reshape_to_blocks(q, accel) + k = self.reshape_to_blocks(k, accel) + q = q.permute(0, 2, 1) # b,hw,c + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = self.reshape_to_blocks(v, accel) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = self.reshape_from_blocks(h_, im_size, accel) + + h_ = self.proj_out(h_) + + return x + h_ + +class Unet(nn.Module): + """ + PyTorch implementation of a U-Net model. + + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. + """ + + def __init__( + self, + in_chans: int, + out_chans: int, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + ): + """ + Args: + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. + """ + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.chans = chans + self.num_pool_layers = num_pool_layers + self.drop_prob = drop_prob + + self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) + ch = chans + for _ in range(num_pool_layers - 1): + self.down_sample_layers.append(ConvBlock(ch, ch * 2, drop_prob)) + ch *= 2 + self.conv = ConvBlock(ch, ch * 2, drop_prob) + + self.up_conv = nn.ModuleList() + self.up_transpose_conv = nn.ModuleList() + for _ in range(num_pool_layers - 1): + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append(ConvBlock(ch * 2, ch, drop_prob)) + ch //= 2 + + self.up_transpose_conv.append(TransposeConvBlock(ch * 2, ch)) + self.up_conv.append( + nn.Sequential( + ConvBlock(ch * 2, ch, drop_prob), + nn.Conv2d(ch, self.out_chans, kernel_size=1, stride=1), + ) + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + """ + Args: + image: Input 4D tensor of shape `(N, in_chans, H, W)`. + + Returns: + Output tensor of shape `(N, out_chans, H, W)`. + """ + stack = [] + output = image + + # apply down-sampling layers + for layer in self.down_sample_layers: + output = layer(output) + stack.append(output) + output = F.avg_pool2d(output, kernel_size=2, stride=2, padding=0) + + output = self.conv(output) + + # apply up-sampling layers + for transpose_conv, conv in zip(self.up_transpose_conv, self.up_conv): + downsample_layer = stack.pop() + output = transpose_conv(output) + + # reflect pad on the right/botton if needed to handle odd input dimensions + padding = [0, 0, 0, 0] + if output.shape[-1] != downsample_layer.shape[-1]: + padding[1] = 1 # padding right + if output.shape[-2] != downsample_layer.shape[-2]: + padding[3] = 1 # padding bottom + if torch.sum(torch.tensor(padding)) != 0: + output = F.pad(output, padding, "reflect") + + output = torch.cat([output, downsample_layer], dim=1) + output = conv(output) + + return output + +class Unet2d(nn.Module): + def __init__( + self, + in_chans: int, + out_chans: int, + chans: int = 32, + num_pool_layers: int = 4, + drop_prob: float = 0.0, + output_bias: bool = False, + ): + super().__init__() + self.in_chans = in_chans + self.out_planes = out_chans + self.factor = 2**num_pool_layers + + # Build from the middle of the UNet outwards + planes = 2 ** (num_pool_layers) + layer = None + for _ in range(num_pool_layers): + planes = planes // 2 + layer = UnetLevel( + layer, + in_planes=planes * chans, + out_planes=2 * planes * chans, + drop_prob=drop_prob, + ) + + self.layer = UnetLevel( + layer, in_planes=in_chans, out_planes=chans, drop_prob=drop_prob + ) + + if output_bias: + self.final_conv = nn.Conv2d( + in_channels=chans, + out_channels=out_chans, + kernel_size=1, + stride=1, + padding=0, + bias=True, + ) + else: + self.final_conv = nn.Sequential( + nn.Conv2d( + in_channels=chans, + out_channels=out_chans, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def pad_input_image(self, image: Tensor) -> Tuple[Tensor, Tuple[int, int]]: + # pad image if it's not divisible by downsamples + _, _, height, width = image.shape + pad_height = (self.factor - (height - self.factor)) % self.factor + pad_width = (self.factor - (width - self.factor)) % self.factor + if pad_height != 0 or pad_width != 0: + image = F.pad(image, (0, pad_width, 0, pad_height), mode="reflect") + + return image, (height, width) + + def forward(self, image: Tensor) -> Tensor: + image, (output_y, output_x) = self.pad_input_image(image) + return self.final_conv(self.layer(image))[:, :, :output_y, :output_x] + +class UnetLevel(nn.Module): + def __init__( + self, + child: Optional[nn.Module], + in_planes: int, + out_planes: int, + drop_prob: float = 0.0, + ): + super().__init__() + self.in_planes = in_planes + self.out_planes = out_planes + + self.left_block = ConvBlock( + in_chans=in_planes, out_chans=out_planes, drop_prob=drop_prob + ) + + self.child = child + + if child is not None: + self.downsample = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) + self.upsample = TransposeConvBlock( + in_chans=child.out_planes, out_chans=out_planes + ) + self.right_block = ConvBlock( + in_chans=2 * out_planes, out_chans=out_planes, drop_prob=drop_prob + ) + + def down_up(self, image: Tensor) -> Tensor: + return self.upsample(self.child(self.downsample(image))) + + def forward(self, image: Tensor) -> Tensor: + image = self.left_block(image) + + if self.child is not None: + image = self.right_block(torch.cat((image, self.down_up(image)), 1)) + + return image + +class ConvBlock(nn.Module): + def __init__(self, in_chans: int, out_chans: int, drop_prob: float): + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + self.drop_prob = drop_prob + + self.layers = nn.Sequential( + nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Dropout2d(drop_prob), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + return self.layers(image) + +class TransposeConvBlock(nn.Module): + def __init__(self, in_chans: int, out_chans: int): + super().__init__() + + self.in_chans = in_chans + self.out_chans = out_chans + + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_chans, out_chans, kernel_size=2, stride=2, bias=False + ), + nn.InstanceNorm2d(out_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + def forward(self, image: torch.Tensor) -> torch.Tensor: + return self.layers(image) + +class NormUnet(nn.Module): + + def __init__( + self, + chans: int, + num_pools: int, + in_chans: int = 2, + out_chans: int = 2, + drop_prob: float = 0.0, + ): + + super().__init__() + + self.unet = Unet( + in_chans=in_chans, + out_chans=out_chans, + chans=chans, + num_pool_layers=num_pools, + drop_prob=drop_prob, + ) + + def complex_to_chan_dim(self, x: torch.Tensor) -> torch.Tensor: + b, c, h, w, two = x.shape + assert two == 2 + return x.permute(0, 4, 1, 2, 3).reshape(b, 2 * c, h, w) + + def chan_complex_to_last_dim(self, x: torch.Tensor) -> torch.Tensor: + b, c2, h, w = x.shape + assert c2 % 2 == 0 + c = c2 // 2 + return x.view(b, 2, c, h, w).permute(0, 2, 3, 4, 1).contiguous() + + def norm(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # group norm + b, c, h, w = x.shape + x = x.view(b, c, h * w) + + mean = x.mean(dim=2).view(b, c, 1, 1) + std = x.std(dim=2).view(b, c, 1, 1) + + x = x.view(b, c, h, w) + + return (x - mean) / std, mean, std + + def unnorm( + self, x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor + ) -> torch.Tensor: + return x * std + mean + + def pad( + self, x: torch.Tensor + ) -> Tuple[torch.Tensor, Tuple[List[int], List[int], int, int]]: + _, _, h, w = x.shape + w_mult = ((w - 1) | 15) + 1 + h_mult = ((h - 1) | 15) + 1 + w_pad = [math.floor((w_mult - w) / 2), math.ceil((w_mult - w) / 2)] + h_pad = [math.floor((h_mult - h) / 2), math.ceil((h_mult - h) / 2)] + x = F.pad(x, w_pad + h_pad) + + return x, (h_pad, w_pad, h_mult, w_mult) + + def unpad( + self, + x: torch.Tensor, + h_pad: List[int], + w_pad: List[int], + h_mult: int, + w_mult: int, + ) -> torch.Tensor: + return x[..., h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1]] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if not x.shape[-1] == 2: + raise ValueError("Last dimension must be 2 for complex.") + + # get shapes for unet and normalize + x = self.complex_to_chan_dim(x) + x, mean, std = self.norm(x) + x, pad_sizes = self.pad(x) + + #attention_goes_here + x = self.unet(x) + + # get shapes back and unnormalize + x = self.unpad(x, *pad_sizes) + x = self.unnorm(x, mean, std) + x = self.chan_complex_to_last_dim(x) + + return x + +class Norm1DUnet(nn.Module): + + def __init__( + self, + chans: int, + num_pools: int, + in_chans: int = 2, + out_chans: int = 2, + drop_prob: float = 0.0, + ): + + super().__init__() + + self.unet = Unet( + in_chans=in_chans, + out_chans=out_chans, + chans=chans, + num_pool_layers=num_pools, + drop_prob=drop_prob, + ) + + def complex_to_chan_dim(self, x: torch.Tensor) -> torch.Tensor: + b, c, h, w, two = x.shape + assert two == 2 + return x.permute(0, 4, 1, 2, 3).reshape(b, 2 * c, h, w) + + def chan_complex_to_last_dim(self, x: torch.Tensor) -> torch.Tensor: + b, c2, h, w = x.shape + assert c2 % 2 == 0 + c = c2 // 2 + return x.view(b, 2, c, h, w).permute(0, 2, 3, 4, 1).contiguous() + + def norm(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # group norm + b, c, h, w = x.shape + x = x.view(b, c, h * w) + + mean = x.mean() + std = x.std() + + x = x.view(b, c, h, w) + + return (x - mean) / std, mean, std + + def unnorm( + self, x: torch.Tensor, mean: torch.Tensor, std: torch.Tensor + ) -> torch.Tensor: + return x * std + mean + + def pad( + self, x: torch.Tensor + ) -> Tuple[torch.Tensor, Tuple[List[int], List[int], int, int]]: + _, _, h, w = x.shape + w_mult = ((w - 1) | 15) + 1 + h_mult = ((h - 1) | 15) + 1 + w_pad = [math.floor((w_mult - w) / 2), math.ceil((w_mult - w) / 2)] + h_pad = [math.floor((h_mult - h) / 2), math.ceil((h_mult - h) / 2)] + x = F.pad(x, w_pad + h_pad) + + return x, (h_pad, w_pad, h_mult, w_mult) + + def unpad( + self, + x: torch.Tensor, + h_pad: List[int], + w_pad: List[int], + h_mult: int, + w_mult: int, + ) -> torch.Tensor: + return x[..., h_pad[0] : h_mult - h_pad[1], w_pad[0] : w_mult - w_pad[1]] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if not x.shape[-1] == 2: + raise ValueError("Last dimension must be 2 for complex.") + + # get shapes for unet and normalize + x = self.complex_to_chan_dim(x) + x, mean, std = self.norm(x) + x, pad_sizes = self.pad(x) + + #attention_goes_here + x = self.unet(x) + + # get shapes back and unnormalize + x = self.unpad(x, *pad_sizes) + x = self.unnorm(x, mean, std) + x = self.chan_complex_to_last_dim(x) + + return x + +class SensitivityModel(nn.Module): + """ + Model for learning sensitivity estimation from k-space data. + + This model applies an IFFT to multichannel k-space data and then a U-Net + to the coil images to estimate coil sensitivities. It can be used with the + end-to-end variational network. + """ + + def __init__( + self, + chans: int, + num_pools: int, + in_chans: int = 2, + out_chans: int = 2, + drop_prob: float = 0.0, + mask_center: bool = True, + ): + """ + Args: + chans: Number of output channels of the first convolution layer. + num_pools: Number of down-sampling and up-sampling layers. + in_chans: Number of channels in the input to the U-Net model. + out_chans: Number of channels in the output to the U-Net model. + drop_prob: Dropout probability. + mask_center: Whether to mask center of k-space for sensitivity map + calculation. + """ + super().__init__() + self.mask_center = mask_center + self.norm_unet = NormUnet( + chans, + num_pools, + in_chans=in_chans, + out_chans=out_chans, + drop_prob=drop_prob, + ) + + def chans_to_batch_dim(self, x: torch.Tensor) -> Tuple[torch.Tensor, int]: + b, c, h, w, comp = x.shape + + return x.view(b * c, 1, h, w, comp), b + + def batch_chans_to_chan_dim(self, x: torch.Tensor, batch_size: int) -> torch.Tensor: + bc, _, h, w, comp = x.shape + c = bc // batch_size + + return x.view(batch_size, c, h, w, comp) + + def divide_root_sum_of_squares(self, x: torch.Tensor) -> torch.Tensor: + return x / rss_complex(x, dim=1).unsqueeze(-1).unsqueeze(1) + + def get_pad_and_num_low_freqs( + self, mask: torch.Tensor, num_low_frequencies: Optional[int] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + if num_low_frequencies is None or num_low_frequencies == 0: + # get low frequency line locations and mask them out + squeezed_mask = mask[:, 0, 0, :, 0].to(torch.int8) + cent = squeezed_mask.shape[1] // 2 + # running argmin returns the first non-zero + left = torch.argmin(squeezed_mask[:, :cent].flip(1), dim=1) + right = torch.argmin(squeezed_mask[:, cent:], dim=1) + num_low_frequencies_tensor = torch.max( + 2 * torch.min(left, right), torch.ones_like(left) + ) # force a symmetric center unless 1 + else: + num_low_frequencies_tensor = num_low_frequencies * torch.ones( + mask.shape[0], dtype=mask.dtype, device=mask.device + ) + + pad = torch.div(mask.shape[-2] - num_low_frequencies_tensor + 1,2,rounding_mode='trunc') + + return pad, num_low_frequencies_tensor + + def forward( + self, + masked_kspace: torch.Tensor, + mask: torch.Tensor, + num_low_frequencies: Optional[int] = None, + ) -> torch.Tensor: + if self.mask_center: + pad, num_low_freqs = self.get_pad_and_num_low_freqs( + mask, num_low_frequencies + ) + masked_kspace = batched_mask_center( + masked_kspace, pad, pad + num_low_freqs + ) + + # convert to image space + images, batches = self.chans_to_batch_dim(ifft2c(masked_kspace)) + + # estimate sensitivities + return self.divide_root_sum_of_squares( + self.batch_chans_to_chan_dim(self.norm_unet(images), batches) + ) + +class FIVarNet(nn.Module): + def __init__( + self, + num_cascades: int = 12, + sens_chans: int = 8, + sens_pools: int = 4, + chans: int = 18, + pools: int = 4, + acceleration: int = 4, + mask_center: bool = True, + image_conv_cascades: Optional[List[int]] = None, + kspace_mult_factor: float = 1e6, + ): + super().__init__() + if image_conv_cascades is None: + image_conv_cascades = [ind for ind in range(num_cascades) if ind % 3 == 0] + + self.image_conv_cascades = image_conv_cascades + self.kspace_mult_factor = kspace_mult_factor + self.sens_net = SensitivityModel( + chans=sens_chans, + num_pools=sens_pools, + mask_center=mask_center, + ) + self.encoder = FeatureEncoder(in_chans=2, feature_chans=chans) + self.decoder = FeatureDecoder(feature_chans=chans, out_chans=2) + cascades = [] + for ind in range(num_cascades): + use_image_conv = ind in self.image_conv_cascades + cascades.append( + AttentionFeatureVarNetBlock( + encoder=self.encoder, + decoder=self.decoder, + acceleration=acceleration, + feature_processor=Unet2d(in_chans=chans, out_chans=chans, num_pool_layers=pools), + attention_layer=AttentionPE(in_chans=chans), + use_extra_feature_conv=use_image_conv, + ) + ) + + self.image_cascades = nn.ModuleList( + [VarNetBlock(NormUnet(chans, pools)) for _ in range(num_cascades)] + ) + + self.decode_norm = nn.InstanceNorm2d(chans) + self.cascades = nn.Sequential(*cascades) + self.norm_fn = NormStats() + + def _decode_output(self, feature_image: FeatureImage) -> Tuple[Tensor, Tensor]: + image = self.decoder( + self.decode_norm(feature_image.features), + means=feature_image.means, + variances=feature_image.variances, + ) + return sens_expand(image, feature_image.sens_maps) + + def _encode_input( + self, + masked_kspace: Tensor, + mask: Tensor, + crop_size: Optional[Tuple[int, int]], + num_low_frequencies: Optional[int], + ) -> FeatureImage: + sens_maps = self.sens_net(masked_kspace, mask, num_low_frequencies) + image = sens_reduce(masked_kspace, sens_maps) + # detect FLAIR 203 + if image.shape[-1] < crop_size[1]: + crop_size = (image.shape[-1], image.shape[-1]) + means, variances = self.norm_fn(image) + features = self.encoder(image, means=means, variances=variances) + + return FeatureImage( + features=features, + sens_maps=sens_maps, + crop_size=crop_size, + means=means, + variances=variances, + ref_kspace=masked_kspace, + mask=mask, + ) + + def forward( + self, + masked_kspace: Tensor, + mask: Tensor, + num_low_frequencies: Optional[int] = None, + crop_size: Optional[Tuple[int, int]] = None, + ) -> Tensor: + masked_kspace = masked_kspace * self.kspace_mult_factor + # Encode to features and get sensitivities + feature_image = self._encode_input(masked_kspace=masked_kspace,mask=mask,crop_size=crop_size,num_low_frequencies=num_low_frequencies) + # Do DC in feature-space + feature_image = self.cascades(feature_image) + # Find last k-space + kspace_pred = self._decode_output(feature_image) + # Run E2EVN + for cascade in self.image_cascades: + kspace_pred = cascade(kspace_pred, feature_image.ref_kspace, mask, feature_image.sens_maps) + # Return Final Image + kspace_pred = kspace_pred / self.kspace_mult_factor + return rss(complex_abs(ifft2c(kspace_pred)), dim=1) + +class IFVarNet(nn.Module): + def __init__( + self, + num_cascades: int = 12, + sens_chans: int = 8, + sens_pools: int = 4, + chans: int = 18, + pools: int = 4, + acceleration: int = 4, + mask_center: bool = True, + image_conv_cascades: Optional[List[int]] = None, + kspace_mult_factor: float = 1e6, + ): + super().__init__() + if image_conv_cascades is None: + image_conv_cascades = [ind for ind in range(num_cascades) if ind % 3 == 0] + + self.image_conv_cascades = image_conv_cascades + self.kspace_mult_factor = kspace_mult_factor + self.sens_net = SensitivityModel( + chans=sens_chans, + num_pools=sens_pools, + mask_center=mask_center, + ) + self.encoder = FeatureEncoder(in_chans=2, feature_chans=chans) + self.decoder = FeatureDecoder(feature_chans=chans, out_chans=2) + cascades = [] + for ind in range(num_cascades): + use_image_conv = ind in self.image_conv_cascades + cascades.append( + AttentionFeatureVarNetBlock( + encoder=self.encoder, + decoder=self.decoder, + acceleration=acceleration, + feature_processor=Unet2d(in_chans=chans, out_chans=chans, num_pool_layers=pools), + attention_layer=AttentionPE(in_chans=chans), + use_extra_feature_conv=use_image_conv, + ) + ) + + self.image_cascades = nn.ModuleList( + [VarNetBlock(NormUnet(chans, pools)) for _ in range(num_cascades)] + ) + + self.decode_norm = nn.InstanceNorm2d(chans) + self.cascades = nn.Sequential(*cascades) + self.norm_fn = NormStats() + + def _decode_output(self, feature_image: FeatureImage) -> Tuple[Tensor, Tensor]: + image = self.decoder( + self.decode_norm(feature_image.features), + means=feature_image.means, + variances=feature_image.variances, + ) + return sens_expand(image, feature_image.sens_maps) + + def _encode_input( + self, + masked_kspace: Tensor, + ref_kspace: Tensor, + sens_maps: Tensor, + mask: Tensor, + crop_size: Optional[Tuple[int, int]], + ) -> FeatureImage: + image = sens_reduce(masked_kspace, sens_maps) + # detect FLAIR 203 + if image.shape[-1] < crop_size[1]: + crop_size = (image.shape[-1], image.shape[-1]) + means, variances = self.norm_fn(image) + features = self.encoder(image, means=means, variances=variances) + + return FeatureImage( + features=features, + sens_maps=sens_maps, + crop_size=crop_size, + means=means, + variances=variances, + ref_kspace=ref_kspace, + mask=mask, + ) + + def forward( + self, + masked_kspace: Tensor, + mask: Tensor, + num_low_frequencies: Optional[int] = None, + crop_size: Optional[Tuple[int, int]] = None, + ) -> Tensor: + + masked_kspace = masked_kspace*self.kspace_mult_factor + + sens_maps = self.sens_net(masked_kspace, mask, num_low_frequencies) + kspace_pred = masked_kspace.clone() + # Run E2EVN + for cascade in self.image_cascades: + kspace_pred = cascade(kspace_pred, masked_kspace, mask, sens_maps) + + feature_image = self._encode_input(masked_kspace=kspace_pred,ref_kspace=masked_kspace,sens_maps=sens_maps,mask=mask,crop_size=crop_size) + feature_image = self.cascades(feature_image) + kspace_pred = self._decode_output(feature_image) + kspace_pred = kspace_pred / self.kspace_mult_factor + + return rss(complex_abs(ifft2c(kspace_pred)), dim=1) + +class FeatureVarNet_sh_w(nn.Module): + def __init__( + self, + num_cascades: int = 12, + sens_chans: int = 8, + sens_pools: int = 4, + chans: int = 18, + pools: int = 4, + mask_center: bool = True, + image_conv_cascades: Optional[List[int]] = None, + kspace_mult_factor: float = 1e6, + ): + super().__init__() + if image_conv_cascades is None: + image_conv_cascades = [ind for ind in range(num_cascades) if ind % 3 == 0] + + self.image_conv_cascades = image_conv_cascades + self.kspace_mult_factor = kspace_mult_factor + self.sens_net = SensitivityModel( + chans=sens_chans, + num_pools=sens_pools, + mask_center=mask_center, + ) + self.encoder = FeatureEncoder(in_chans=2, feature_chans=chans) + self.decoder = FeatureDecoder(feature_chans=chans, out_chans=2) + cascades = [] + for ind in range(num_cascades): + use_image_conv = ind in self.image_conv_cascades + cascades.append( + FeatureVarNetBlock( + encoder=self.encoder, + decoder=self.decoder, + feature_processor=Unet2d(in_chans=chans, out_chans=chans, num_pool_layers=pools), + use_extra_feature_conv=use_image_conv, + ) + ) + + self.decode_norm = nn.InstanceNorm2d(chans) + self.cascades = nn.Sequential(*cascades) + self.norm_fn = NormStats() + + def _decode_output(self, feature_image: FeatureImage) -> Tuple[Tensor, Tensor]: + image = self.decoder( + self.decode_norm(feature_image.features), + means=feature_image.means, + variances=feature_image.variances, + ) + return sens_expand(image, feature_image.sens_maps) + + def _encode_input( + self, + masked_kspace: Tensor, + mask: Tensor, + crop_size: Optional[Tuple[int, int]], + num_low_frequencies: Optional[int], + ) -> FeatureImage: + sens_maps = self.sens_net(masked_kspace, mask, num_low_frequencies) + image = sens_reduce(masked_kspace, sens_maps) + # detect FLAIR 203 + if image.shape[-1] < crop_size[1]: + crop_size = (image.shape[-1], image.shape[-1]) + means, variances = self.norm_fn(image) + features = self.encoder(image, means=means, variances=variances) + + return FeatureImage( + features=features, + sens_maps=sens_maps, + crop_size=crop_size, + means=means, + variances=variances, + ref_kspace=masked_kspace, + mask=mask, + ) + + def forward( + self, + masked_kspace: Tensor, + mask: Tensor, + num_low_frequencies: Optional[int] = None, + crop_size: Optional[Tuple[int, int]] = None, + ) -> Tensor: + masked_kspace = masked_kspace * self.kspace_mult_factor + # Encode to features and get sensitivities + feature_image = self._encode_input(masked_kspace=masked_kspace,mask=mask,crop_size=crop_size,num_low_frequencies=num_low_frequencies) + # Do DC in feature-space + feature_image = self.cascades(feature_image) + # Find last k-space + kspace_pred = self._decode_output(feature_image) + # Return Final Image + kspace_pred = kspace_pred / self.kspace_mult_factor + return rss(complex_abs(ifft2c(kspace_pred)), dim=1) + +class FeatureVarNet_n_sh_w(nn.Module): + def __init__( + self, + num_cascades: int = 12, + sens_chans: int = 8, + sens_pools: int = 4, + chans: int = 18, + pools: int = 4, + mask_center: bool = True, + image_conv_cascades: Optional[List[int]] = None, + kspace_mult_factor: float = 1e6, + ): + super().__init__() + if image_conv_cascades is None: + image_conv_cascades = [ind for ind in range(num_cascades) if ind % 3 == 0] + + self.image_conv_cascades = image_conv_cascades + self.kspace_mult_factor = kspace_mult_factor + self.sens_net = SensitivityModel( + chans=sens_chans, + num_pools=sens_pools, + mask_center=mask_center, + ) + self.encoder = FeatureEncoder(in_chans=2, feature_chans=chans) + self.decoder = FeatureDecoder(feature_chans=chans, out_chans=2) + cascades = [] + for ind in range(num_cascades): + use_image_conv = ind in self.image_conv_cascades + cascades.append( + FeatureVarNetBlock( + encoder=FeatureEncoder(in_chans=2, feature_chans=chans), + decoder=FeatureDecoder(feature_chans=chans, out_chans=2), + feature_processor=Unet2d(in_chans=chans, out_chans=chans, num_pool_layers=pools), + use_extra_feature_conv=use_image_conv, + ) + ) + + self.decode_norm = nn.InstanceNorm2d(chans) + self.cascades = nn.Sequential(*cascades) + self.norm_fn = NormStats() + + def _decode_output(self, feature_image: FeatureImage) -> Tuple[Tensor, Tensor]: + image = self.decoder( + self.decode_norm(feature_image.features), + means=feature_image.means, + variances=feature_image.variances, + ) + return sens_expand(image, feature_image.sens_maps) + + def _encode_input( + self, + masked_kspace: Tensor, + mask: Tensor, + crop_size: Optional[Tuple[int, int]], + num_low_frequencies: Optional[int], + ) -> FeatureImage: + sens_maps = self.sens_net(masked_kspace, mask, num_low_frequencies) + image = sens_reduce(masked_kspace, sens_maps) + # detect FLAIR 203 + if image.shape[-1] < crop_size[1]: + crop_size = (image.shape[-1], image.shape[-1]) + means, variances = self.norm_fn(image) + features = self.encoder(image, means=means, variances=variances) + + return FeatureImage( + features=features, + sens_maps=sens_maps, + crop_size=crop_size, + means=means, + variances=variances, + ref_kspace=masked_kspace, + mask=mask, + ) + + def forward( + self, + masked_kspace: Tensor, + mask: Tensor, + num_low_frequencies: Optional[int] = None, + crop_size: Optional[Tuple[int, int]] = None, + ) -> Tensor: + masked_kspace = masked_kspace * self.kspace_mult_factor + # Encode to features and get sensitivities + feature_image = self._encode_input(masked_kspace=masked_kspace,mask=mask,crop_size=crop_size,num_low_frequencies=num_low_frequencies) + # Do DC in feature-space + feature_image = self.cascades(feature_image) + # Find last k-space + kspace_pred = self._decode_output(feature_image) + # Return Final Image + kspace_pred = kspace_pred / self.kspace_mult_factor + return rss(complex_abs(ifft2c(kspace_pred)), dim=1) + +class AttentionFeatureVarNet_n_sh_w(nn.Module): + def __init__( + self, + num_cascades: int = 12, + sens_chans: int = 8, + sens_pools: int = 4, + chans: int = 18, + pools: int = 4, + acceleration: int = 4, + mask_center: bool = True, + image_conv_cascades: Optional[List[int]] = None, + kspace_mult_factor: float = 1e6, + ): + super().__init__() + if image_conv_cascades is None: + image_conv_cascades = [ind for ind in range(num_cascades) if ind % 3 == 0] + + self.image_conv_cascades = image_conv_cascades + self.kspace_mult_factor = kspace_mult_factor + self.sens_net = SensitivityModel( + chans=sens_chans, + num_pools=sens_pools, + mask_center=mask_center, + ) + self.encoder = FeatureEncoder(in_chans=2, feature_chans=chans) + self.decoder = FeatureDecoder(feature_chans=chans, out_chans=2) + cascades = [] + for ind in range(num_cascades): + use_image_conv = ind in self.image_conv_cascades + cascades.append( + AttentionFeatureVarNetBlock( + encoder=self.encoder, + decoder=self.decoder, + acceleration=acceleration, + feature_processor=Unet2d(in_chans=chans, out_chans=chans, num_pool_layers=pools), + attention_layer=AttentionPE(in_chans=chans), + use_extra_feature_conv=use_image_conv, + ) + ) + + self.decode_norm = nn.InstanceNorm2d(chans) + self.cascades = nn.Sequential(*cascades) + self.norm_fn = NormStats() + + def _decode_output(self, feature_image: FeatureImage) -> Tuple[Tensor, Tensor]: + image = self.decoder( + self.decode_norm(feature_image.features), + means=feature_image.means, + variances=feature_image.variances, + ) + return sens_expand(image, feature_image.sens_maps) + + def _encode_input( + self, + masked_kspace: Tensor, + mask: Tensor, + crop_size: Optional[Tuple[int, int]], + num_low_frequencies: Optional[int], + ) -> FeatureImage: + sens_maps = self.sens_net(masked_kspace, mask, num_low_frequencies) + image = sens_reduce(masked_kspace, sens_maps) + # detect FLAIR 203 + if image.shape[-1] < crop_size[1]: + crop_size = (image.shape[-1], image.shape[-1]) + means, variances = self.norm_fn(image) + features = self.encoder(image, means=means, variances=variances) + + return FeatureImage( + features=features, + sens_maps=sens_maps, + crop_size=crop_size, + means=means, + variances=variances, + ref_kspace=masked_kspace, + mask=mask, + ) + + def forward( + self, + masked_kspace: Tensor, + mask: Tensor, + num_low_frequencies: Optional[int] = None, + crop_size: Optional[Tuple[int, int]] = None, + ) -> Tensor: + masked_kspace = masked_kspace * self.kspace_mult_factor + # Encode to features and get sensitivities + feature_image = self._encode_input(masked_kspace=masked_kspace,mask=mask,crop_size=crop_size,num_low_frequencies=num_low_frequencies) + # Do DC in feature-space + feature_image = self.cascades(feature_image) + # Find last k-space + kspace_pred = self._decode_output(feature_image) + # Return Final Image + kspace_pred = kspace_pred / self.kspace_mult_factor + return rss(complex_abs(ifft2c(kspace_pred)), dim=1) + +class E2EVarNet(nn.Module): + """ + A full variational network model. + + This model applies a combination of soft data consistency with a U-Net + regularizer. To use non-U-Net regularizers, use VarNetBlock. + """ + + def __init__( + self, + num_cascades: int = 12, + sens_chans: int = 8, + sens_pools: int = 4, + chans: int = 18, + pools: int = 4, + mask_center: bool = True, + ): + """ + Args: + num_cascades: Number of cascades (i.e., layers) for variational + network. + sens_chans: Number of channels for sensitivity map U-Net. + sens_pools Number of downsampling and upsampling layers for + sensitivity map U-Net. + chans: Number of channels for cascade U-Net. + pools: Number of downsampling and upsampling layers for cascade + U-Net. + mask_center: Whether to mask center of k-space for sensitivity map + calculation. + """ + super().__init__() + + self.sens_net = SensitivityModel( + chans=sens_chans, + num_pools=sens_pools, + mask_center=mask_center, + ) + self.cascades = nn.ModuleList( + [VarNetBlock(NormUnet(chans, pools)) for _ in range(num_cascades)] + ) + + def forward( + self, + masked_kspace: torch.Tensor, + mask: torch.Tensor, + num_low_frequencies: Optional[int] = None, + crop_size: Optional[Tuple[int, int]] = None, + ) -> torch.Tensor: + sens_maps = self.sens_net(masked_kspace, mask, num_low_frequencies) + kspace_pred = masked_kspace.clone() + + for cascade in self.cascades: + kspace_pred = cascade(kspace_pred, masked_kspace, mask, sens_maps) + + return rss(complex_abs(ifft2c(kspace_pred)), dim=1) + +class AttentionFeatureVarNetBlock(nn.Module): + def __init__( + self, + encoder: FeatureEncoder, + decoder: FeatureDecoder, + acceleration: int, + feature_processor: Unet2d, + attention_layer: AttentionPE, + use_extra_feature_conv: bool = False, + ): + super().__init__() + self.encoder = encoder + self.decoder = decoder + self.feature_processor = feature_processor + self.attention_layer = attention_layer + self.use_image_conv = use_extra_feature_conv + self.dc_weight = nn.Parameter(torch.ones(1)) + feature_chans = self.encoder.feature_chans + self.acceleration = acceleration + + self.input_norm = nn.InstanceNorm2d(feature_chans) + self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + if use_extra_feature_conv: + self.output_norm = nn.InstanceNorm2d(feature_chans) + self.output_conv = nn.Sequential( + nn.Conv2d( + in_channels=feature_chans, + out_channels=feature_chans, + kernel_size=5, + padding=2, + bias=False, + ), + nn.InstanceNorm2d(feature_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d( + in_channels=feature_chans, + out_channels=feature_chans, + kernel_size=5, + padding=2, + bias=False, + ), + nn.InstanceNorm2d(feature_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + self.zero: Tensor + self.register_buffer("zero", torch.zeros(1, 1, 1, 1, 1)) + + def encode_from_kspace(self, kspace: Tensor, feature_image: FeatureImage) -> Tensor: + image = sens_reduce(kspace, feature_image.sens_maps) + + return self.encoder( + image, means=feature_image.means, variances=feature_image.variances + ) + + def decode_to_kspace(self, feature_image: FeatureImage) -> Tensor: + image = self.decoder( + feature_image.features, + means=feature_image.means, + variances=feature_image.variances, + ) + + return sens_expand(image, feature_image.sens_maps) + + def compute_dc_term(self, feature_image: FeatureImage) -> Tensor: + est_kspace = self.decode_to_kspace(feature_image) + + return self.dc_weight * self.encode_from_kspace( + torch.where( + feature_image.mask, est_kspace - feature_image.ref_kspace, self.zero + ), + feature_image, + ) + + def apply_model_with_crop(self, feature_image: FeatureImage) -> Tensor: + if feature_image.crop_size is not None: + features = image_uncrop( + self.feature_processor( + image_crop(feature_image.features, feature_image.crop_size) + ), + feature_image.features.clone(), + ) + else: + features = self.feature_processor(feature_image.features) + + return features + + def forward(self, feature_image: FeatureImage) -> FeatureImage: + feature_image = feature_image._replace( + features=self.input_norm(feature_image.features) + ) + + new_features = feature_image.features - self.compute_dc_term(feature_image) + """ + new_features_np = feature_image.features.cpu().numpy() + timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + file_name = f'new_features_before_{timestamp}.mat' + savemat(file_name, {'new_features_before': new_features_np}) + + new_ref_kspace = feature_image.ref_kspace.cpu().numpy() + timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + file_name = f'kspace_{timestamp}.mat' + savemat(file_name, {'kspace_': new_ref_kspace}) + """ + feature_image = feature_image._replace(features=self.attention_layer(feature_image.features,self.acceleration)) + new_features = new_features - self.apply_model_with_crop(feature_image) + + if self.use_image_conv: + new_features = self.output_norm(new_features) + new_features = new_features + self.output_conv(new_features) + + return feature_image._replace(features=new_features) + +class FeatureVarNetBlock(nn.Module): + def __init__( + self, + encoder: FeatureEncoder, + decoder: FeatureDecoder, + feature_processor: Unet2d, + use_extra_feature_conv: bool = False, + ): + super().__init__() + self.encoder = encoder + self.decoder = decoder + self.feature_processor = feature_processor + self.use_image_conv = use_extra_feature_conv + self.dc_weight = nn.Parameter(torch.ones(1)) + feature_chans = self.encoder.feature_chans + + self.input_norm = nn.InstanceNorm2d(feature_chans) + self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + if use_extra_feature_conv: + self.output_norm = nn.InstanceNorm2d(feature_chans) + self.output_conv = nn.Sequential( + nn.Conv2d( + in_channels=feature_chans, + out_channels=feature_chans, + kernel_size=5, + padding=2, + bias=False, + ), + nn.InstanceNorm2d(feature_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d( + in_channels=feature_chans, + out_channels=feature_chans, + kernel_size=5, + padding=2, + bias=False, + ), + nn.InstanceNorm2d(feature_chans), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + self.zero: Tensor + self.register_buffer("zero", torch.zeros(1, 1, 1, 1, 1)) + + def encode_from_kspace(self, kspace: Tensor, feature_image: FeatureImage) -> Tensor: + image = sens_reduce(kspace, feature_image.sens_maps) + + return self.encoder( + image, means=feature_image.means, variances=feature_image.variances + ) + + def decode_to_kspace(self, feature_image: FeatureImage) -> Tensor: + image = self.decoder( + feature_image.features, + means=feature_image.means, + variances=feature_image.variances, + ) + + return sens_expand(image, feature_image.sens_maps) + + def compute_dc_term(self, feature_image: FeatureImage) -> Tensor: + est_kspace = self.decode_to_kspace(feature_image) + + return self.dc_weight * self.encode_from_kspace( + torch.where( + feature_image.mask, est_kspace - feature_image.ref_kspace, self.zero + ), + feature_image, + ) + + def apply_model_with_crop(self, feature_image: FeatureImage) -> Tensor: + if feature_image.crop_size is not None: + features = image_uncrop( + self.feature_processor( + image_crop(feature_image.features, feature_image.crop_size) + ), + feature_image.features.clone(), + ) + else: + features = self.feature_processor(feature_image.features) + + return features + + def forward(self, feature_image: FeatureImage) -> FeatureImage: + feature_image = feature_image._replace( + features=self.input_norm(feature_image.features) + ) + + new_features = feature_image.features - self.compute_dc_term(feature_image) - self.apply_model_with_crop(feature_image) + + if self.use_image_conv: + new_features = self.output_norm(new_features) + new_features = new_features + self.output_conv(new_features) + + return feature_image._replace(features=new_features) + +class VarNetBlock(nn.Module): + """ + Model block for end-to-end variational network. + + This model applies a combination of soft data consistency with the input + model as a regularizer. A series of these blocks can be stacked to form + the full variational network. + """ + + def __init__(self, model: nn.Module): + """ + Args: + model: Module for "regularization" component of variational + network. + """ + super().__init__() + + self.model = model + self.dc_weight = nn.Parameter(torch.ones(1)) + + def sens_expand(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: + return fft2c(complex_mul(x, sens_maps)) + + def sens_reduce(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor: + return complex_mul(ifft2c(x), complex_conj(sens_maps)).sum(dim=1, keepdim=True) + + def forward( + self, + current_kspace: torch.Tensor, + ref_kspace: torch.Tensor, + mask: torch.Tensor, + sens_maps: torch.Tensor, + ) -> torch.Tensor: + zero = torch.zeros(1, 1, 1, 1, 1).to(current_kspace) + soft_dc = torch.where(mask, current_kspace - ref_kspace, zero) * self.dc_weight + + model_term = self.sens_expand( + self.model(self.sens_reduce(current_kspace, sens_maps)), sens_maps + ) + + return current_kspace - soft_dc - model_term diff --git a/fastmri_examples/README.md b/fastmri_examples/README.md index d5caf89a..17493c2e 100644 --- a/fastmri_examples/README.md +++ b/fastmri_examples/README.md @@ -15,3 +15,4 @@ further details. * [End-to-End Variational Networks for Accelerated MRI Reconstruction ({A. Sriram*, J. Zbontar*} et al., 2020)](varnet/) * [On learning adaptive acquisition policies for undersampled multi-coil MRI reconstruction (T. Bakker et al., 2021)](adaptive_varnet/) + * [Accelerated MRI reconstructions via variational network and feature domain learning (I. Giannakopoulos et al., 2024)](feature_varnet/) diff --git a/fastmri_examples/feature_varnet/README.md b/fastmri_examples/feature_varnet/README.md new file mode 100644 index 00000000..18066325 --- /dev/null +++ b/fastmri_examples/feature_varnet/README.md @@ -0,0 +1,72 @@ +# Accelerated MRI reconstructions via variational network and feature domain learning + +This directory contains a PyTorch implementation for reproducing the following paper, to be published at MIDL 2022. + +[Accelerated MRI reconstructions via variational network and feature domain learning (I. Giannakopoulos, et al., 2024).][feature_varnet] + +## Installation +We **strongly** recommend creating a separate conda environment for this example, as the +PyTorch Lightning versions required differs from that of the base `fastmri` installation. + +Before installing dependencies, first install PyTorch according to the directions at the +PyTorch Website for your operating system and CUDA setup +(we used `torch` version 1.7.0 for our experiments). Then run + +```bash +pip install -r fastmri_examples/feature_varnet/requirements.txt +``` + + +## Example training commands: + +This code provides a few ablations of the end-to-end variational network, namely, feature varnet with weight sharing, feature varnet without weight sharing, attention feature varnet with weight sharing, feature-image varnet, and image-feature varnet. Train and test each model with the same commands as the end-to-end variational network and include an additional input argument to your input file: +For the end-to-end varnet +> --varnet_type e2e_varnet + +For the feature varnet with weight sharing +> --varnet_type feature_varnet_sh_w + +For the feature varnet without weight sharing +> --varnet_type feature_varnet_n_sh_w + +For the attention feature varnet with weight sharing +> --varnet_type attention_feature_varnet_sh_w + +For the feature-image varnet +> --varnet_type fi_varnet + +For the image-feature varnet +> --varnet_type if_varnet + +See `train_feature_varnet.py` for additional arguments. + + +## Example evaluation commands: + +Evaluate the model as the end-to-end varnet + + +## Paths: + +Data and log paths are defined the fastmri_dirs.yaml + + +## Citing + +If you use this this code in your research, please cite the corresponding +paper: + +```BibTeX +@article{giannakopoulos2024accelerated, + title={Accelerated MRI reconstructions via variational network and feature domain learning}, + author={Giannakopoulos, Ilias I and Muckley, Matthew J and Kim, Jesi and Breen, Matthew and Johnson, Patricia M and Lui, Yvonne W and Lattanzi, Riccardo}, + journal={Scientific Reports}, + volume={14}, + number={1}, + pages={10991}, + year={2024}, + publisher={Nature Publishing Group UK London} +} +``` + +[feature_varnet]: https://www.nature.com/articles/s41598-024-59705-0 diff --git a/fastmri_examples/feature_varnet/pl_modules/__init__.py b/fastmri_examples/feature_varnet/pl_modules/__init__.py new file mode 100644 index 00000000..7aaea906 --- /dev/null +++ b/fastmri_examples/feature_varnet/pl_modules/__init__.py @@ -0,0 +1,9 @@ + +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from .feature_varnet_module import FIVarNetModule diff --git a/fastmri_examples/feature_varnet/pl_modules/feature_varnet_module.py b/fastmri_examples/feature_varnet/pl_modules/feature_varnet_module.py new file mode 100644 index 00000000..ccd99596 --- /dev/null +++ b/fastmri_examples/feature_varnet/pl_modules/feature_varnet_module.py @@ -0,0 +1,158 @@ +import math +from argparse import ArgumentParser + +import torch +torch.set_float32_matmul_precision('high') +import torch.nn as nn +from fastmri.pl_modules.mri_module import MriModule +from fastmri.losses import SSIMLoss +from fastmri.data.transforms import center_crop_to_smallest, center_crop +from fastmri.models import FIVarNet, IFVarNet, FeatureVarNet_sh_w, FeatureVarNet_n_sh_w, AttentionFeatureVarNet_n_sh_w, E2EVarNet + +class FIVarNetModule(MriModule): + def __init__( + self, + fi_varnet: FIVarNet, + lr: float = 0.0003, + weight_decay: float = 0.0, + max_steps: int = 65450, + ramp_steps: int = 2618, + cosine_decay_start: int = 32725, + **kwargs, + ): + super().__init__(**kwargs) + self.lr = lr + self.max_steps = max_steps + self.ramp_steps = ramp_steps + self.cosine_decay_start = cosine_decay_start + self.weight_decay = weight_decay + self.fi_varnet = fi_varnet + self.loss = SSIMLoss() + + def forward(self, masked_kspace, mask, num_low_frequencies): + return self.fi_varnet(masked_kspace, mask, num_low_frequencies) + + def training_step(self, batch, batch_idx): + output = self.fi_varnet(batch.masked_kspace,batch.mask,batch.num_low_frequencies,crop_size=batch.crop_size) + target, output = center_crop_to_smallest(batch.target, output) + loss = self.loss( + output.unsqueeze(1), target.unsqueeze(1).float(), data_range=batch.max_value + ) + self.log("train_loss", loss, sync_dist=True) + return loss + + def on_before_optimizer_step(self, optimizer, optimizer_idx): + for name, param in self.fi_varnet.named_parameters(): + if param.grad is not None: + self.log(f"grads/{name}", torch.norm(param.grad)) + + def validation_step(self, batch, batch_idx): + output = self.fi_varnet(batch.masked_kspace,batch.mask,batch.num_low_frequencies,crop_size=batch.crop_size) + target, output = center_crop_to_smallest(batch.target, output) + return { + "batch_idx": batch_idx, + "fname": batch.fname, + "slice_num": batch.slice_num, + "max_value": batch.max_value, + "output": output, + "target": target, + "val_loss": self.loss( + output.unsqueeze(1), target.unsqueeze(1).float(), data_range=batch.max_value + ), + } + + def test_step(self, batch, batch_idx): + output = self.fi_varnet(batch.masked_kspace,batch.mask,batch.num_low_frequencies,crop_size=batch.crop_size) + if output.shape[-1] < batch.crop_size[1]: + crop_size = (output.shape[-1], output.shape[-1]) + else: + crop_size = batch.crop_size + output = center_crop(output, crop_size) + return { + "fname": batch.fname, + "slice": batch.slice_num, + "output": output.cpu().numpy(), + } + + def configure_optimizers(self): + cosine_steps = self.max_steps - self.cosine_decay_start + def step_fn(step): + if step < self.cosine_decay_start: + return min(step / self.ramp_steps, 1.0) + else: + angle = (step - self.cosine_decay_start) / cosine_steps * math.pi / 2 + return max(math.cos(angle), 1e-8) + + optimizer = torch.optim.AdamW( + self.parameters(), + lr=self.lr, + weight_decay=self.weight_decay, + ) + lr_scheduler_config = { + "scheduler": torch.optim.lr_scheduler.LambdaLR(optimizer, step_fn), + "interval": "step", + } + return [optimizer], [lr_scheduler_config] + + @staticmethod + def add_model_specific_args(parent_parser): # pragma: no-cover + parser = ArgumentParser(parents=[parent_parser], add_help=False) + parser = MriModule.add_model_specific_args(parser) + parser.add_argument( + "--num_cascades", + default=12, + type=int, + help="Number of VarNet cascades", + ) + parser.add_argument( + "--pools", + default=4, + type=int, + help="Number of U-Net pooling layers in VarNetFiLM blocks", + ) + parser.add_argument( + "--chans", + default=18, + type=int, + help="Number of channels for U-Net in VarNetFiLM blocks", + ) + parser.add_argument( + "--sens_pools", + default=4, + type=int, + help="Number of pooling layers for sense map estimation U-Net in VarNetFiLM", + ) + parser.add_argument( + "--sens_chans", + default=8, + type=float, + help="Number of channels for sense map estimation U-Net in VarNetFiLM", + ) + parser.add_argument( + "--lr", default=0.0003, type=float, help="Adam learning rate" + ) + parser.add_argument( + "--lr_step_size", + default=40, + type=int, + help="Epoch at which to decrease step size", + ) + parser.add_argument( + "--ramp_steps", + default=2618, + type=int, + help="Number of steps for ramping learning rate", + ) + parser.add_argument( + "--cosine_decay_start", + default=32725, + type=int, + help="Step at which to start cosine lr decay", + ) + parser.add_argument( + "--weight_decay", + default=0.0, + type=float, + help="Strength of weight decay regularization", + ) + return parser diff --git a/fastmri_examples/feature_varnet/requirements.txt b/fastmri_examples/feature_varnet/requirements.txt new file mode 100644 index 00000000..b4d5d2bb --- /dev/null +++ b/fastmri_examples/feature_varnet/requirements.txt @@ -0,0 +1,11 @@ + +numpy==1.18.5 +scikit_image==0.16.2 +torchvision==0.8.1 +torch==1.7.0 +runstats==2.0.0 +pytorch_lightning==1.0.6 +h5py==2.10.0 +PyYAML==5.4.1 +torchmetrics==0.3.2 +wandb==0.12.7 diff --git a/fastmri_examples/feature_varnet/train_feature_varnet.py b/fastmri_examples/feature_varnet/train_feature_varnet.py new file mode 100644 index 00000000..8f66f5f0 --- /dev/null +++ b/fastmri_examples/feature_varnet/train_feature_varnet.py @@ -0,0 +1,270 @@ +import os +import torch +torch.set_float32_matmul_precision('high') +import pathlib +from argparse import ArgumentParser +from pathlib import Path +from typing import Optional +import pytorch_lightning as pl +from pytorch_lightning.loggers import TensorBoardLogger +from fastmri.models.feature_varnet import FIVarNet, IFVarNet, FeatureVarNet_sh_w, FeatureVarNet_n_sh_w, E2EVarNet, AttentionFeatureVarNet_n_sh_w +from pl_modules import FIVarNetModule +from fastmri.data.subsample import create_mask_for_mask_type +from fastmri.data.transforms import VarNetDataTransform +from fastmri.data.mri_data import fetch_dir +from fastmri.pl_modules.data_module import FastMriDataModule +import subprocess + +def check_gpu_availability(): + command = "nvidia-smi --query-gpu=index --format=csv,noheader | wc -l" + output = subprocess.check_output(command, shell=True).decode("utf-8").strip() + return int(output) + +def reload_state_dict( + module: FIVarNetModule, fname: Path, module_name: str = "fi_varnet." +): + print(f"loading model from {fname}") + lm = len(module_name) + state_dict = torch.load(fname, map_location=torch.device("cpu"))["state_dict"] + state_dict = {k[lm:]: v for k, v in state_dict.items() if k[:lm] == module_name} + module.fi_varnet.load_state_dict(state_dict) + return module + +def fetch_model(args, acceleration): + if args.varnet_type == "fi_varnet": + print(f"BUILDING FI VARNET, chans={args.chans}") + return FIVarNet( + num_cascades=args.num_cascades, + pools=args.pools, + chans=args.chans, + sens_pools=args.sens_pools, + sens_chans=args.sens_chans, + acceleration=acceleration, + ) + if args.varnet_type == "if_varnet": + print(f"BUILDING IF VARNET, chans={args.chans}") + return IFVarNet( + num_cascades=args.num_cascades, + pools=args.pools, + chans=args.chans, + sens_pools=args.sens_pools, + sens_chans=args.sens_chans, + acceleration=acceleration, + ) + elif args.varnet_type == "attention_feature_varnet_sh_w": + print(f"BUILDING ATTENTION FEATURE VARNET WITH WEIGHT SHARING, chans={args.chans}") + return AttentionFeatureVarNet_n_sh_w( + num_cascades=args.num_cascades, + pools=args.pools, + chans=args.chans, + sens_pools=args.sens_pools, + sens_chans=args.sens_chans, + acceleration=acceleration, + ) + elif args.varnet_type == "feature_varnet_n_sh_w": + print(f"BUILDING FEATURE VARNET WITHOUT WEIGHT SHARING, chans={args.chans}") + return FeatureVarNet_n_sh_w( + num_cascades=args.num_cascades, + pools=args.pools, + chans=args.chans, + sens_pools=args.sens_pools, + sens_chans=args.sens_chans, + ) + elif args.varnet_type == "feature_varnet_sh_w": + print(f"BUILDING FEATURE VARNET WITH WEIGHT SHARING, chans={args.chans}") + return FeatureVarNet_sh_w( + num_cascades=args.num_cascades, + pools=args.pools, + chans=args.chans, + sens_pools=args.sens_pools, + sens_chans=args.sens_chans, + ) + elif args.varnet_type == "e2e_varnet": + print(f"BUILDING E2E VARNET, chans={args.chans}") + return E2EVarNet( + num_cascades=args.num_cascades, + pools=args.pools, + chans=args.chans, + sens_pools=args.sens_pools, + sens_chans=args.sens_chans, + ) + else: + raise ValueError("Unrecognized varnet_type") + +def cli_main(args): + pl.seed_everything(args.seed) + + mask = create_mask_for_mask_type( + args.mask_type, args.center_fractions, args.accelerations + ) + train_transform = VarNetDataTransform(mask_func=mask, use_seed=False) + val_transform = VarNetDataTransform(mask_func=mask) + + if args.mode == "test_val": + args.mode = "test" + test_transform = VarNetDataTransform(mask_func=mask) + else: + test_transform = VarNetDataTransform() + + data_module = FastMriDataModule( + data_path=args.data_path, + challenge=args.challenge, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + combine_train_val=True, + test_split=args.test_split, + test_path=args.test_path, + sample_rate=args.sample_rate, + batch_size=args.batch_size, + num_workers=args.num_workers, + distributed_sampler=(args.accelerator in ("ddp", "ddp_cpu")), + ) + + acceleration_mean = int(round(sum(args.accelerations) / len(args.accelerations))) + print(acceleration_mean) + pl_module = FIVarNetModule( + fi_varnet=fetch_model(args, acceleration_mean), + lr=args.lr, + weight_decay=args.weight_decay, + max_steps=args.max_steps, + ramp_steps=args.ramp_steps, + cosine_decay_start=args.cosine_decay_start, + ) + + if args.resume_from_checkpoint is not None: + pl_module = reload_state_dict(pl_module, args.resume_from_checkpoint) + trainer = pl.Trainer.from_argparse_args(args) + if args.mode == "train": + trainer.fit(pl_module, datamodule=data_module) + elif args.mode == "test": + trainer.test(pl_module, datamodule=data_module) + else: + raise ValueError(f"unrecognized mode {args.mode}") + +def build_args(model_name: Optional[str] = "VarNet DDP x4", cluster_launch: bool = True): + parser = ArgumentParser() + path_config = pathlib.Path("./fastmri_dirs.yaml") + backend = "ddp" + num_gpus = check_gpu_availability() if backend == "ddp" else 1 + batch_size = 1 + data_path = fetch_dir("data_path", path_config) + parser.add_argument( + "--mode", + default="train", + choices=("train", "test", "test_val"), + type=str, + help="Operation mode", + ) + parser.add_argument( + "--mask_type", + choices=("random", "equispaced", "equispaced_fraction"), + default="equispaced_fraction", + type=str, + help="Type of k-space mask", + ) + parser.add_argument( + "--center_fractions", + nargs="+", + default=[0.08], + type=float, + help="Number of center lines to use in mask", + ) + parser.add_argument( + "--accelerations", + nargs="+", + default=[4], + type=int, + help="Acceleration rates to use for masks", + ) + parser.add_argument( + "--varnet_type", + choices=("fi_varnet","if_varnet","feature_varnet_sh_w","feature_varnet_n_sh_w","attention_feature_varnet_sh_w","e2e_varnet"), + default="fi_varnet", + type=str, + help="Type of VarNet to use", + ) + + parser = FastMriDataModule.add_data_specific_args(parser) + + args, _ = parser.parse_known_args() + if args.mode == "test" or args.mode == "test_val": + num_gpus = 1 + if args.varnet_type == "e2e_varnet": + default_root_dir = (fetch_dir("log_path", path_config) / "e2e_varnet") + if args.varnet_type == "fi_varnet": + default_root_dir = (fetch_dir("log_path", path_config) / "fi_varnet") + if args.varnet_type == "if_varnet": + default_root_dir = (fetch_dir("log_path", path_config) / "if_varnet") + elif args.varnet_type == "feature_varnet_sh_w": + default_root_dir = (fetch_dir("log_path", path_config) / "feature_varnet_sh_w") + elif args.varnet_type == "feature_varnet_n_sh_w": + default_root_dir = (fetch_dir("log_path", path_config) / "feature_varnet_n_sh_w") + elif args.varnet_type == "attention_feature_varnet_sh_w": + default_root_dir = (fetch_dir("log_path", path_config) / "attention_feature_varnet_sh_w") + + parser.set_defaults( + data_path=data_path, # path to fastMRI data + mask_type="equispaced_fraction", # knee uses equispaced mask + challenge="multicoil", # only multicoil implemented for VarNet + batch_size=batch_size, # number of samples per batch + test_path=None, # path for test split, overwrites data_path + ) + + parser = FIVarNetModule.add_model_specific_args(parser) + + parser.set_defaults( + num_cascades=12, # number of unrolled iterations + pools=4, # number of pooling layers for U-Net + chans=32, # number of top-level channels for U-Net + sens_pools=4, # number of pooling layers for sense est. U-Net + sens_chans=8, # number of top-level channels for sense est. U-Net + lr=0.0003, # Adam learning rate + ramp_steps=7500, + cosine_decay_start=150000,#150000, + weight_decay=0.0, # weight regularization strength + ) + parser = pl.Trainer.add_argparse_args(parser) + parser.set_defaults( + devices=num_gpus, # number of gpus to use + replace_sampler_ddp=True, # this is necessary for volume dispatch during val + accelerator="gpu", # what distributed version to use + strategy="ddp_find_unused_parameters_false", # what distributed version to use + seed=42, # random seed + # deterministic=True, # makes things slower, but deterministic + default_root_dir=default_root_dir, # directory for logs and checkpoints + max_steps=210000,#210000, # number of steps for 50 knee epochs + detect_anomaly=False, + gradient_clip_val=1.0, + ) + args = parser.parse_args() + print(f"MODEL NAME: {model_name}") + args.logger = TensorBoardLogger( + save_dir=args.default_root_dir, version=f"{model_name}" + ) + checkpoint_dir = args.default_root_dir / "checkpoints" / f"{model_name}" + if not checkpoint_dir.exists(): + checkpoint_dir.mkdir(parents=True) + args.callbacks = [ + pl.callbacks.ModelCheckpoint( + dirpath=checkpoint_dir, + save_last=True, + save_top_k=True, + verbose=True, + monitor="validation_loss", + mode="min", + ), + pl.callbacks.LearningRateMonitor(), + ] + if args.resume_from_checkpoint is None: + ckpt_list = sorted(checkpoint_dir.glob("*.ckpt"), key=os.path.getmtime) + if ckpt_list: + args.resume_from_checkpoint = str(ckpt_list[-1]) + return args + +def run_cli(): + args = build_args(cluster_launch=True) + cli_main(args) + +if __name__ == "__main__": + run_cli() From c0ba931ddbc22eed0837acc1739ca5f20bf0be8a Mon Sep 17 00:00:00 2001 From: GiannakopoulosIlias Date: Wed, 12 Jun 2024 13:29:33 -0400 Subject: [PATCH 2/6] Included copyright header and fixed linter tests --- fastmri/models/feature_varnet.py | 181 ++++++++++++++---- .../{pl_modules => }/feature_varnet_module.py | 46 ++++- .../feature_varnet/pl_modules/__init__.py | 9 - .../feature_varnet/requirements.txt | 11 -- .../feature_varnet/train_feature_varnet.py | 63 ++++-- 5 files changed, 231 insertions(+), 79 deletions(-) rename fastmri_examples/feature_varnet/{pl_modules => }/feature_varnet_module.py (82%) delete mode 100644 fastmri_examples/feature_varnet/pl_modules/__init__.py delete mode 100644 fastmri_examples/feature_varnet/requirements.txt diff --git a/fastmri/models/feature_varnet.py b/fastmri/models/feature_varnet.py index a325ea6e..20d637f5 100644 --- a/fastmri/models/feature_varnet.py +++ b/fastmri/models/feature_varnet.py @@ -1,9 +1,17 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + from typing import NamedTuple, Optional, Tuple, List import math import torch import torch.nn as nn from torch import Tensor -torch.set_float32_matmul_precision('high') + +torch.set_float32_matmul_precision("high") import torch.nn.functional as F import torch.distributed as dist import numpy as np @@ -14,11 +22,13 @@ from fastmri.coil_combine import rss_complex, rss from fastmri.math import complex_abs, complex_mul, complex_conj + def image_crop(image: Tensor, crop_size: Optional[Tuple[int, int]] = None) -> Tensor: if crop_size is None: return image return center_crop(image, crop_size).contiguous() + def _calc_uncrop(crop_height: int, in_height: int) -> Tuple[int, int]: pad_height = (in_height - crop_height) // 2 if (in_height - crop_height) % 2 != 0: @@ -30,6 +40,7 @@ def _calc_uncrop(crop_height: int, in_height: int) -> Tuple[int, int]: return pad_height_top, pad_height + def image_uncrop(image: Tensor, original_image: Tensor) -> Tensor: """Insert values back into original image.""" in_shape = original_image.shape @@ -42,47 +53,52 @@ def image_uncrop(image: Tensor, original_image: Tensor) -> Tensor: pad_height_left, pad_width = _calc_uncrop(image.shape[-1], in_shape[-1]) try: - original_image[ - ..., pad_height_top:pad_height, pad_height_left:pad_width - ] = image[...] + original_image[..., pad_height_top:pad_height, pad_height_left:pad_width] = ( + image[...] + ) except RuntimeError: print(f"in_shape: {in_shape}, image shape: {image.shape}") raise return original_image + def norm_fn(image: Tensor, means: Tensor, variances: Tensor) -> Tensor: means = means.view(1, -1, 1, 1) variances = variances.view(1, -1, 1, 1) return (image - means) * torch.rsqrt(variances) + def unnorm_fn(image: Tensor, means: Tensor, variances: Tensor) -> Tensor: means = means.view(1, -1, 1, 1) variances = variances.view(1, -1, 1, 1) return image * torch.sqrt(variances) + means + def complex_to_chan_dim(x: Tensor) -> Tensor: b, c, h, w, two = x.shape assert two == 2 assert c == 1 return x.permute(0, 4, 1, 2, 3).reshape(b, 2 * c, h, w) + def chan_complex_to_last_dim(x: Tensor) -> Tensor: b, c2, h, w = x.shape assert c2 == 2 c = c2 // 2 return x.view(b, 2, c, h, w).permute(0, 2, 3, 4, 1).contiguous() + def sens_expand(x: Tensor, sens_maps: Tensor) -> Tensor: return fft2c(complex_mul(chan_complex_to_last_dim(x), sens_maps)) + def sens_reduce(x: Tensor, sens_maps: Tensor) -> Tensor: return complex_to_chan_dim( - complex_mul(ifft2c(x), complex_conj(sens_maps)).sum( - dim=1, keepdim=True - ) + complex_mul(ifft2c(x), complex_conj(sens_maps)).sum(dim=1, keepdim=True) ) + class NormStats(nn.Module): def forward(self, data: Tensor) -> Tuple[Tensor, Tensor]: # group norm @@ -103,6 +119,7 @@ def forward(self, data: Tensor) -> Tuple[Tensor, Tensor]: return mean, variance + class RunningChannelStats(nn.Module): def __init__(self, chans: int, eps: float = 1e-14, freeze_step: int = 20000): super().__init__() @@ -143,6 +160,7 @@ def forward(self, image: Tensor) -> Tuple[Tensor, Tensor, Tensor]: return run_mean, run_var + class FeatureImage(NamedTuple): features: Tensor sens_maps: Tensor = None @@ -154,6 +172,7 @@ class FeatureImage(NamedTuple): beta: Optional[Tensor] = None gamma: Optional[Tensor] = None + class FeatureEncoder(nn.Module): def __init__(self, in_chans: int, feature_chans: int = 32, drop_prob: float = 0.0): super().__init__() @@ -174,6 +193,7 @@ def forward(self, image: Tensor, means: Tensor, variances: Tensor) -> Tensor: variances = variances.view(1, -1, 1, 1) return self.encoder((image - means) * torch.rsqrt(variances)) + class FeatureDecoder(nn.Module): def __init__(self, feature_chans: int = 32, out_chans: int = 2): super().__init__() @@ -192,6 +212,7 @@ def forward(self, features: Tensor, means: Tensor, variances: Tensor) -> Tensor: variances = variances.view(1, -1, 1, 1) return self.decoder(features) * torch.sqrt(variances) + means + class AttentionPE(nn.Module): def __init__(self, in_chans: int): super().__init__() @@ -201,8 +222,12 @@ def __init__(self, in_chans: int): self.q = nn.Conv2d(in_chans, in_chans, kernel_size=1, stride=1, padding=0) self.k = nn.Conv2d(in_chans, in_chans, kernel_size=1, stride=1, padding=0) self.v = nn.Conv2d(in_chans, in_chans, kernel_size=1, stride=1, padding=0) - self.proj_out = nn.Conv2d(in_chans, in_chans, kernel_size=1, stride=1, padding=0) - self.dilated_conv = nn.Conv2d(in_chans, in_chans, kernel_size=3, stride=1, padding=2, dilation=2) + self.proj_out = nn.Conv2d( + in_chans, in_chans, kernel_size=1, stride=1, padding=0 + ) + self.dilated_conv = nn.Conv2d( + in_chans, in_chans, kernel_size=3, stride=1, padding=2, dilation=2 + ) def reshape_to_blocks(self, x: Tensor, accel: int) -> Tensor: chans = x.shape[1] @@ -210,26 +235,43 @@ def reshape_to_blocks(self, x: Tensor, accel: int) -> Tensor: pad_right = pad_total // 2 pad_left = pad_total - pad_right x = F.pad(x, (pad_left, pad_right, 0, 0), "reflect") - return (torch.stack(x.chunk(chunks=accel, dim=3), dim=-1).view(chans, -1, accel).permute(1, 0, 2).contiguous()) + return ( + torch.stack(x.chunk(chunks=accel, dim=3), dim=-1) + .view(chans, -1, accel) + .permute(1, 0, 2) + .contiguous() + ) - def reshape_from_blocks(self, x: Tensor, image_size: Tuple[int, int], accel: int) -> Tensor: + def reshape_from_blocks( + self, x: Tensor, image_size: Tuple[int, int], accel: int + ) -> Tensor: chans = x.shape[1] num_freq, num_phase = image_size - x = (x.permute(1, 0, 2).reshape(1, chans, num_freq, -1, accel).permute(0, 1, 2, 4, 3).reshape(1, chans, num_freq, -1)) + x = ( + x.permute(1, 0, 2) + .reshape(1, chans, num_freq, -1, accel) + .permute(0, 1, 2, 4, 3) + .reshape(1, chans, num_freq, -1) + ) padded_phase = x.shape[3] pad_total = padded_phase - num_phase pad_right = pad_total // 2 pad_left = pad_total - pad_right return x[:, :, :, pad_left : padded_phase - pad_right] - def get_positional_encodings(self, seq_len: int, embed_dim: int, device: str) -> Tensor: - freqs = torch.tensor([1 / (10000 ** (2 * (i // 2) / embed_dim)) for i in range(embed_dim)], device=device) + def get_positional_encodings( + self, seq_len: int, embed_dim: int, device: str + ) -> Tensor: + freqs = torch.tensor( + [1 / (10000 ** (2 * (i // 2) / embed_dim)) for i in range(embed_dim)], + device=device, + ) freqs = freqs.unsqueeze(0) positions = torch.arange(seq_len, dtype=torch.float, device=device).unsqueeze(1) scaled = positions * freqs sin_encodings = torch.sin(scaled) cos_encodings = torch.cos(scaled) - encodings = torch.cat([sin_encodings, cos_encodings], dim=1)[:,:embed_dim] + encodings = torch.cat([sin_encodings, cos_encodings], dim=1)[:, :embed_dim] return encodings def forward(self, x: Tensor, accel: int) -> Tensor: @@ -246,10 +288,10 @@ def forward(self, x: Tensor, accel: int) -> Tensor: v = self.dilated_conv(self.v(h_)) # compute attention - c = q.shape[1] - q = self.reshape_to_blocks(q, accel) - k = self.reshape_to_blocks(k, accel) - q = q.permute(0, 2, 1) # b,hw,c + c = q.shape[1] + q = self.reshape_to_blocks(q, accel) + k = self.reshape_to_blocks(k, accel) + q = q.permute(0, 2, 1) # b,hw,c w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] w_ = w_ * (int(c) ** (-0.5)) w_ = torch.nn.functional.softmax(w_, dim=2) @@ -264,6 +306,7 @@ def forward(self, x: Tensor, accel: int) -> Tensor: return x + h_ + class Unet(nn.Module): """ PyTorch implementation of a U-Net model. @@ -358,6 +401,7 @@ def forward(self, image: torch.Tensor) -> torch.Tensor: return output + class Unet2d(nn.Module): def __init__( self, @@ -426,6 +470,7 @@ def forward(self, image: Tensor) -> Tensor: image, (output_y, output_x) = self.pad_input_image(image) return self.final_conv(self.layer(image))[:, :, :output_y, :output_x] + class UnetLevel(nn.Module): def __init__( self, @@ -464,6 +509,7 @@ def forward(self, image: Tensor) -> Tensor: return image + class ConvBlock(nn.Module): def __init__(self, in_chans: int, out_chans: int, drop_prob: float): super().__init__() @@ -486,6 +532,7 @@ def __init__(self, in_chans: int, out_chans: int, drop_prob: float): def forward(self, image: torch.Tensor) -> torch.Tensor: return self.layers(image) + class TransposeConvBlock(nn.Module): def __init__(self, in_chans: int, out_chans: int): super().__init__() @@ -504,6 +551,7 @@ def __init__(self, in_chans: int, out_chans: int): def forward(self, image: torch.Tensor) -> torch.Tensor: return self.layers(image) + class NormUnet(nn.Module): def __init__( @@ -584,7 +632,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x, mean, std = self.norm(x) x, pad_sizes = self.pad(x) - #attention_goes_here + # attention_goes_here x = self.unet(x) # get shapes back and unnormalize @@ -594,6 +642,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x + class Norm1DUnet(nn.Module): def __init__( @@ -674,7 +723,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x, mean, std = self.norm(x) x, pad_sizes = self.pad(x) - #attention_goes_here + # attention_goes_here x = self.unet(x) # get shapes back and unnormalize @@ -684,6 +733,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x + class SensitivityModel(nn.Module): """ Model for learning sensitivity estimation from k-space data. @@ -754,7 +804,9 @@ def get_pad_and_num_low_freqs( mask.shape[0], dtype=mask.dtype, device=mask.device ) - pad = torch.div(mask.shape[-2] - num_low_frequencies_tensor + 1,2,rounding_mode='trunc') + pad = torch.div( + mask.shape[-2] - num_low_frequencies_tensor + 1, 2, rounding_mode="trunc" + ) return pad, num_low_frequencies_tensor @@ -768,9 +820,7 @@ def forward( pad, num_low_freqs = self.get_pad_and_num_low_freqs( mask, num_low_frequencies ) - masked_kspace = batched_mask_center( - masked_kspace, pad, pad + num_low_freqs - ) + masked_kspace = batched_mask_center(masked_kspace, pad, pad + num_low_freqs) # convert to image space images, batches = self.chans_to_batch_dim(ifft2c(masked_kspace)) @@ -780,6 +830,7 @@ def forward( self.batch_chans_to_chan_dim(self.norm_unet(images), batches) ) + class FIVarNet(nn.Module): def __init__( self, @@ -814,7 +865,9 @@ def __init__( encoder=self.encoder, decoder=self.decoder, acceleration=acceleration, - feature_processor=Unet2d(in_chans=chans, out_chans=chans, num_pool_layers=pools), + feature_processor=Unet2d( + in_chans=chans, out_chans=chans, num_pool_layers=pools + ), attention_layer=AttentionPE(in_chans=chans), use_extra_feature_conv=use_image_conv, ) @@ -870,18 +923,26 @@ def forward( ) -> Tensor: masked_kspace = masked_kspace * self.kspace_mult_factor # Encode to features and get sensitivities - feature_image = self._encode_input(masked_kspace=masked_kspace,mask=mask,crop_size=crop_size,num_low_frequencies=num_low_frequencies) + feature_image = self._encode_input( + masked_kspace=masked_kspace, + mask=mask, + crop_size=crop_size, + num_low_frequencies=num_low_frequencies, + ) # Do DC in feature-space feature_image = self.cascades(feature_image) # Find last k-space kspace_pred = self._decode_output(feature_image) # Run E2EVN for cascade in self.image_cascades: - kspace_pred = cascade(kspace_pred, feature_image.ref_kspace, mask, feature_image.sens_maps) + kspace_pred = cascade( + kspace_pred, feature_image.ref_kspace, mask, feature_image.sens_maps + ) # Return Final Image kspace_pred = kspace_pred / self.kspace_mult_factor return rss(complex_abs(ifft2c(kspace_pred)), dim=1) + class IFVarNet(nn.Module): def __init__( self, @@ -916,7 +977,9 @@ def __init__( encoder=self.encoder, decoder=self.decoder, acceleration=acceleration, - feature_processor=Unet2d(in_chans=chans, out_chans=chans, num_pool_layers=pools), + feature_processor=Unet2d( + in_chans=chans, out_chans=chans, num_pool_layers=pools + ), attention_layer=AttentionPE(in_chans=chans), use_extra_feature_conv=use_image_conv, ) @@ -971,7 +1034,7 @@ def forward( crop_size: Optional[Tuple[int, int]] = None, ) -> Tensor: - masked_kspace = masked_kspace*self.kspace_mult_factor + masked_kspace = masked_kspace * self.kspace_mult_factor sens_maps = self.sens_net(masked_kspace, mask, num_low_frequencies) kspace_pred = masked_kspace.clone() @@ -979,13 +1042,20 @@ def forward( for cascade in self.image_cascades: kspace_pred = cascade(kspace_pred, masked_kspace, mask, sens_maps) - feature_image = self._encode_input(masked_kspace=kspace_pred,ref_kspace=masked_kspace,sens_maps=sens_maps,mask=mask,crop_size=crop_size) + feature_image = self._encode_input( + masked_kspace=kspace_pred, + ref_kspace=masked_kspace, + sens_maps=sens_maps, + mask=mask, + crop_size=crop_size, + ) feature_image = self.cascades(feature_image) kspace_pred = self._decode_output(feature_image) kspace_pred = kspace_pred / self.kspace_mult_factor return rss(complex_abs(ifft2c(kspace_pred)), dim=1) + class FeatureVarNet_sh_w(nn.Module): def __init__( self, @@ -1018,7 +1088,9 @@ def __init__( FeatureVarNetBlock( encoder=self.encoder, decoder=self.decoder, - feature_processor=Unet2d(in_chans=chans, out_chans=chans, num_pool_layers=pools), + feature_processor=Unet2d( + in_chans=chans, out_chans=chans, num_pool_layers=pools + ), use_extra_feature_conv=use_image_conv, ) ) @@ -1069,7 +1141,12 @@ def forward( ) -> Tensor: masked_kspace = masked_kspace * self.kspace_mult_factor # Encode to features and get sensitivities - feature_image = self._encode_input(masked_kspace=masked_kspace,mask=mask,crop_size=crop_size,num_low_frequencies=num_low_frequencies) + feature_image = self._encode_input( + masked_kspace=masked_kspace, + mask=mask, + crop_size=crop_size, + num_low_frequencies=num_low_frequencies, + ) # Do DC in feature-space feature_image = self.cascades(feature_image) # Find last k-space @@ -1078,6 +1155,7 @@ def forward( kspace_pred = kspace_pred / self.kspace_mult_factor return rss(complex_abs(ifft2c(kspace_pred)), dim=1) + class FeatureVarNet_n_sh_w(nn.Module): def __init__( self, @@ -1110,7 +1188,9 @@ def __init__( FeatureVarNetBlock( encoder=FeatureEncoder(in_chans=2, feature_chans=chans), decoder=FeatureDecoder(feature_chans=chans, out_chans=2), - feature_processor=Unet2d(in_chans=chans, out_chans=chans, num_pool_layers=pools), + feature_processor=Unet2d( + in_chans=chans, out_chans=chans, num_pool_layers=pools + ), use_extra_feature_conv=use_image_conv, ) ) @@ -1161,7 +1241,12 @@ def forward( ) -> Tensor: masked_kspace = masked_kspace * self.kspace_mult_factor # Encode to features and get sensitivities - feature_image = self._encode_input(masked_kspace=masked_kspace,mask=mask,crop_size=crop_size,num_low_frequencies=num_low_frequencies) + feature_image = self._encode_input( + masked_kspace=masked_kspace, + mask=mask, + crop_size=crop_size, + num_low_frequencies=num_low_frequencies, + ) # Do DC in feature-space feature_image = self.cascades(feature_image) # Find last k-space @@ -1170,6 +1255,7 @@ def forward( kspace_pred = kspace_pred / self.kspace_mult_factor return rss(complex_abs(ifft2c(kspace_pred)), dim=1) + class AttentionFeatureVarNet_n_sh_w(nn.Module): def __init__( self, @@ -1204,7 +1290,9 @@ def __init__( encoder=self.encoder, decoder=self.decoder, acceleration=acceleration, - feature_processor=Unet2d(in_chans=chans, out_chans=chans, num_pool_layers=pools), + feature_processor=Unet2d( + in_chans=chans, out_chans=chans, num_pool_layers=pools + ), attention_layer=AttentionPE(in_chans=chans), use_extra_feature_conv=use_image_conv, ) @@ -1256,7 +1344,12 @@ def forward( ) -> Tensor: masked_kspace = masked_kspace * self.kspace_mult_factor # Encode to features and get sensitivities - feature_image = self._encode_input(masked_kspace=masked_kspace,mask=mask,crop_size=crop_size,num_low_frequencies=num_low_frequencies) + feature_image = self._encode_input( + masked_kspace=masked_kspace, + mask=mask, + crop_size=crop_size, + num_low_frequencies=num_low_frequencies, + ) # Do DC in feature-space feature_image = self.cascades(feature_image) # Find last k-space @@ -1265,6 +1358,7 @@ def forward( kspace_pred = kspace_pred / self.kspace_mult_factor return rss(complex_abs(ifft2c(kspace_pred)), dim=1) + class E2EVarNet(nn.Module): """ A full variational network model. @@ -1321,6 +1415,7 @@ def forward( return rss(complex_abs(ifft2c(kspace_pred)), dim=1) + class AttentionFeatureVarNetBlock(nn.Module): def __init__( self, @@ -1426,7 +1521,9 @@ def forward(self, feature_image: FeatureImage) -> FeatureImage: file_name = f'kspace_{timestamp}.mat' savemat(file_name, {'kspace_': new_ref_kspace}) """ - feature_image = feature_image._replace(features=self.attention_layer(feature_image.features,self.acceleration)) + feature_image = feature_image._replace( + features=self.attention_layer(feature_image.features, self.acceleration) + ) new_features = new_features - self.apply_model_with_crop(feature_image) if self.use_image_conv: @@ -1435,6 +1532,7 @@ def forward(self, feature_image: FeatureImage) -> FeatureImage: return feature_image._replace(features=new_features) + class FeatureVarNetBlock(nn.Module): def __init__( self, @@ -1524,7 +1622,11 @@ def forward(self, feature_image: FeatureImage) -> FeatureImage: features=self.input_norm(feature_image.features) ) - new_features = feature_image.features - self.compute_dc_term(feature_image) - self.apply_model_with_crop(feature_image) + new_features = ( + feature_image.features + - self.compute_dc_term(feature_image) + - self.apply_model_with_crop(feature_image) + ) if self.use_image_conv: new_features = self.output_norm(new_features) @@ -1532,6 +1634,7 @@ def forward(self, feature_image: FeatureImage) -> FeatureImage: return feature_image._replace(features=new_features) + class VarNetBlock(nn.Module): """ Model block for end-to-end variational network. diff --git a/fastmri_examples/feature_varnet/pl_modules/feature_varnet_module.py b/fastmri_examples/feature_varnet/feature_varnet_module.py similarity index 82% rename from fastmri_examples/feature_varnet/pl_modules/feature_varnet_module.py rename to fastmri_examples/feature_varnet/feature_varnet_module.py index ccd99596..3533608b 100644 --- a/fastmri_examples/feature_varnet/pl_modules/feature_varnet_module.py +++ b/fastmri_examples/feature_varnet/feature_varnet_module.py @@ -1,13 +1,29 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + import math from argparse import ArgumentParser import torch -torch.set_float32_matmul_precision('high') + +torch.set_float32_matmul_precision("high") import torch.nn as nn from fastmri.pl_modules.mri_module import MriModule from fastmri.losses import SSIMLoss from fastmri.data.transforms import center_crop_to_smallest, center_crop -from fastmri.models import FIVarNet, IFVarNet, FeatureVarNet_sh_w, FeatureVarNet_n_sh_w, AttentionFeatureVarNet_n_sh_w, E2EVarNet +from fastmri.models import ( + FIVarNet, + IFVarNet, + FeatureVarNet_sh_w, + FeatureVarNet_n_sh_w, + AttentionFeatureVarNet_n_sh_w, + E2EVarNet, +) + class FIVarNetModule(MriModule): def __init__( @@ -33,7 +49,12 @@ def forward(self, masked_kspace, mask, num_low_frequencies): return self.fi_varnet(masked_kspace, mask, num_low_frequencies) def training_step(self, batch, batch_idx): - output = self.fi_varnet(batch.masked_kspace,batch.mask,batch.num_low_frequencies,crop_size=batch.crop_size) + output = self.fi_varnet( + batch.masked_kspace, + batch.mask, + batch.num_low_frequencies, + crop_size=batch.crop_size, + ) target, output = center_crop_to_smallest(batch.target, output) loss = self.loss( output.unsqueeze(1), target.unsqueeze(1).float(), data_range=batch.max_value @@ -47,7 +68,12 @@ def on_before_optimizer_step(self, optimizer, optimizer_idx): self.log(f"grads/{name}", torch.norm(param.grad)) def validation_step(self, batch, batch_idx): - output = self.fi_varnet(batch.masked_kspace,batch.mask,batch.num_low_frequencies,crop_size=batch.crop_size) + output = self.fi_varnet( + batch.masked_kspace, + batch.mask, + batch.num_low_frequencies, + crop_size=batch.crop_size, + ) target, output = center_crop_to_smallest(batch.target, output) return { "batch_idx": batch_idx, @@ -57,12 +83,19 @@ def validation_step(self, batch, batch_idx): "output": output, "target": target, "val_loss": self.loss( - output.unsqueeze(1), target.unsqueeze(1).float(), data_range=batch.max_value + output.unsqueeze(1), + target.unsqueeze(1).float(), + data_range=batch.max_value, ), } def test_step(self, batch, batch_idx): - output = self.fi_varnet(batch.masked_kspace,batch.mask,batch.num_low_frequencies,crop_size=batch.crop_size) + output = self.fi_varnet( + batch.masked_kspace, + batch.mask, + batch.num_low_frequencies, + crop_size=batch.crop_size, + ) if output.shape[-1] < batch.crop_size[1]: crop_size = (output.shape[-1], output.shape[-1]) else: @@ -76,6 +109,7 @@ def test_step(self, batch, batch_idx): def configure_optimizers(self): cosine_steps = self.max_steps - self.cosine_decay_start + def step_fn(step): if step < self.cosine_decay_start: return min(step / self.ramp_steps, 1.0) diff --git a/fastmri_examples/feature_varnet/pl_modules/__init__.py b/fastmri_examples/feature_varnet/pl_modules/__init__.py deleted file mode 100644 index 7aaea906..00000000 --- a/fastmri_examples/feature_varnet/pl_modules/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ - -""" -Copyright (c) Facebook, Inc. and its affiliates. - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. -""" - -from .feature_varnet_module import FIVarNetModule diff --git a/fastmri_examples/feature_varnet/requirements.txt b/fastmri_examples/feature_varnet/requirements.txt deleted file mode 100644 index b4d5d2bb..00000000 --- a/fastmri_examples/feature_varnet/requirements.txt +++ /dev/null @@ -1,11 +0,0 @@ - -numpy==1.18.5 -scikit_image==0.16.2 -torchvision==0.8.1 -torch==1.7.0 -runstats==2.0.0 -pytorch_lightning==1.0.6 -h5py==2.10.0 -PyYAML==5.4.1 -torchmetrics==0.3.2 -wandb==0.12.7 diff --git a/fastmri_examples/feature_varnet/train_feature_varnet.py b/fastmri_examples/feature_varnet/train_feature_varnet.py index 8f66f5f0..1d6054b5 100644 --- a/fastmri_examples/feature_varnet/train_feature_varnet.py +++ b/fastmri_examples/feature_varnet/train_feature_varnet.py @@ -1,25 +1,42 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + import os import torch -torch.set_float32_matmul_precision('high') + +torch.set_float32_matmul_precision("high") import pathlib from argparse import ArgumentParser from pathlib import Path from typing import Optional import pytorch_lightning as pl from pytorch_lightning.loggers import TensorBoardLogger -from fastmri.models.feature_varnet import FIVarNet, IFVarNet, FeatureVarNet_sh_w, FeatureVarNet_n_sh_w, E2EVarNet, AttentionFeatureVarNet_n_sh_w -from pl_modules import FIVarNetModule +from fastmri.models.feature_varnet import ( + FIVarNet, + IFVarNet, + FeatureVarNet_sh_w, + FeatureVarNet_n_sh_w, + E2EVarNet, + AttentionFeatureVarNet_n_sh_w, +) from fastmri.data.subsample import create_mask_for_mask_type from fastmri.data.transforms import VarNetDataTransform from fastmri.data.mri_data import fetch_dir from fastmri.pl_modules.data_module import FastMriDataModule +from .feature_varnet_module import FIVarNetModule import subprocess + def check_gpu_availability(): command = "nvidia-smi --query-gpu=index --format=csv,noheader | wc -l" output = subprocess.check_output(command, shell=True).decode("utf-8").strip() return int(output) + def reload_state_dict( module: FIVarNetModule, fname: Path, module_name: str = "fi_varnet." ): @@ -30,6 +47,7 @@ def reload_state_dict( module.fi_varnet.load_state_dict(state_dict) return module + def fetch_model(args, acceleration): if args.varnet_type == "fi_varnet": print(f"BUILDING FI VARNET, chans={args.chans}") @@ -52,7 +70,9 @@ def fetch_model(args, acceleration): acceleration=acceleration, ) elif args.varnet_type == "attention_feature_varnet_sh_w": - print(f"BUILDING ATTENTION FEATURE VARNET WITH WEIGHT SHARING, chans={args.chans}") + print( + f"BUILDING ATTENTION FEATURE VARNET WITH WEIGHT SHARING, chans={args.chans}" + ) return AttentionFeatureVarNet_n_sh_w( num_cascades=args.num_cascades, pools=args.pools, @@ -91,6 +111,7 @@ def fetch_model(args, acceleration): else: raise ValueError("Unrecognized varnet_type") + def cli_main(args): pl.seed_everything(args.seed) @@ -142,7 +163,10 @@ def cli_main(args): else: raise ValueError(f"unrecognized mode {args.mode}") -def build_args(model_name: Optional[str] = "VarNet DDP x4", cluster_launch: bool = True): + +def build_args( + model_name: Optional[str] = "VarNet DDP x4", cluster_launch: bool = True +): parser = ArgumentParser() path_config = pathlib.Path("./fastmri_dirs.yaml") backend = "ddp" @@ -179,7 +203,14 @@ def build_args(model_name: Optional[str] = "VarNet DDP x4", cluster_launch: bool ) parser.add_argument( "--varnet_type", - choices=("fi_varnet","if_varnet","feature_varnet_sh_w","feature_varnet_n_sh_w","attention_feature_varnet_sh_w","e2e_varnet"), + choices=( + "fi_varnet", + "if_varnet", + "feature_varnet_sh_w", + "feature_varnet_n_sh_w", + "attention_feature_varnet_sh_w", + "e2e_varnet", + ), default="fi_varnet", type=str, help="Type of VarNet to use", @@ -191,17 +222,19 @@ def build_args(model_name: Optional[str] = "VarNet DDP x4", cluster_launch: bool if args.mode == "test" or args.mode == "test_val": num_gpus = 1 if args.varnet_type == "e2e_varnet": - default_root_dir = (fetch_dir("log_path", path_config) / "e2e_varnet") + default_root_dir = fetch_dir("log_path", path_config) / "e2e_varnet" if args.varnet_type == "fi_varnet": - default_root_dir = (fetch_dir("log_path", path_config) / "fi_varnet") + default_root_dir = fetch_dir("log_path", path_config) / "fi_varnet" if args.varnet_type == "if_varnet": - default_root_dir = (fetch_dir("log_path", path_config) / "if_varnet") + default_root_dir = fetch_dir("log_path", path_config) / "if_varnet" elif args.varnet_type == "feature_varnet_sh_w": - default_root_dir = (fetch_dir("log_path", path_config) / "feature_varnet_sh_w") + default_root_dir = fetch_dir("log_path", path_config) / "feature_varnet_sh_w" elif args.varnet_type == "feature_varnet_n_sh_w": - default_root_dir = (fetch_dir("log_path", path_config) / "feature_varnet_n_sh_w") + default_root_dir = fetch_dir("log_path", path_config) / "feature_varnet_n_sh_w" elif args.varnet_type == "attention_feature_varnet_sh_w": - default_root_dir = (fetch_dir("log_path", path_config) / "attention_feature_varnet_sh_w") + default_root_dir = ( + fetch_dir("log_path", path_config) / "attention_feature_varnet_sh_w" + ) parser.set_defaults( data_path=data_path, # path to fastMRI data @@ -221,7 +254,7 @@ def build_args(model_name: Optional[str] = "VarNet DDP x4", cluster_launch: bool sens_chans=8, # number of top-level channels for sense est. U-Net lr=0.0003, # Adam learning rate ramp_steps=7500, - cosine_decay_start=150000,#150000, + cosine_decay_start=150000, # 150000, weight_decay=0.0, # weight regularization strength ) parser = pl.Trainer.add_argparse_args(parser) @@ -233,7 +266,7 @@ def build_args(model_name: Optional[str] = "VarNet DDP x4", cluster_launch: bool seed=42, # random seed # deterministic=True, # makes things slower, but deterministic default_root_dir=default_root_dir, # directory for logs and checkpoints - max_steps=210000,#210000, # number of steps for 50 knee epochs + max_steps=210000, # 210000, # number of steps for 50 knee epochs detect_anomaly=False, gradient_clip_val=1.0, ) @@ -262,9 +295,11 @@ def build_args(model_name: Optional[str] = "VarNet DDP x4", cluster_launch: bool args.resume_from_checkpoint = str(ckpt_list[-1]) return args + def run_cli(): args = build_args(cluster_launch=True) cli_main(args) + if __name__ == "__main__": run_cli() From dba55c6c8a872a9f23c99449a8f3c445f9aa2b16 Mon Sep 17 00:00:00 2001 From: GiannakopoulosIlias Date: Wed, 12 Jun 2024 15:47:00 -0400 Subject: [PATCH 3/6] Re-adding fastmri/models/feature_varnet.py to ensure CI passes --- fastmri/models/feature_varnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastmri/models/feature_varnet.py b/fastmri/models/feature_varnet.py index 20d637f5..f9334566 100644 --- a/fastmri/models/feature_varnet.py +++ b/fastmri/models/feature_varnet.py @@ -938,7 +938,7 @@ def forward( kspace_pred = cascade( kspace_pred, feature_image.ref_kspace, mask, feature_image.sens_maps ) - # Return Final Image + # Divide with k-space factor and Return Final Image kspace_pred = kspace_pred / self.kspace_mult_factor return rss(complex_abs(ifft2c(kspace_pred)), dim=1) From 5abfbb48254c62390e9f50df17fed64d0aa06bf4 Mon Sep 17 00:00:00 2001 From: GiannakopoulosIlias Date: Wed, 19 Jun 2024 11:32:10 -0400 Subject: [PATCH 4/6] Re-adding fastmri/models/feature_varnet.py to ensure CI passes - reinstalled black 22.3.0 --- fastmri/models/feature_varnet.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/fastmri/models/feature_varnet.py b/fastmri/models/feature_varnet.py index f9334566..343841d3 100644 --- a/fastmri/models/feature_varnet.py +++ b/fastmri/models/feature_varnet.py @@ -53,9 +53,9 @@ def image_uncrop(image: Tensor, original_image: Tensor) -> Tensor: pad_height_left, pad_width = _calc_uncrop(image.shape[-1], in_shape[-1]) try: - original_image[..., pad_height_top:pad_height, pad_height_left:pad_width] = ( - image[...] - ) + original_image[ + ..., pad_height_top:pad_height, pad_height_left:pad_width + ] = image[...] except RuntimeError: print(f"in_shape: {in_shape}, image shape: {image.shape}") raise @@ -553,7 +553,6 @@ def forward(self, image: torch.Tensor) -> torch.Tensor: class NormUnet(nn.Module): - def __init__( self, chans: int, @@ -644,7 +643,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Norm1DUnet(nn.Module): - def __init__( self, chans: int, From 84ce0233058bb6741cf68852a9c3cf0f064cb246 Mon Sep 17 00:00:00 2001 From: GiannakopoulosIlias Date: Thu, 20 Jun 2024 11:32:09 -0400 Subject: [PATCH 5/6] Re-adding fastmri/models/feature_varnet.py to ensure compatibility with mypy 1.1.1 --- fastmri/models/feature_varnet.py | 124 ++++++++++++++++++++----------- 1 file changed, 82 insertions(+), 42 deletions(-) diff --git a/fastmri/models/feature_varnet.py b/fastmri/models/feature_varnet.py index 343841d3..acc0f527 100644 --- a/fastmri/models/feature_varnet.py +++ b/fastmri/models/feature_varnet.py @@ -11,7 +11,6 @@ import torch.nn as nn from torch import Tensor -torch.set_float32_matmul_precision("high") import torch.nn.functional as F import torch.distributed as dist import numpy as np @@ -53,9 +52,18 @@ def image_uncrop(image: Tensor, original_image: Tensor) -> Tensor: pad_height_left, pad_width = _calc_uncrop(image.shape[-1], in_shape[-1]) try: - original_image[ - ..., pad_height_top:pad_height, pad_height_left:pad_width - ] = image[...] + if len(in_shape) == 2: # Assuming 2D images + original_image[pad_height_top:pad_height, pad_height_left:pad_width] = image + elif len(in_shape) == 3: # Assuming 3D images with channels + original_image[ + :, pad_height_top:pad_height, pad_height_left:pad_width + ] = image + elif len(in_shape) == 4: # Assuming 4D images with batch size + original_image[ + :, :, pad_height_top:pad_height, pad_height_left:pad_width + ] = image + else: + raise RuntimeError(f"Unsupported tensor shape: {in_shape}") except RuntimeError: print(f"in_shape: {in_shape}, image shape: {image.shape}") raise @@ -120,6 +128,7 @@ def forward(self, data: Tensor) -> Tuple[Tensor, Tensor]: return mean, variance +""" class RunningChannelStats(nn.Module): def __init__(self, chans: int, eps: float = 1e-14, freeze_step: int = 20000): super().__init__() @@ -159,16 +168,17 @@ def forward(self, image: Tensor) -> Tuple[Tensor, Tensor, Tensor]: run_var = self.vars.clone().view(1, -1, 1, 1) + self.eps return run_mean, run_var +""" class FeatureImage(NamedTuple): features: Tensor - sens_maps: Tensor = None + sens_maps: Optional[Tensor] = None crop_size: Optional[Tuple[int, int]] = None - means: Tensor = None - variances: Tensor = None - mask: Tensor = None - ref_kspace: Tensor = None + means: Optional[Tensor] = None + variances: Optional[Tensor] = None + mask: Optional[Tensor] = None + ref_kspace: Optional[Tensor] = None beta: Optional[Tensor] = None gamma: Optional[Tensor] = None @@ -279,7 +289,7 @@ def forward(self, x: Tensor, accel: int) -> Tensor: h_ = x h_ = self.norm(h_) - pos_enc = self.get_positional_encodings(x.shape[2], x.shape[3], h_.device) + pos_enc = self.get_positional_encodings(x.shape[2], x.shape[3], h_.device.type) h_ = h_ + pos_enc @@ -434,13 +444,15 @@ def __init__( ) if output_bias: - self.final_conv = nn.Conv2d( - in_channels=chans, - out_channels=out_chans, - kernel_size=1, - stride=1, - padding=0, - bias=True, + self.final_conv = nn.Sequential( + nn.Conv2d( + in_channels=chans, + out_channels=out_chans, + kernel_size=1, + stride=1, + padding=0, + bias=True, + ) ) else: self.final_conv = nn.Sequential( @@ -491,15 +503,24 @@ def __init__( if child is not None: self.downsample = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) - self.upsample = TransposeConvBlock( - in_chans=child.out_planes, out_chans=out_planes - ) + if isinstance(child, UnetLevel): # Ensure child is an instance of UnetLevel + self.upsample = TransposeConvBlock( + in_chans=child.out_planes, out_chans=out_planes + ) + else: + raise TypeError("Child must be an instance of UnetLevel") + self.right_block = ConvBlock( in_chans=2 * out_planes, out_chans=out_planes, drop_prob=drop_prob ) def down_up(self, image: Tensor) -> Tensor: - return self.upsample(self.child(self.downsample(image))) + if self.child is None: + raise ValueError("self.child is None, cannot call down_up.") + downsampled = self.downsample(image) + child_output = self.child(downsampled) + upsampled = self.upsample(child_output) + return upsampled def forward(self, image: Tensor) -> Tensor: image = self.left_block(image) @@ -879,7 +900,7 @@ def __init__( self.cascades = nn.Sequential(*cascades) self.norm_fn = NormStats() - def _decode_output(self, feature_image: FeatureImage) -> Tuple[Tensor, Tensor]: + def _decode_output(self, feature_image: FeatureImage) -> Tensor: image = self.decoder( self.decode_norm(feature_image.features), means=feature_image.means, @@ -897,7 +918,7 @@ def _encode_input( sens_maps = self.sens_net(masked_kspace, mask, num_low_frequencies) image = sens_reduce(masked_kspace, sens_maps) # detect FLAIR 203 - if image.shape[-1] < crop_size[1]: + if crop_size is not None and image.shape[-1] < crop_size[1]: crop_size = (image.shape[-1], image.shape[-1]) means, variances = self.norm_fn(image) features = self.encoder(image, means=means, variances=variances) @@ -937,8 +958,12 @@ def forward( kspace_pred, feature_image.ref_kspace, mask, feature_image.sens_maps ) # Divide with k-space factor and Return Final Image - kspace_pred = kspace_pred / self.kspace_mult_factor - return rss(complex_abs(ifft2c(kspace_pred)), dim=1) + kspace_pred = ( + kspace_pred / self.kspace_mult_factor + ) # Ensure kspace_pred is a Tensor + return rss( + complex_abs(ifft2c(kspace_pred)), dim=1 + ) # Ensure kspace_pred is a Tensor class IFVarNet(nn.Module): @@ -991,7 +1016,7 @@ def __init__( self.cascades = nn.Sequential(*cascades) self.norm_fn = NormStats() - def _decode_output(self, feature_image: FeatureImage) -> Tuple[Tensor, Tensor]: + def _decode_output(self, feature_image: FeatureImage) -> Tensor: image = self.decoder( self.decode_norm(feature_image.features), means=feature_image.means, @@ -1009,7 +1034,7 @@ def _encode_input( ) -> FeatureImage: image = sens_reduce(masked_kspace, sens_maps) # detect FLAIR 203 - if image.shape[-1] < crop_size[1]: + if crop_size is not None and image.shape[-1] < crop_size[1]: crop_size = (image.shape[-1], image.shape[-1]) means, variances = self.norm_fn(image) features = self.encoder(image, means=means, variances=variances) @@ -1049,9 +1074,12 @@ def forward( ) feature_image = self.cascades(feature_image) kspace_pred = self._decode_output(feature_image) - kspace_pred = kspace_pred / self.kspace_mult_factor - - return rss(complex_abs(ifft2c(kspace_pred)), dim=1) + kspace_pred = ( + kspace_pred / self.kspace_mult_factor + ) # Ensure kspace_pred is a Tensor + return rss( + complex_abs(ifft2c(kspace_pred)), dim=1 + ) # Ensure kspace_pred is a Tensor class FeatureVarNet_sh_w(nn.Module): @@ -1097,7 +1125,7 @@ def __init__( self.cascades = nn.Sequential(*cascades) self.norm_fn = NormStats() - def _decode_output(self, feature_image: FeatureImage) -> Tuple[Tensor, Tensor]: + def _decode_output(self, feature_image: FeatureImage) -> Tensor: image = self.decoder( self.decode_norm(feature_image.features), means=feature_image.means, @@ -1115,7 +1143,7 @@ def _encode_input( sens_maps = self.sens_net(masked_kspace, mask, num_low_frequencies) image = sens_reduce(masked_kspace, sens_maps) # detect FLAIR 203 - if image.shape[-1] < crop_size[1]: + if crop_size is not None and image.shape[-1] < crop_size[1]: crop_size = (image.shape[-1], image.shape[-1]) means, variances = self.norm_fn(image) features = self.encoder(image, means=means, variances=variances) @@ -1150,8 +1178,12 @@ def forward( # Find last k-space kspace_pred = self._decode_output(feature_image) # Return Final Image - kspace_pred = kspace_pred / self.kspace_mult_factor - return rss(complex_abs(ifft2c(kspace_pred)), dim=1) + kspace_pred = ( + kspace_pred / self.kspace_mult_factor + ) # Ensure kspace_pred is a Tensor + return rss( + complex_abs(ifft2c(kspace_pred)), dim=1 + ) # Ensure kspace_pred is a Tensor class FeatureVarNet_n_sh_w(nn.Module): @@ -1197,7 +1229,7 @@ def __init__( self.cascades = nn.Sequential(*cascades) self.norm_fn = NormStats() - def _decode_output(self, feature_image: FeatureImage) -> Tuple[Tensor, Tensor]: + def _decode_output(self, feature_image: FeatureImage) -> Tensor: image = self.decoder( self.decode_norm(feature_image.features), means=feature_image.means, @@ -1215,7 +1247,7 @@ def _encode_input( sens_maps = self.sens_net(masked_kspace, mask, num_low_frequencies) image = sens_reduce(masked_kspace, sens_maps) # detect FLAIR 203 - if image.shape[-1] < crop_size[1]: + if crop_size is not None and image.shape[-1] < crop_size[1]: crop_size = (image.shape[-1], image.shape[-1]) means, variances = self.norm_fn(image) features = self.encoder(image, means=means, variances=variances) @@ -1250,8 +1282,12 @@ def forward( # Find last k-space kspace_pred = self._decode_output(feature_image) # Return Final Image - kspace_pred = kspace_pred / self.kspace_mult_factor - return rss(complex_abs(ifft2c(kspace_pred)), dim=1) + kspace_pred = ( + kspace_pred / self.kspace_mult_factor + ) # Ensure kspace_pred is a Tensor + return rss( + complex_abs(ifft2c(kspace_pred)), dim=1 + ) # Ensure kspace_pred is a Tensor class AttentionFeatureVarNet_n_sh_w(nn.Module): @@ -1300,7 +1336,7 @@ def __init__( self.cascades = nn.Sequential(*cascades) self.norm_fn = NormStats() - def _decode_output(self, feature_image: FeatureImage) -> Tuple[Tensor, Tensor]: + def _decode_output(self, feature_image: FeatureImage) -> Tensor: image = self.decoder( self.decode_norm(feature_image.features), means=feature_image.means, @@ -1318,7 +1354,7 @@ def _encode_input( sens_maps = self.sens_net(masked_kspace, mask, num_low_frequencies) image = sens_reduce(masked_kspace, sens_maps) # detect FLAIR 203 - if image.shape[-1] < crop_size[1]: + if crop_size is not None and image.shape[-1] < crop_size[1]: crop_size = (image.shape[-1], image.shape[-1]) means, variances = self.norm_fn(image) features = self.encoder(image, means=means, variances=variances) @@ -1353,8 +1389,12 @@ def forward( # Find last k-space kspace_pred = self._decode_output(feature_image) # Return Final Image - kspace_pred = kspace_pred / self.kspace_mult_factor - return rss(complex_abs(ifft2c(kspace_pred)), dim=1) + kspace_pred = ( + kspace_pred / self.kspace_mult_factor + ) # Ensure kspace_pred is a Tensor + return rss( + complex_abs(ifft2c(kspace_pred)), dim=1 + ) # Ensure kspace_pred is a Tensor class E2EVarNet(nn.Module): From 11c6cc096d3fd0f64d1e1e92c216d18e812229e8 Mon Sep 17 00:00:00 2001 From: GiannakopoulosIlias Date: Fri, 28 Jun 2024 14:30:55 -0400 Subject: [PATCH 6/6] Change fastMRI version --- fastmri/models/feature_varnet.py | 44 -------------------------------- 1 file changed, 44 deletions(-) diff --git a/fastmri/models/feature_varnet.py b/fastmri/models/feature_varnet.py index acc0f527..fee94505 100644 --- a/fastmri/models/feature_varnet.py +++ b/fastmri/models/feature_varnet.py @@ -127,50 +127,6 @@ def forward(self, data: Tensor) -> Tuple[Tensor, Tensor]: return mean, variance - -""" -class RunningChannelStats(nn.Module): - def __init__(self, chans: int, eps: float = 1e-14, freeze_step: int = 20000): - super().__init__() - - self.means: Tensor - self.vars: Tensor - self.current_step: Tensor - self.eps = eps - self.chans = chans - self.freeze_step = freeze_step - - self.register_buffer("current_step", torch.zeros(1, dtype=torch.int)) - self.register_buffer("means", torch.zeros(chans)) - self.register_buffer("vars", torch.zeros(chans)) - - def forward(self, image: Tensor) -> Tuple[Tensor, Tensor, Tensor]: - if image.shape[1] != self.chans: - raise ValueError("Invalid channel number.") - - if self.current_step < self.freeze_step and self.training: - stats = image.permute(1, 0, 2, 3).reshape(self.chans, -1) - mean = stats.mean(1) - var = stats.var(1, unbiased=True) - - var = var / dist.get_world_size() - self.means.copy_(self.means + (mean - self.means) / (self.current_step + 1)) - self.vars.copy_(self.vars + (var - self.vars) / (self.current_step + 1)) - - self.current_step += 1 - - if self.current_step == 0 and not self.training: - stats = image.permute(1, 0, 2, 3).reshape(self.chans, -1) - run_mean = stats.mean(1).view(1, -1, 1, 1) - run_var = (stats.var(1, unbiased=True) + self.eps).view(1, -1, 1, 1) - else: - run_mean = self.means.clone().view(1, -1, 1, 1) - run_var = self.vars.clone().view(1, -1, 1, 1) + self.eps - - return run_mean, run_var -""" - - class FeatureImage(NamedTuple): features: Tensor sens_maps: Optional[Tensor] = None