diff --git a/fastmri/models/feature_varnet.py b/fastmri_examples/feature_varnet/feature_varnet.py similarity index 99% rename from fastmri/models/feature_varnet.py rename to fastmri_examples/feature_varnet/feature_varnet.py index 5f993e83..1ddb656d 100644 --- a/fastmri/models/feature_varnet.py +++ b/fastmri_examples/feature_varnet/feature_varnet.py @@ -5,21 +5,19 @@ LICENSE file in the root directory of this source tree. """ -from typing import NamedTuple, Optional, Tuple, List import math +from typing import List, NamedTuple, Optional, Tuple + import torch import torch.nn as nn +import torch.nn.functional as F from torch import Tensor -import torch.nn.functional as F -import torch.distributed as dist -import numpy as np -import math -from fastmri.data.transforms import center_crop, batched_mask_center -from fastmri.fftc import ifft2c_new as ifft2c +from fastmri.coil_combine import rss, rss_complex +from fastmri.data.transforms import batched_mask_center, center_crop from fastmri.fftc import fft2c_new as fft2c -from fastmri.coil_combine import rss_complex, rss -from fastmri.math import complex_abs, complex_mul, complex_conj +from fastmri.fftc import ifft2c_new as ifft2c +from fastmri.math import complex_abs, complex_conj, complex_mul def image_crop(image: Tensor, crop_size: Optional[Tuple[int, int]] = None) -> Tensor: @@ -55,9 +53,9 @@ def image_uncrop(image: Tensor, original_image: Tensor) -> Tensor: if len(in_shape) == 2: # Assuming 2D images original_image[pad_height_top:pad_height, pad_height_left:pad_width] = image elif len(in_shape) == 3: # Assuming 3D images with channels - original_image[ - :, pad_height_top:pad_height, pad_height_left:pad_width - ] = image + original_image[:, pad_height_top:pad_height, pad_height_left:pad_width] = ( + image + ) elif len(in_shape) == 4: # Assuming 4D images with batch size original_image[ :, :, pad_height_top:pad_height, pad_height_left:pad_width diff --git a/fastmri_examples/feature_varnet/feature_varnet_module.py b/fastmri_examples/feature_varnet/feature_varnet_module.py index 3533608b..7fc33015 100644 --- a/fastmri_examples/feature_varnet/feature_varnet_module.py +++ b/fastmri_examples/feature_varnet/feature_varnet_module.py @@ -11,18 +11,11 @@ import torch torch.set_float32_matmul_precision("high") -import torch.nn as nn -from fastmri.pl_modules.mri_module import MriModule +from feature_varnet import FIVarNet + +from fastmri.data.transforms import center_crop, center_crop_to_smallest from fastmri.losses import SSIMLoss -from fastmri.data.transforms import center_crop_to_smallest, center_crop -from fastmri.models import ( - FIVarNet, - IFVarNet, - FeatureVarNet_sh_w, - FeatureVarNet_n_sh_w, - AttentionFeatureVarNet_n_sh_w, - E2EVarNet, -) +from fastmri.pl_modules.mri_module import MriModule class FIVarNetModule(MriModule): diff --git a/fastmri_examples/feature_varnet/train_feature_varnet.py b/fastmri_examples/feature_varnet/train_feature_varnet.py index 1d6054b5..42447c9b 100644 --- a/fastmri_examples/feature_varnet/train_feature_varnet.py +++ b/fastmri_examples/feature_varnet/train_feature_varnet.py @@ -6,29 +6,33 @@ """ import os + import torch torch.set_float32_matmul_precision("high") import pathlib +import subprocess from argparse import ArgumentParser from pathlib import Path from typing import Optional + import pytorch_lightning as pl -from pytorch_lightning.loggers import TensorBoardLogger -from fastmri.models.feature_varnet import ( +from feature_varnet import ( + AttentionFeatureVarNet_n_sh_w, + E2EVarNet, + FeatureVarNet_n_sh_w, + FeatureVarNet_sh_w, FIVarNet, IFVarNet, - FeatureVarNet_sh_w, - FeatureVarNet_n_sh_w, - E2EVarNet, - AttentionFeatureVarNet_n_sh_w, ) +from pytorch_lightning.loggers import TensorBoardLogger + +from fastmri.data.mri_data import fetch_dir from fastmri.data.subsample import create_mask_for_mask_type from fastmri.data.transforms import VarNetDataTransform -from fastmri.data.mri_data import fetch_dir from fastmri.pl_modules.data_module import FastMriDataModule + from .feature_varnet_module import FIVarNetModule -import subprocess def check_gpu_availability():