Skip to content

Commit

Permalink
Removing unneccesary y parameter. Rejigged deep models to separate lo…
Browse files Browse the repository at this point in the history
…sses where they have composite parts. Big simplification to DVCCA
  • Loading branch information
jameschapman19 committed Feb 28, 2022
1 parent 16f712a commit b848907
Show file tree
Hide file tree
Showing 18 changed files with 157 additions and 188 deletions.
14 changes: 7 additions & 7 deletions cca_zoo/deepmodels/dcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ class DCCA(_DCCA_base):
"""

def __init__(
self,
latent_dims: int,
objective=objectives.MCCA,
encoders=None,
r: float = 0,
eps: float = 1e-5,
self,
latent_dims: int,
objective=objectives.MCCA,
encoders=None,
r: float = 0,
eps: float = 1e-5,
):
"""
Constructor class for DCCA
Expand Down Expand Up @@ -53,7 +53,7 @@ def loss(self, *args):
:return:
"""
z = self(*args)
return self.objective.loss(*z)
return {'objective': self.objective.loss(*z)}

def post_transform(self, z_list, train=False):
if train:
Expand Down
10 changes: 5 additions & 5 deletions cca_zoo/deepmodels/dcca_barlow_twins.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ class BarlowTwins(DCCA):
"""

def __init__(
self,
latent_dims: int,
encoders: Iterable[BaseEncoder] = [Encoder, Encoder],
lam=1,
self,
latent_dims: int,
encoders: Iterable[BaseEncoder] = [Encoder, Encoder],
lam=1,
):
"""
Constructor class for Barlow Twins
Expand Down Expand Up @@ -48,4 +48,4 @@ def loss(self, *args):
covariance = torch.sum(
torch.triu(torch.pow(cross_cov, 2), diagonal=1)
) + torch.sum(torch.tril(torch.pow(cross_cov, 2), diagonal=-1))
return invariance + covariance
return {'objective': invariance + covariance, 'invariance': invariance, 'covariance': covariance}
2 changes: 1 addition & 1 deletion cca_zoo/deepmodels/dcca_noi.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def loss(self, *args):
covariance_inv = [mat_pow(cov, -0.5, self.eps) for cov in self.covs]
preds = [z_ @ covariance_inv[i] for i, z_ in enumerate(z_copy)]
loss = self.mse(z[0], preds[1]) + self.mse(z[1], preds[0])
return loss
return {'objective':loss}

def _update_covariances(self, *z, train=True):
b = z[0].shape[0]
Expand Down
21 changes: 11 additions & 10 deletions cca_zoo/deepmodels/dcca_sdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ class DCCA_SDL(DCCA_NOI):
"""

def __init__(
self,
latent_dims: int,
N: int,
encoders=None,
r: float = 0,
rho: float = 0.2,
eps: float = 1e-5,
shared_target: bool = False,
lam=0.5,
self,
latent_dims: int,
N: int,
encoders=None,
r: float = 0,
rho: float = 0.2,
eps: float = 1e-5,
shared_target: bool = False,
lam=0.5,
):
"""
Constructor class for DCCA
Expand Down Expand Up @@ -64,7 +64,8 @@ def loss(self, *args):
self._update_covariances(*z, train=self.training)
SDL_loss = self._sdl_loss(self.covs)
l2_loss = F.mse_loss(z[0], z[1])
return l2_loss + self.lam * SDL_loss
loss = l2_loss + self.lam * SDL_loss
return {'objective': loss, 'l2':l2_loss, 'sdl':SDL_loss}

def _sdl_loss(self, covs):
loss = 0
Expand Down
19 changes: 10 additions & 9 deletions cca_zoo/deepmodels/dccae.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ class DCCAE(_DCCA_base):
"""

def __init__(
self,
latent_dims: int,
objective=objectives.MCCA,
encoders=None,
decoders=None,
r: float = 0,
eps: float = 1e-5,
lam=0.5,
self,
latent_dims: int,
objective=objectives.MCCA,
encoders=None,
decoders=None,
r: float = 0,
eps: float = 1e-5,
lam=0.5,
):
"""
:param latent_dims: # latent dimensions
Expand Down Expand Up @@ -75,7 +75,8 @@ def loss(self, *args):
z = self(*args)
recon = self._decode(*z)
recon_loss = self._recon_loss(args[: len(recon)], recon)
return self.lam * recon_loss + self.objective.loss(*z)
return {'objective': self.lam * recon_loss + (1 - self.lam) * self.objective.loss(*z),
'reconstruction': recon_loss, 'correlation loss': self.objective.loss(*z)}

@staticmethod
def _recon_loss(x, recon):
Expand Down
148 changes: 56 additions & 92 deletions cca_zoo/deepmodels/dvcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ class DVCCA(_DCCA_base):
"""

