Skip to content

Commit

Permalink
Merge pull request #9 from ziatdinovmax/master
Browse files Browse the repository at this point in the history
Add joint (discrete + continuous) VAE
  • Loading branch information
ziatdinovmax authored Feb 26, 2021
2 parents 88cd357 + e2f9ad1 commit 7fb6dcd
Show file tree
Hide file tree
Showing 28 changed files with 1,906 additions and 563 deletions.
4 changes: 0 additions & 4 deletions .gitpod.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,2 @@
image:
file: .gitpod.Dockerfile
ports:
- port: 8080
tasks:
- command: jupyter lab --no-browser --port 8080
2 changes: 1 addition & 1 deletion atomai/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version = '0.6.2'
version = '0.6.5'
9 changes: 7 additions & 2 deletions atomai/losses_metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from .losses import dice_loss, focal_loss, select_loss, vae_loss, rvae_loss
from .losses import dice_loss, focal_loss, select_loss
from .metrics import IoU
from .vi_losses import (joint_rvae_loss, joint_vae_loss, kld_discrete,
kld_normal, kld_rot, rvae_loss, vae_loss,
reconstruction_loss)

__all__ = ['focal_loss', 'dice_loss', 'select_seg_loss', 'IoU']
__all__ = ['focal_loss', 'dice_loss', 'select_loss', "vae_loss",
"rvae_loss", "joint_vae_loss", "joint_rvae_loss", "IoU",
"kld_normal", "kld_discrete", "kld_rot", "reconstruction_loss"]
74 changes: 1 addition & 73 deletions atomai/losses_metrics/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
Custom Pytorch loss functions
"""
from typing import Tuple
import numpy as np
import torch
import torch.nn.functional as F

Expand Down Expand Up @@ -89,79 +87,9 @@ def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
return (1 - dice_loss)


def vae_loss(reconstr_loss: str,
in_dim: Tuple[int],
x: torch.Tensor,
x_reconstr: torch.Tensor,
*args: torch.Tensor
) -> torch.Tensor:
"""
Calculates ELBO
"""
batch_dim = x.size(0)
if len(args) == 2:
z_mean, z_logsd = args
else:
z_mean = z_logsd = torch.zeros((batch_dim, 1))
z_sd = torch.exp(z_logsd)
if reconstr_loss == "mse":
reconstr_error = -0.5 * torch.sum(
(x_reconstr.reshape(batch_dim, -1) - x.reshape(batch_dim, -1))**2, 1).mean()
elif reconstr_loss == "ce":
px_size = np.product(in_dim)
rs = (np.product(in_dim[:2]),)
if len(in_dim) == 3:
rs = rs + (in_dim[-1],)
reconstr_error = -F.binary_cross_entropy_with_logits(
x_reconstr.reshape(-1, *rs), x.reshape(-1, *rs)) * px_size
else:
raise NotImplementedError("Reconstruction loss must be 'mse' or 'ce'")
kl_z = -z_logsd + 0.5 * z_sd**2 + 0.5 * z_mean**2 - 0.5
kl_z = torch.sum(kl_z, 1).mean()
return reconstr_error - kl_z


def rvae_loss(reconstr_loss: str,
in_dim: Tuple[int],
x: torch.Tensor,
x_reconstr: torch.Tensor,
*args: torch.Tensor,
**kwargs: float) -> torch.Tensor:
"""
Calculates ELBO
"""
batch_dim = x.size(0)
if len(args) == 2:
z_mean, z_logsd = args
else:
z_mean = z_logsd = torch.zeros((batch_dim, 1))
phi_prior = kwargs.get("phi_prior", 0.1)
z_sd = torch.exp(z_logsd)
phi_sd, phi_logsd = z_sd[:, 0], z_logsd[:, 0]
z_mean, z_sd, z_logsd = z_mean[:, 1:], z_sd[:, 1:], z_logsd[:, 1:]
batch_dim = x.size(0)
if reconstr_loss == "mse":
reconstr_error = -0.5 * torch.sum(
(x_reconstr.view(batch_dim, -1) - x.view(batch_dim, -1))**2, 1).mean()
elif reconstr_loss == "ce":
px_size = np.product(in_dim)
rs = (np.product(in_dim[:2]),)
if len(in_dim) == 3:
rs = rs + (in_dim[-1],)
reconstr_error = -F.binary_cross_entropy_with_logits(
x_reconstr.view(-1, *rs), x.view(-1, *rs)) * px_size
else:
raise NotImplementedError("Reconstruction loss must be 'mse' or 'ce'")
kl_rot = (-phi_logsd + np.log(phi_prior) +
phi_sd**2 / (2 * phi_prior**2) - 0.5)
kl_z = -z_logsd + 0.5 * z_sd**2 + 0.5 * z_mean**2 - 0.5
kl_div = (kl_rot + torch.sum(kl_z, 1)).mean()
return reconstr_error - kl_div


def select_loss(loss: str, nb_classes: int = None):
"""
Selects loss for a semantic segmentation model training
Selects loss for DCNN model training
"""
if loss == 'ce' and nb_classes is None:
raise ValueError("For cross-entropy loss function, you must" +
Expand Down
254 changes: 254 additions & 0 deletions atomai/losses_metrics/vi_losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
"""
vae_losses.py
=========
Custom loss functions for Variational Autoencoders (VAEs)
"""
from typing import Tuple, List, Union, Optional
import numpy as np
import torch
import torch.nn.functional as F


def reconstruction_loss(loss_type: str,
in_dim: Tuple[int],
x: torch.Tensor,
x_reconstr: torch.Tensor,
logits: bool = True,
) -> torch.Tensor:
"""
Computes reconstruction loss (mse or cross-entropy)
without mean reduction (used in VAE objectives)
"""
batch_dim = x.size(0)
if loss_type == "mse":
reconstr_loss = 0.5 * torch.sum(
(x_reconstr.reshape(batch_dim, -1) - x.reshape(batch_dim, -1))**2, 1)
elif loss_type == "ce":
rs = (np.product(in_dim[:2]),)
if len(in_dim) == 3:
rs = rs + (in_dim[-1],)
xe = (F.binary_cross_entropy_with_logits if
logits else F.binary_cross_entropy)
reconstr_loss = xe(x_reconstr.reshape(-1, *rs), x.reshape(-1, *rs),
reduction='none').sum(-1)
else:
raise NotImplementedError("Reconstruction loss must be 'mse' or 'ce'")
return reconstr_loss


def kld_normal(q_param: Tuple[torch.Tensor],
p_param: Optional[Tuple[torch.Tensor]] = None
) -> torch.Tensor:
"""
Kullback–Leibler (KL) divergence between two normal distributions
"""
mu_1, log_sd_1 = q_param
sd_1 = torch.exp(log_sd_1)
if p_param is None:
# KL divergence b/w normal and standard normal distributions
kl = -log_sd_1 + 0.5 * sd_1**2 + 0.5 * mu_1**2 - 0.5
else:
mu_2, log_sd_2 = p_param
sd_2 = torch.exp(log_sd_2)
# KL divergence b/w two normal distributions
kl = (log_sd_2 - log_sd_1 +
0.5 * (sd_1**2 + (mu_1 - mu_2)**2) / sd_2**2 - 0.5)
return torch.sum(kl, -1)


def kld_discrete(alpha: torch.Tensor):
"""
Calculates the KL divergence between a Gumbel-Softmax distribution
and a uniform categorical distribution.
Args:
alpha:
Parameters of the Gumbel-Softmax distribution.
"""
eps = 1e-12
cat_dim = alpha.size(-1)
h1 = torch.log(alpha + eps)
h2 = np.log(1. / cat_dim + eps)
kld_loss = torch.mean(torch.sum(alpha * (h1 - h2), dim=1), dim=0)
return kld_loss.view(1)


def kld_rot(phi_prior: torch.Tensor, phi_logsd: torch.Tensor) -> torch.Tensor:
"""
Kullback–Leibler (KL) divergence for rotation latent variable
"""
phi_sd = torch.exp(phi_logsd)
kl_rot = (-phi_logsd + np.log(phi_prior) +
phi_sd**2 / (2 * phi_prior**2) - 0.5)
return kl_rot


def vae_loss(recon_loss: str,
in_dim: Tuple[int],
x: torch.Tensor,
x_reconstr: torch.Tensor,
*args: torch.Tensor,
) -> torch.Tensor:
"""
Calculates ELBO
"""
if len(args) == 2:
q_param = args
else:
raise ValueError(
"Pass mean and SD values of encoded distribution as args")
likelihood = -reconstruction_loss(recon_loss, in_dim, x, x_reconstr).mean()
kl_z = kld_normal(q_param).mean()
return likelihood - kl_z


def rvae_loss(recon_loss: str,
in_dim: Tuple[int],
x: torch.Tensor,
x_reconstr: torch.Tensor,
*args: torch.Tensor,
**kwargs: float) -> torch.Tensor:
"""
Calculates ELBO
"""
if len(args) == 2:
z_mean, z_logsd = args
else:
raise ValueError(
"Pass mean and SD values of encoded distribution as args")
phi_prior = kwargs.get("phi_prior", 0.1)
b1, b2 = kwargs.get("b1", 1), kwargs.get("b2", 1)
phi_logsd = z_logsd[:, 0]
z_mean, z_logsd = z_mean[:, 1:], z_logsd[:, 1:]
likelihood = -reconstruction_loss(recon_loss, in_dim, x, x_reconstr).mean()
kl_rot = kld_rot(phi_prior, phi_logsd).mean()
kl_z = kld_normal([z_mean, z_logsd]).mean()
kl_div = (b1*kl_z + b2 * kl_rot)
return likelihood - kl_div


def joint_vae_loss(recon_loss: str,
in_dim: Tuple[int],
x: torch.Tensor,
x_reconstr: torch.Tensor,
*args: torch.Tensor,
**kwargs: Union[List, int],
) -> torch.Tensor:
"""
Calculates joint ELBO for continuous and discrete variables
"""
if len(args) == 3:
z_mean, z_logsd, alphas = args
else:
raise ValueError(
"Pass continuous (mean, SD) and discrete (alphas) values" +
"of encoded distributions as args")

cont_capacity = kwargs.get("cont_capacity", [0.0, 5.0, 25000, 30])
disc_capacity = kwargs.get("disc_capacity", [0.0, 5.0, 25000, 30])
num_iter = kwargs.get("num_iter", 0)
disc_dims = [a.size(1) for a in alphas]

# Calculate reconstruction loss term
likelihood = -reconstruction_loss(recon_loss, in_dim, x, x_reconstr).mean()

# Calculate KL term for continuous latent variables
kl_cont_loss = kld_normal([z_mean, z_logsd]).mean()
# Calculate KL term for discrete latent variables
kl_disc = [kld_discrete(alpha) for alpha in alphas]
kl_disc_loss = torch.sum(torch.cat(kl_disc))

# Apply information capacity terms to contninuous and discrete channels
cargs = [kl_cont_loss, kl_disc_loss, cont_capacity,
disc_capacity, disc_dims, num_iter]
cont_capacity_loss, disc_capacity_loss = infocapacity(*cargs)

return likelihood - cont_capacity_loss - disc_capacity_loss


def joint_rvae_loss(recon_loss: str,
in_dim: Tuple[int],
x: torch.Tensor,
x_reconstr: torch.Tensor,
*args: torch.Tensor,
**kwargs: float) -> torch.Tensor:
"""
Calculates joint ELBO for continuous and discrete variables
"""
if len(args) == 3:
z_mean, z_logsd, alphas = args
else:
raise ValueError(
"Pass continuous (mean, SD) and discrete (alphas) values" +
"of encoded distributions as args")

phi_prior = kwargs.get("phi_prior", 0.1)
klrot_cap = kwargs.get("klrot_cap", True)
cont_capacity = kwargs.get("cont_capacity", [0.0, 5.0, 25000, 30])
disc_capacity = kwargs.get("disc_capacity", [0.0, 5.0, 25000, 30])
num_iter = kwargs.get("num_iter", 0)

# Calculate reconstruction loss term
likelihood = -reconstruction_loss(recon_loss, in_dim, x, x_reconstr).mean()

# Calculate KL term for continuous latent variables
phi_logsd = z_logsd[:, 0] # rotation
z_mean, z_logsd = z_mean[:, 1:], z_logsd[:, 1:] # image content
kl_rot = kld_rot(phi_prior, phi_logsd).mean()
kl_z = kld_normal([z_mean, z_logsd]).mean()
if klrot_cap:
kl_cont_loss = kl_z + kl_rot
else: # no capacity limit on KL term associated with rotations
kl_cont_loss = kl_z

# Calculate KL term for discrete latent variables
disc_dims = [a.size(1) for a in alphas]
kl_disc = [kld_discrete(alpha) for alpha in alphas]
kl_disc_loss = torch.sum(torch.cat(kl_disc))

# Apply information capacity terms to contninuous and discrete channels
cargs = [kl_cont_loss, kl_disc_loss, cont_capacity,
disc_capacity, disc_dims, num_iter]
cont_capacity_loss, disc_capacity_loss = infocapacity(*cargs)
if not klrot_cap:
cont_capacity_loss = cont_capacity_loss + kl_rot

return likelihood - cont_capacity_loss - disc_capacity_loss


def infocapacity(kl_cont_loss: torch.Tensor,
kl_disc_loss: torch.Tensor,
cont_capacity: List[float],
disc_capacity: List[float],
disc_dims: List[int],
num_iter: int) -> torch.Tensor:
"""
Controls information capacity of the continuous and discrete loss
(based on https://arxiv.org/pdf/1804.00104.pdf &
https://github.com/Schlumberger/joint-vae/blob/master/jointvae/training.py)
"""
# Linearly increase capacity of continuous channels
cont_min, cont_max, cont_num_iters, cont_gamma = cont_capacity
# Increase continuous capacity without exceeding cont_max
cont_cap_current = (cont_max - cont_min) * num_iter
cont_cap_current = cont_cap_current / float(cont_num_iters) + cont_min
cont_cap_current = min(cont_cap_current, cont_max)
# Calculate continuous capacity loss
cont_capacity_loss = cont_gamma*torch.abs(cont_cap_current - kl_cont_loss)

# Linearly increase capacity of discrete channels
disc_min, disc_max, disc_num_iters, disc_gamma = disc_capacity
# Increase discrete capacity without exceeding disc_max or theoretical
# maximum (i.e. sum of log of dimension of each discrete variable)
disc_cap_current = (disc_max - disc_min) * num_iter
disc_cap_current = disc_cap_current / float(disc_num_iters) + disc_min
disc_cap_current = min(disc_cap_current, disc_max)
# Require float conversion here to not end up with numpy float
disc_theory_max = sum([float(np.log(d)) for d in disc_dims])
disc_cap_current = min(disc_cap_current, disc_theory_max)
# Calculate discrete capacity loss
disc_capacity_loss = disc_gamma*torch.abs(disc_cap_current - kl_disc_loss)

return cont_capacity_loss, disc_capacity_loss

4 changes: 2 additions & 2 deletions atomai/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .segmentor import Segmentor
from .imspec import ImSpec
from .vae import BaseVAE, VAE, rVAE
from .dgm import BaseVAE, VAE, rVAE, jVAE, jrVAE
from .loaders import load_model, load_ensemble

__all__ = ["Segmentor", "ImSpec", "BaseVAE", "VAE", "rVAE",
"load_model", "load_ensemble"]
"jVAE", "jrVAE", "load_model", "load_ensemble"]
6 changes: 6 additions & 0 deletions atomai/models/dgm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .vae import BaseVAE, VAE
from .rvae import rVAE
from .jvae import jVAE
from .jrvae import jrVAE

__all__ = ["BaseVAE", "VAE", "rVAE", "jVAE", "jrVAE"]
Loading

0 comments on commit 7fb6dcd

Please sign in to comment.