Skip to content

Commit

Permalink
Change fastMRI version
Browse files Browse the repository at this point in the history
  • Loading branch information
GiannakopoulosIlias committed Jun 28, 2024
1 parent 84ce023 commit 11c6cc0
Showing 1 changed file with 0 additions and 44 deletions.
44 changes: 0 additions & 44 deletions fastmri/models/feature_varnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,50 +127,6 @@ def forward(self, data: Tensor) -> Tuple[Tensor, Tensor]:

return mean, variance


"""
class RunningChannelStats(nn.Module):
def __init__(self, chans: int, eps: float = 1e-14, freeze_step: int = 20000):
super().__init__()
self.means: Tensor
self.vars: Tensor
self.current_step: Tensor
self.eps = eps
self.chans = chans
self.freeze_step = freeze_step
self.register_buffer("current_step", torch.zeros(1, dtype=torch.int))
self.register_buffer("means", torch.zeros(chans))
self.register_buffer("vars", torch.zeros(chans))
def forward(self, image: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
if image.shape[1] != self.chans:
raise ValueError("Invalid channel number.")
if self.current_step < self.freeze_step and self.training:
stats = image.permute(1, 0, 2, 3).reshape(self.chans, -1)
mean = stats.mean(1)
var = stats.var(1, unbiased=True)
var = var / dist.get_world_size()
self.means.copy_(self.means + (mean - self.means) / (self.current_step + 1))
self.vars.copy_(self.vars + (var - self.vars) / (self.current_step + 1))
self.current_step += 1
if self.current_step == 0 and not self.training:
stats = image.permute(1, 0, 2, 3).reshape(self.chans, -1)
run_mean = stats.mean(1).view(1, -1, 1, 1)
run_var = (stats.var(1, unbiased=True) + self.eps).view(1, -1, 1, 1)
else:
run_mean = self.means.clone().view(1, -1, 1, 1)
run_var = self.vars.clone().view(1, -1, 1, 1) + self.eps
return run_mean, run_var
"""


class FeatureImage(NamedTuple):
features: Tensor
sens_maps: Optional[Tensor] = None
Expand Down

0 comments on commit 11c6cc0

Please sign in to comment.