Skip to content

Commit

Permalink
adding back certain y parameters to fix scikit-learn integration
Browse files Browse the repository at this point in the history
  • Loading branch information
jameschapman19 committed Feb 28, 2022
1 parent b848907 commit 0ce2061
Show file tree
Hide file tree
Showing 18 changed files with 291 additions and 253 deletions.
14 changes: 7 additions & 7 deletions cca_zoo/deepmodels/dcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ class DCCA(_DCCA_base):
"""

def __init__(
self,
latent_dims: int,
objective=objectives.MCCA,
encoders=None,
r: float = 0,
eps: float = 1e-5,
self,
latent_dims: int,
objective=objectives.MCCA,
encoders=None,
r: float = 0,
eps: float = 1e-5,
):
"""
Constructor class for DCCA
Expand Down Expand Up @@ -53,7 +53,7 @@ def loss(self, *args):
:return:
"""
z = self(*args)
return {'objective': self.objective.loss(*z)}
return {"objective": self.objective.loss(*z)}

def post_transform(self, z_list, train=False):
if train:
Expand Down
14 changes: 9 additions & 5 deletions cca_zoo/deepmodels/dcca_barlow_twins.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ class BarlowTwins(DCCA):
"""

def __init__(
self,
latent_dims: int,
encoders: Iterable[BaseEncoder] = [Encoder, Encoder],
lam=1,
self,
latent_dims: int,
encoders: Iterable[BaseEncoder] = [Encoder, Encoder],
lam=1,
):
"""
Constructor class for Barlow Twins
Expand Down Expand Up @@ -48,4 +48,8 @@ def loss(self, *args):
covariance = torch.sum(
torch.triu(torch.pow(cross_cov, 2), diagonal=1)
) + torch.sum(torch.tril(torch.pow(cross_cov, 2), diagonal=-1))
return {'objective': invariance + covariance, 'invariance': invariance, 'covariance': covariance}
return {
"objective": invariance + covariance,
"invariance": invariance,
"covariance": covariance,
}
2 changes: 1 addition & 1 deletion cca_zoo/deepmodels/dcca_noi.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def loss(self, *args):
covariance_inv = [mat_pow(cov, -0.5, self.eps) for cov in self.covs]
preds = [z_ @ covariance_inv[i] for i, z_ in enumerate(z_copy)]
loss = self.mse(z[0], preds[1]) + self.mse(z[1], preds[0])
return {'objective':loss}
return {"objective": loss}

def _update_covariances(self, *z, train=True):
b = z[0].shape[0]
Expand Down
20 changes: 10 additions & 10 deletions cca_zoo/deepmodels/dcca_sdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ class DCCA_SDL(DCCA_NOI):
"""

def __init__(
self,
latent_dims: int,
N: int,
encoders=None,
r: float = 0,
rho: float = 0.2,
eps: float = 1e-5,
shared_target: bool = False,
lam=0.5,
self,
latent_dims: int,
N: int,
encoders=None,
r: float = 0,
rho: float = 0.2,
eps: float = 1e-5,
shared_target: bool = False,
lam=0.5,
):
"""
Constructor class for DCCA
Expand Down Expand Up @@ -65,7 +65,7 @@ def loss(self, *args):
SDL_loss = self._sdl_loss(self.covs)
l2_loss = F.mse_loss(z[0], z[1])
loss = l2_loss + self.lam * SDL_loss
return {'objective': loss, 'l2':l2_loss, 'sdl':SDL_loss}
return {"objective": loss, "l2": l2_loss, "sdl": SDL_loss}

def _sdl_loss(self, covs):
loss = 0
Expand Down
24 changes: 14 additions & 10 deletions cca_zoo/deepmodels/dccae.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ class DCCAE(_DCCA_base):
"""

def __init__(
self,
latent_dims: int,
objective=objectives.MCCA,
encoders=None,
decoders=None,
r: float = 0,
eps: float = 1e-5,
lam=0.5,
self,
latent_dims: int,
objective=objectives.MCCA,
encoders=None,
decoders=None,
r: float = 0,
eps: float = 1e-5,
lam=0.5,
):
"""
:param latent_dims: # latent dimensions
Expand Down Expand Up @@ -75,8 +75,12 @@ def loss(self, *args):
z = self(*args)
recon = self._decode(*z)
recon_loss = self._recon_loss(args[: len(recon)], recon)
return {'objective': self.lam * recon_loss + (1 - self.lam) * self.objective.loss(*z),
'reconstruction': recon_loss, 'correlation loss': self.objective.loss(*z)}
return {
"objective": self.lam * recon_loss
+ (1 - self.lam) * self.objective.loss(*z),
"reconstruction": recon_loss,
"correlation loss": self.objective.loss(*z),
}

@staticmethod
def _recon_loss(x, recon):
Expand Down
58 changes: 33 additions & 25 deletions cca_zoo/deepmodels/dvcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ class DVCCA(_DCCA_base):
"""

