Skip to content

Commit

Permalink
Pulled out callbacks for more customisation.
Browse files Browse the repository at this point in the history
  • Loading branch information
jameschapman19 committed May 31, 2022
1 parent a5beb39 commit 55addf4
Show file tree
Hide file tree
Showing 39 changed files with 791 additions and 750 deletions.
48 changes: 24 additions & 24 deletions cca_zoo/data/simulated.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@


def generate_covariance_data(
n: int,
view_features: List[int],
latent_dims: int = 1,
view_sparsity: List[Union[int, float]] = None,
correlation: Union[List[float], float] = 1,
structure: Union[str, List[str]] = None,
sigma: Union[List[float], float] = None,
decay: float = 0.5,
positive=None,
random_state: Union[int, np.random.RandomState] = None,
n: int,
view_features: List[int],
latent_dims: int = 1,
view_sparsity: List[Union[int, float]] = None,
correlation: Union[List[float], float] = 1,
structure: Union[str, List[str]] = None,
sigma: Union[List[float], float] = None,
decay: float = 0.5,
positive=None,
random_state: Union[int, np.random.RandomState] = None,
):
"""
Function to generate CCA dataset with defined population correlations
Expand Down Expand Up @@ -59,7 +59,7 @@ def generate_covariance_data(
covs = []
true_features = []
for view_p, sparsity, view_structure, view_positive, view_sigma in zip(
view_features, view_sparsity, structure, positive, sigma
view_features, view_sparsity, structure, positive, sigma
):
# Covariance Bit
if view_structure == "identity":
Expand Down Expand Up @@ -111,12 +111,12 @@ def generate_covariance_data(
# Cross Bit
cross += covs[i] @ A @ covs[j]
cov[
splits[i]: splits[i] + view_features[i],
splits[j]: splits[j] + view_features[j],
splits[i] : splits[i] + view_features[i],
splits[j] : splits[j] + view_features[j],
] = cross
cov[
splits[j]: splits[j] + view_features[j],
splits[i]: splits[i] + view_features[i],
splits[j] : splits[j] + view_features[j],
splits[i] : splits[i] + view_features[i],
] = cross.T

X = np.zeros((n, sum(view_features)))
Expand All @@ -136,12 +136,12 @@ def generate_covariance_data(


def generate_simple_data(
n: int,
view_features: List[int],
view_sparsity: List[Union[int, float]] = None,
eps: float = 0,
transform=False,
random_state=None,
n: int,
view_features: List[int],
view_sparsity: List[Union[int, float]] = None,
eps: float = 0,
transform=False,
random_state=None,
):
"""
Simple latent variable model to generate data with one latent factor
Expand Down Expand Up @@ -202,9 +202,9 @@ def _gaussian(x, mu, sig, dn):
:param dn:
"""
return (
np.exp(-np.power(x - mu, 2.0) / (2 * np.power(sig, 2.0)))
* dn
/ (np.sqrt(2 * np.pi) * sig)
np.exp(-np.power(x - mu, 2.0) / (2 * np.power(sig, 2.0)))
* dn
/ (np.sqrt(2 * np.pi) * sig)
)


Expand Down
66 changes: 33 additions & 33 deletions cca_zoo/deepmodels/_architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ def forward(self, x):

class Encoder(BaseEncoder):
def __init__(
self,
latent_dims: int,
variational: bool = False,
feature_size: int = 784,
layer_sizes: tuple = None,
activation=nn.LeakyReLU(),
dropout=0,
self,
latent_dims: int,
variational: bool = False,
feature_size: int = 784,
layer_sizes: tuple = None,
activation=nn.LeakyReLU(),
dropout=0,
):
super(Encoder, self).__init__(latent_dims, variational=variational)
if layer_sizes is None:
Expand Down Expand Up @@ -80,12 +80,12 @@ def forward(self, x):

class Decoder(BaseDecoder):
def __init__(
self,
latent_dims: int,
feature_size: int = 784,
layer_sizes: tuple = None,
activation=nn.LeakyReLU(),
dropout=0,
self,
latent_dims: int,
feature_size: int = 784,
layer_sizes: tuple = None,
activation=nn.LeakyReLU(),
dropout=0,
):
super(Decoder, self).__init__(latent_dims)
if layer_sizes is None:
Expand All @@ -109,16 +109,16 @@ def forward(self, x):

class CNNEncoder(BaseEncoder):
def __init__(
self,
latent_dims: int,
variational: bool = False,
feature_size: Iterable = (28, 28),
channels: tuple = None,
kernel_sizes: tuple = None,
stride: tuple = None,
padding: tuple = None,
activation=nn.LeakyReLU(),
dropout=0,
self,
latent_dims: int,
variational: bool = False,
feature_size: Iterable = (28, 28),
channels: tuple = None,
kernel_sizes: tuple = None,
stride: tuple = None,
padding: tuple = None,
activation=nn.LeakyReLU(),
dropout=0,
):
super(CNNEncoder, self).__init__(latent_dims, variational=variational)
if channels is None:
Expand Down Expand Up @@ -187,15 +187,15 @@ def forward(self, x):

class CNNDecoder(BaseDecoder):
def __init__(
self,
latent_dims: int,
feature_size: Iterable = (28, 28),
channels: tuple = None,
kernel_sizes=None,
strides=None,
paddings=None,
activation=nn.LeakyReLU(),
dropout=0,
self,
latent_dims: int,
feature_size: Iterable = (28, 28),
channels: tuple = None,
kernel_sizes=None,
strides=None,
paddings=None,
activation=nn.LeakyReLU(),
dropout=0,
):
super(CNNDecoder, self).__init__(latent_dims)
if channels is None:
Expand All @@ -210,7 +210,7 @@ def __init__(
current_channels = 1
current_size = feature_size[0]
for l_id, (channel, kernel, stride, padding) in reversed(
list(enumerate(zip(channels, kernel_sizes, strides, paddings)))
list(enumerate(zip(channels, kernel_sizes, strides, paddings)))
):
conv_layers.append(
torch.nn.Sequential(
Expand Down
42 changes: 21 additions & 21 deletions cca_zoo/deepmodels/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,19 @@

class _BaseDeep(pl.LightningModule):
def __init__(
self,
latent_dims: int,
optimizer="adam",
scheduler=None,
lr=1e-3,
weight_decay=0,
extra_optimizer_kwargs=None,
max_epochs=1000,
min_lr=1e-9,
lr_decay_steps=None,
correlation=True,
*args,
**kwargs,
self,
latent_dims: int,
optimizer="adam",
scheduler=None,
lr=1e-3,
weight_decay=0,
extra_optimizer_kwargs=None,
max_epochs=1000,
min_lr=1e-9,
lr_decay_steps=None,
correlation=True,
*args,
**kwargs,
):
super().__init__()
if extra_optimizer_kwargs is None:
Expand Down Expand Up @@ -80,9 +80,9 @@ def test_step(self, batch, batch_idx):
return loss["objective"]

def transform(
self,
loader: torch.utils.data.DataLoader,
train=False,
self,
loader: torch.utils.data.DataLoader,
train=False,
):
"""
:param loader: a dataloader that matches the structure of that used for training
Expand Down Expand Up @@ -140,21 +140,21 @@ def configure_callbacks(self):


