Skip to content

Commit

Permalink
Merge branch 'fix-doc'
Browse files Browse the repository at this point in the history
  • Loading branch information
guanhuaw committed Apr 13, 2022
2 parents f79f46a + f82bb1c commit e2c0083
Show file tree
Hide file tree
Showing 8 changed files with 19 additions and 18 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Documentation: https://mirtorch.readthedocs.io/en/latest/
We recommend to [pre-install `PyTorch` first](https://pytorch.org/).
To install the `MIRTorch` package, after cloning the repo, please try `pip install -e .`

`requirements.txt` details the package dependencies.
`requirements.txt` details the package dependencies. We recommend to install [pytorch_wavelets](https://github.com/fbcotter/pytorch_wavelets) directly from the source code instead of `pip`.

### Features

Expand Down
2 changes: 2 additions & 0 deletions mirtorch/alg/cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def cg_block(x0, b, A, tol, max_iter, alert, eval_func, P):
num_loop = num_loop + 1
if eval_func is not None:
saved.append(eval_func(rk))
if alert:
print("Residual at %dth iter: %10.3e." % (num_loop, rktrk))
else:
r0 = b - A * x0
Expand Down Expand Up @@ -74,6 +75,7 @@ def cg_block(x0, b, A, tol, max_iter, alert, eval_func, P):
num_loop = num_loop + 1
if eval_func is not None:
saved.append(eval_func(rk))
if alert:
print("Residual at %dth iter: %10.3e." % (num_loop, rktzk))

if eval_func is not None:
Expand Down
11 changes: 5 additions & 6 deletions mirtorch/alg/fbpd.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,14 @@ def __init__(self,
self.G = G
self.G_norm = G_norm
if tau is None:
self.tau = 2 / (g_L + 2)
self.tau = 2.0 / (g_L + 2.0)
else:
self.tau = tau
self.sigma = (1 / self.tau - self.g_L / 2) / self.G_norm
self.sigma = (1.0 / self.tau - self.g_L / 2.0) / self.G_norm
self.eval_func = eval_func

def run(self,
x0: torch.Tensor,
u0: torch.Tensor):
x0: torch.Tensor):
"""
Run the algorithm
Args:
Expand All @@ -67,12 +66,12 @@ def run(self,
xk: results
saved: (optional) a list of intermediate results, calcuated by the eval_func.
"""
uold = u0
uold = self.G*x0
xold = x0
if self.eval_func is not None:
saved = []
for i in range(1, self.max_iter + 1):
xold_bar = self.g_grad(xold - self.G.H * uold)
xold_bar = self.g_grad(xold) - self.G.H * uold
xnew = self.f_prox(xold - self.tau * xold_bar, self.tau)
uold_bar = self.G * (2 * xnew - xold)
unew = self.h_conj_prox(uold + self.sigma * uold_bar, self.sigma)
Expand Down
2 changes: 1 addition & 1 deletion mirtorch/linear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
__all__ = ['LinearMap',
'Identity',
'Diff1d',
'Diffnd'
'Diffnd',
'Diff2dframe', 'Diff3dframe',
'Diag',
'Convolve1d', 'Convolve2d', 'Convolve3d',
Expand Down
4 changes: 2 additions & 2 deletions mirtorch/linear/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ class Diffnd(LinearMap):
def __init__(self,
size_in: Sequence[int],
dims: Sequence[int]):
self.dims = dims
self.dims = sorted(dims)
size_out = copy.copy(list(size_in))
size_out[self.dims[0]] = size_out[self.dims[0]] * 2
size_out[self.dims[0]] = size_out[self.dims[0]] * len(dims)
super(Diffnd, self).__init__(size_in, size_out)

def _apply(self, x):
Expand Down
6 changes: 3 additions & 3 deletions mirtorch/linear/mri.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ def __init__(self,
size_out = list(smaps.shape)
dims = tuple(np.arange(2, len(smaps.shape)))
self.masks = masks.unsqueeze(1)
assert size_in == list(self.masks.shape), "size of sensitivity maps and mask not matched!"
assert smaps.shape[2:] == masks.shape[1:], "size of sensitivity maps and mask not matched!"
else:
size_in = list(smaps.shape[1:])
size_out = list(smaps.shape)
dims = tuple(np.arange(1, len(smaps.shape)))
self.masks = masks
assert size_in == list(self.masks.shape), "size of sensitivity maps and mask not matched!"
assert smaps.shape[1:] == masks.shape, "size of sensitivity maps and mask not matched!"
super(Sense, self).__init__(size_in, size_out)
self.norm = norm
self.dims = dims
Expand Down Expand Up @@ -246,7 +246,7 @@ def _apply_adjoint(self, y: Tensor) -> Tensor:
if self.batchmode:
return self.AT(y, self.traj, smaps=self.smaps, norm=self.norm)
else:
return self.AT(y.unsqueeze(0), self.traj, smaps=self.smaps.unsqueeze(0), norm=self.norm).squeeze(0)
return self.AT(y.unsqueeze(0), self.traj, smaps=self.smaps.unsqueeze(0), norm=self.norm).squeeze(0).squeeze(0)


class NuSenseFrame(LinearMap):
Expand Down
2 changes: 1 addition & 1 deletion mirtorch/linear/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def forward(ctx, x, dim, mode):
return finitediff(x, dim, mode)

@staticmethod
def backward(ctx, dx, mode):
def backward(ctx, dx):
return finitediff_adj(dx, ctx.dim, ctx.mode), None, None


Expand Down
8 changes: 4 additions & 4 deletions mirtorch/prox/prox.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class L1Regularizer(Prox):
def __init__(self, Lambda, T: LinearMap = None, P: LinearMap = None):
super().__init__(T, P)
self.Lambda = float(Lambda)
assert self.Lambda > 0, "alpha should be greater than 0"
assert self.Lambda >= 0, f'alpha should be greater than 0, the alpha here is {alpha}.'
if P is not None:
self.Lambda = P(Lambda * torch.ones(P.size_in))

Expand All @@ -94,7 +94,7 @@ def _softshrink(self, x, lambd):
return mask1.float() * (-lambd) + mask1.float() * x + mask2.float() * lambd + mask2.float() * x

def _apply(self, v, alpha):
assert alpha >= 0, "alpha should be greater than 0"
assert alpha >= 0, f'alpha should be greater than 0, the alpha here is {alpha}.'
if type(self.Lambda) is not torch.Tensor and type(alpha) is not torch.Tensor:
# The softshrink function do not support tensor as Lambda.
thresh = torch.nn.Softshrink(self.Lambda * alpha)
Expand Down Expand Up @@ -131,7 +131,7 @@ def _hardshrink(self,
return mask1 * x + mask2 * x

def _apply(self, v: torch.Tensor, alpha: FloatLike):
assert alpha >= 0, "alpha should be greater than 0"
assert alpha >= 0, f'alpha should be greater than 0, the alpha here is {alpha}.'
if type(self.Lambda) is not torch.Tensor and type(alpha) is not torch.Tensor:
# The hardthreshold function do not support tensor as Lambda.
thresh = torch.nn.Hardshrink(self.Lambda * alpha)
Expand Down Expand Up @@ -290,6 +290,6 @@ def __init__(self, prox: Prox):
super().__init__()

def _apply(self, v, alpha):
assert alpha > 0, "alpha should be greater than 0"
assert alpha >= 0, f'alpha should be greater than 0, the alpha here is {alpha}.'
return v - alpha * self.prox(v / alpha, 1 / alpha)

0 comments on commit e2c0083

Please sign in to comment.