def __init__(
self,
latent_dims: int,
encoders=None,
decoders=None,
private_encoders: Iterable[BaseEncoder] = None,
self,
latent_dims: int,
encoders=None,
decoders=None,
private_encoders: Iterable[BaseEncoder] = None,
):
"""
:param latent_dims: # latent dimensions
Expand Down Expand Up @@ -57,14 +57,14 @@ def forward(self, *args, mle=True):
mu = dict()
logvar = dict()
# Used when we get reconstructions
mu['shared'], logvar['shared'] = self._encode(*args)
z['shared'] = self._sample(mu['shared'], logvar['shared'], mle)
mu["shared"], logvar["shared"] = self._encode(*args)
z["shared"] = self._sample(mu["shared"], logvar["shared"], mle)
# If using single encoder repeat representation n times
if len(self.encoders) == 1:
z['shared'] = z['shared'] * len(args)
z["shared"] = z["shared"] * len(args)
if self.private_encoders:
mu['private'], logvar['private'] = self._encode_private(*args)
z['private'] = self._sample(mu['private'], logvar['private'], mle)
mu["private"], logvar["private"] = self._encode_private(*args)
z["private"] = self._sample(mu["private"], logvar["private"], mle)
return z, mu, logvar

def _sample(self, mu, logvar, mle):
Expand All @@ -78,8 +78,10 @@ def _sample(self, mu, logvar, mle):
if mle:
return mu
else:
return [dist.Normal(mu_, torch.exp(0.5 * logvar_)).rsample() for mu_, logvar_ in
zip(mu, logvar)]
return [
dist.Normal(mu_, torch.exp(0.5 * logvar_)).rsample()
for mu_, logvar_ in zip(mu, logvar)
]

def _encode(self, *args):
"""
Expand Down Expand Up @@ -114,10 +116,12 @@ def _decode(self, z):
"""
x = []
for i, decoder in enumerate(self.decoders):
if 'private' in z:
x_i = F.sigmoid(decoder(torch.cat((z['shared'][i], z['private'][i]),dim=-1)))
if "private" in z:
x_i = F.sigmoid(
decoder(torch.cat((z["shared"][i], z["private"][i]), dim=-1))
)
else:
x_i = F.sigmoid(decoder(z['shared'][i]))
x_i = F.sigmoid(decoder(z["shared"][i]))
x.append(x_i)
return x

Expand All @@ -136,20 +140,24 @@ def loss(self, *args):
"""
z, mu, logvar = self(*args, mle=False)
loss = dict()
loss['reconstruction'] = self.recon_loss(args, z)
loss['kl shared'] = self.kl_loss(mu['shared'], logvar['shared'])
if 'private' in z:
loss['kl private'] = self.kl_loss(mu['private'], logvar['private'])
loss['objective'] = torch.stack(tuple(loss.values())).sum()
loss["reconstruction"] = self.recon_loss(args, z)
loss["kl shared"] = self.kl_loss(mu["shared"], logvar["shared"])
if "private" in z:
loss["kl private"] = self.kl_loss(mu["private"], logvar["private"])
loss["objective"] = torch.stack(tuple(loss.values())).sum()
return loss

@staticmethod
def kl_loss(mu, logvar):
return torch.stack([torch.mean(
-0.5
* torch.sum(1 + logvar_ - logvar_.exp() - mu_.pow(2), dim=1),
dim=0,
) for mu_, logvar_ in zip(mu, logvar)]).sum()
return torch.stack(
[
torch.mean(
-0.5 * torch.sum(1 + logvar_ - logvar_.exp() - mu_.pow(2), dim=1),
dim=0,
)
for mu_, logvar_ in zip(mu, logvar)
]
).sum()

def recon_loss(self, x, z):
recon = self._decode(z)
Expand Down
2 changes: 1 addition & 1 deletion cca_zoo/deepmodels/splitae.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def loss(self, *args):
z = self(*args)
recon = self._decode(*z)
recon_loss = self.recon_loss(args, recon)
return {'objective':recon_loss}
return {"objective": recon_loss}