class _GenerativeMixin:
def recon_loss(self, x, recon, loss='mse', reduction='mean', **kwargs):
def recon_loss(self, x, recon, loss="mse", reduction="mean", **kwargs):
if loss == "mse":
return self.mse_loss(x, recon, reduction=reduction)
elif loss == "bce":
return self.mse_loss(x, recon, reduction=reduction)
elif loss == "nll":
return self.mse_loss(x, recon, reduction=reduction)

def mse_loss(self, x, recon, reduction='mean'):
def mse_loss(self, x, recon, reduction="mean"):
return F.mse_loss(recon, x, reduction=reduction)

def bce_loss(self, x, recon, reduction='mean'):
def bce_loss(self, x, recon, reduction="mean"):
return F.binary_cross_entropy(recon, x, reduction=reduction)

def nll_loss(self, x, recon, reduction='mean'):
def nll_loss(self, x, recon, reduction="mean"):
return F.nll_loss(recon, x, reduction=reduction)

@staticmethod
Expand Down Expand Up @@ -189,4 +189,4 @@ def collate_all(z, z_):
z = [np.append(z[i], z_i, axis=0) for i, z_i in enumerate(z_)]
else:
z = np.append(z, z_, axis=0)
return z
return z
4 changes: 2 additions & 2 deletions cca_zoo/deepmodels/_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class GenerativeCallback(Callback):
def on_validation_epoch_end(
self, trainer: Trainer, pl_module: LightningModule
) -> None:
if hasattr(pl_module, 'img_dim') and pl_module.img_dim is not None:
if hasattr(pl_module, "img_dim") and pl_module.img_dim is not None:
z = dict()
z["shared"] = Variable(torch.randn(64, pl_module.latent_dims))
if pl_module.private_encoders:
Expand All @@ -51,4 +51,4 @@ def on_validation_epoch_end(
)
pl_module.logger.experiment.add_image(
"generated_images_2", grid2, pl_module.current_epoch
)
)
26 changes: 13 additions & 13 deletions cca_zoo/deepmodels/_dcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ class DCCA(_BaseDeep):
"""

def __init__(
self,
latent_dims: int,
objective=_objectives.MCCA,
encoders=None,
r: float = 0,
eps: float = 1e-5,
**kwargs,
self,
latent_dims: int,
objective=_objectives.MCCA,
encoders=None,
r: float = 0,
eps: float = 1e-5,
**kwargs,
):
"""
Constructor class for DCCA
Expand Down Expand Up @@ -66,9 +66,9 @@ def post_transform(self, z, train=False):
return z