def __init__(
self,
latent_dims: int,
encoders=None,
decoders=None,
private_encoders: Iterable[BaseEncoder] = None,
self,
latent_dims: int,
encoders=None,
decoders=None,
private_encoders: Iterable[BaseEncoder] = None,
):
"""
:param latent_dims: # latent dimensions
Expand All @@ -53,25 +53,33 @@ def forward(self, *args, mle=True):
:param mle:
:return:
"""
z = dict()
mu = dict()
logvar = dict()
# Used when we get reconstructions
mu, logvar = self._encode(*args)
if mle:
z = mu
else:
z_dist = dist.Normal(mu, torch.exp(0.5 * logvar))
z = z_dist.rsample()
mu['shared'], logvar['shared'] = self._encode(*args)
z['shared'] = self._sample(mu['shared'], logvar['shared'], mle)
# If using single encoder repeat representation n times
if len(self.encoders) == 1:
z = z * len(args)
z['shared'] = z['shared'] * len(args)
if self.private_encoders:
mu_p, logvar_p = self._encode_private(*args)
if mle:
z_p = mu_p
else:
z_dist = dist.Normal(mu_p, torch.exp(0.5 * logvar_p))
z_p = z_dist.rsample()
z = [torch.cat((z_, z_p_), dim=-1) for z_, z_p_ in zip(z, z_p)]
return z
mu['private'], logvar['private'] = self._encode_private(*args)
z['private'] = self._sample(mu['private'], logvar['private'], mle)
return z, mu, logvar

def _sample(self, mu, logvar, mle):
"""
:param mu:
:param logvar:
:param mle: whether to return the maximum likelihood (i.e. mean) or whether to sample
:return: a sample from latent vector
"""
if mle:
return mu
else:
return [dist.Normal(mu_, torch.exp(0.5 * logvar_)).rsample() for mu_, logvar_ in
zip(mu, logvar)]

def _encode(self, *args):
"""
Expand Down Expand Up @@ -106,92 +114,48 @@ def _decode(self, z):
"""
x = []
for i, decoder in enumerate(self.decoders):
x_i = F.sigmoid(decoder(z))
if 'private' in z:
x_i = F.sigmoid(decoder(torch.cat((z['shared'][i], z['private'][i]),dim=-1)))
else:
x_i = F.sigmoid(decoder(z['shared'][i]))
x.append(x_i)
return x

def recon(self, *args):
def recon(self, *args, mle=True):
"""
:param args:
:return:
"""
z = self(*args)
return [self._decode(z_) for z_ in z]
z, _, _ = self(*args, mle=mle)
return self._decode(z)

def loss(self, *args):
"""
:param args:
:return:
"""
mus, logvars = self._encode(*args)
if self.private_encoders:
mus_p, logvars_p = self._encode_private(*args)
losses = [
self.vcca_private_loss(
*args, mu=mu, logvar=logvar, mu_p=mu_p, logvar_p=logvar_p
)
for (mu, logvar, mu_p, logvar_p) in zip(mus, logvars, mus_p, logvars_p)
]
else:
losses = [
self.vcca_loss(*args, mu=mu, logvar=logvar)
for (mu, logvar) in zip(mus, logvars)
]
return torch.stack(losses).mean()

def vcca_loss(self, *args, mu, logvar):
"""
:param args:
:param mu:
:param logvar:
:return:
"""
batch_n = mu.shape[0]
z_dist = dist.Normal(mu, torch.exp(0.5 * logvar))
z = z_dist.rsample()
kl = torch.mean(
-0.5 * torch.sum(1 + logvar - logvar.exp() - mu.pow(2), dim=1), dim=0
)
recons = self._decode(z)
bces = torch.stack(
z, mu, logvar = self(*args, mle=False)
loss = dict()
loss['reconstruction'] = self.recon_loss(args, z)
loss['kl shared'] = self.kl_loss(mu['shared'], logvar['shared'])
if 'private' in z:
loss['kl private'] = self.kl_loss(mu['private'], logvar['private'])
loss['objective'] = torch.stack(tuple(loss.values())).sum()
return loss

@staticmethod
def kl_loss(mu, logvar):
return torch.stack([torch.mean(
-0.5
* torch.sum(1 + logvar_ - logvar_.exp() - mu_.pow(2), dim=1),
dim=0,
) for mu_, logvar_ in zip(mu, logvar)]).sum()

def recon_loss(self, x, z):
recon = self._decode(z)
return torch.stack(
[
F.binary_cross_entropy(recon, arg, reduction="mean")
for recon, arg in zip(recons, args)
]
).sum()
return kl + bces

def vcca_private_loss(self, *args, mu, logvar, mu_p, logvar_p):
"""
:param args:
:param mu:
:param logvar:
:return:
"""
batch_n = mu.shape[0]
z_dist = dist.Normal(mu, torch.exp(0.5 * logvar))
z = z_dist.rsample()
z_p_dist = dist.Normal(mu_p, torch.exp(0.5 * logvar_p))
z_p = z_p_dist.rsample()
kl_p = torch.stack(
[
torch.mean(
-0.5
* torch.sum(1 + logvar_p - logvar_p.exp() - mu_p.pow(2), dim=1),
dim=0,
)
for i, _ in enumerate(self.private_encoders)
]
).sum()
kl = torch.mean(
-0.5 * torch.sum(1 + logvar - logvar.exp() - mu.pow(2), dim=1), dim=0
)
z_combined = torch.cat([z, z_p], dim=-1)
recon = self._decode(z_combined)
bces = torch.stack(
[
F.binary_cross_entropy(recon[i], args[i], reduction="sum") / batch_n
for i, _ in enumerate(self.decoders)
for recon, arg in zip(recon, x)
]
).sum()
return kl + kl_p + bces
2 changes: 1 addition & 1 deletion cca_zoo/deepmodels/splitae.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def loss(self, *args):
z = self(*args)
recon = self._decode(*z)
recon_loss = self.recon_loss(args, recon)
return recon_loss
return {'objective':recon_loss}

@staticmethod
def recon_loss(x, recon):
Expand Down
Loading

0 comments on commit b848907

Please sign in to comment.