@staticmethod
def recon_loss(x, recon):
Expand Down
44 changes: 22 additions & 22 deletions cca_zoo/deepmodels/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

class CCALightning(LightningModule):
def __init__(
self,
model: _DCCA_base,
optimizer: torch.optim.Optimizer = None,
lr_scheduler: torch.optim.lr_scheduler = None,
self,
model: _DCCA_base,
optimizer: torch.optim.Optimizer = None,
lr_scheduler: torch.optim.lr_scheduler = None,
):
"""
Expand Down Expand Up @@ -49,17 +49,17 @@ def configure_optimizers(self):
def training_step(self, batch, batch_idx):
loss = self.model.loss(*batch["views"])
self.log("train", loss, prog_bar=True)
return loss['objective']
return loss["objective"]

def validation_step(self, batch, batch_idx):
loss = self.model.loss(*batch["views"])
self.log("val", loss)
return loss['objective']
return loss["objective"]

def test_step(self, batch, batch_idx):
loss = self.model.loss(*batch["views"])
self.log("test", loss)
return loss['objective']
return loss["objective"]

def on_train_epoch_end(self, unused: Optional = None) -> None:
score = self.score(self.trainer.train_dataloader, train=True).sum()
Expand All @@ -75,9 +75,9 @@ def on_validation_epoch_end(self, unused: Optional = None) -> None:
self.log("val corr", score)

def correlations(
self,
loader: torch.utils.data.DataLoader,
train=False,
self,
loader: torch.utils.data.DataLoader,
train=False,
):
"""
Expand All @@ -89,16 +89,16 @@ def correlations(
return None
all_corrs = []
for x, y in itertools.product(transformed_views, repeat=2):
all_corrs.append(np.diag(np.corrcoef(x.T, y.T)[: x.shape[1], y.shape[1]:]))
all_corrs.append(np.diag(np.corrcoef(x.T, y.T)[: x.shape[1], y.shape[1] :]))
all_corrs = np.array(all_corrs).reshape(
(len(transformed_views), len(transformed_views), -1)
)
return all_corrs

def transform(
self,
loader: torch.utils.data.DataLoader,
train=False,
self,
loader: torch.utils.data.DataLoader,
train=False,
):
"""
Expand All @@ -111,7 +111,7 @@ def transform(
views = [view.to(self.device) for view in batch["views"]]
z = self.model(*views)
if isinstance(z[0], dict):
z = z[0]['shared']
z = z[0]["shared"]
if batch_idx == 0:
z_list = [z_i.detach().cpu().numpy() for i, z_i in enumerate(z)]
else:
Expand All @@ -123,9 +123,9 @@ def transform(
return z_list

def score(
self,
loader: torch.utils.data.DataLoader,
train=False,
self,
loader: torch.utils.data.DataLoader,
train=False,
):
"""
Expand All @@ -140,13 +140,13 @@ def score(
n_views = pair_corrs.shape[0]
# sum all the pairwise correlations for each dimension. Subtract the self correlations. Divide by the number of views. Gives average correlation
dim_corrs = (
pair_corrs.sum(axis=tuple(range(pair_corrs.ndim - 1))) - n_views
) / (n_views ** 2 - n_views)
pair_corrs.sum(axis=tuple(range(pair_corrs.ndim - 1))) - n_views
) / (n_views ** 2 - n_views)
return dim_corrs

def recon(
self,
loader: torch.utils.data.DataLoader,
self,
loader: torch.utils.data.DataLoader,
):
with torch.no_grad():
for batch_idx, batch in enumerate(loader):
Expand Down
3 changes: 1 addition & 2 deletions cca_zoo/model_selection/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@ def fit(self, Xs, y=None, *, groups=None, **fit_params):
self._check_refit_for_multimetric(scorers)
refit_metric = self.refit

Xs[0], y, groups = indexable(Xs[0], y, groups)
fit_params = _check_fit_params(Xs[0], fit_params)

cv_orig = check_cv(self.cv, y, classifier=is_classifier(estimator))
Expand Down Expand Up @@ -286,7 +285,7 @@ def evaluate_candidates(candidate_params, cv=None, more_results=None):
("estimator", clone(base_estimator)),
]
)

pipeline.fit(np.hstack(Xs))
out = parallel(
delayed(_fit_and_score)(
pipeline,
Expand Down
Loading

0 comments on commit 0ce2061

Please sign in to comment.