def batch_correlation(
self,
loader: torch.utils.data.DataLoader,
train=False,
self,
loader: torch.utils.data.DataLoader,
train=False,
):
"""
Expand All @@ -80,15 +80,15 @@ def batch_correlation(
pair_corrs = []
for x, y in itertools.product(transformed_views, repeat=2):
pair_corrs.append(
np.diag(np.corrcoef(x.T, y.T)[: x.shape[1], y.shape[1]:])
np.diag(np.corrcoef(x.T, y.T)[: x.shape[1], y.shape[1] :])
)
pair_corrs = np.array(pair_corrs).reshape(
(len(transformed_views), len(transformed_views), -1)
)
n_views = pair_corrs.shape[0]
dim_corrs = (
pair_corrs.sum(axis=tuple(range(pair_corrs.ndim - 1))) - n_views
) / (n_views ** 2 - n_views)
pair_corrs.sum(axis=tuple(range(pair_corrs.ndim - 1))) - n_views
) / (n_views ** 2 - n_views)
return dim_corrs

def configure_callbacks(self):
Expand Down
10 changes: 5 additions & 5 deletions cca_zoo/deepmodels/_dcca_barlow_twins.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ class BarlowTwins(DCCA):
"""

def __init__(
self,
latent_dims: int,
encoders=None,
lam=1,
**kwargs,
self,
latent_dims: int,
encoders=None,
lam=1,
**kwargs,
):
"""
Constructor class for Barlow Twins
Expand Down
20 changes: 10 additions & 10 deletions cca_zoo/deepmodels/_dcca_noi.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ class DCCA_NOI(DCCA):
"""

def __init__(
self,
latent_dims: int,
N: int,
encoders=None,
r: float = 0,
rho: float = 0.2,
eps: float = 1e-9,
shared_target: bool = False,
**kwargs,
self,
latent_dims: int,
N: int,
encoders=None,
r: float = 0,
rho: float = 0.2,
eps: float = 1e-9,
shared_target: bool = False,
**kwargs,
):
"""
Constructor class for DCCA
Expand Down Expand Up @@ -61,7 +61,7 @@ def forward(self, views, **kwargs):
z = []
# Users architecture + final linear layer
for i, (encoder, linear_layer) in enumerate(
zip(self.encoders, self.linear_layers)
zip(self.encoders, self.linear_layers)
):
z.append(linear_layer(encoder(views[i])))
return z
Expand Down
Loading

0 comments on commit 55addf4

Please sign in to comment.