This repo contains code for quickly and easily implementing multimodal variational autoencoders (VAEs).
$ python main.py --help
import torch
import torch.nn as nn
from src.encoders_decoders import GatherLayer, NetworkList, SplitLinearLayer
from src.likelihoods import GroupedLikelihood, BernoulliLikelihood
from src.objectives import StandardElbo
from src.priors import StandardGaussianPrior
from src.variational_posteriors import DiagonalGaussianPosterior
from src.variational_strategies import GaussianPoeStrategy
# Make a VAE with two modalities, both with 392 dimensions, and a 20-dimensional
# latent space. The VAE is simply a collection of different pieces, with each
# piece subclassing `torch.nn.Module`.
latent_dim, m_dim = 20, 392
vae = nn.ModuleDict({
'encoder': NetworkList(
nn.ModuleList([
nn.Sequential(
nn.Linear(m_dim,200),
nn.ReLU(),
SplitLinearLayer(200, (latent_dim,latent_dim)),
),
nn.Sequential(
nn.Linear(m_dim,200),
nn.ReLU(),
SplitLinearLayer(200, (latent_dim,latent_dim)),
),
])
),
'variational_strategy': GaussianPoeStrategy(),
'variational_posterior': DiagonalGaussianPosterior(),
'decoder': nn.Sequential(
nn.Linear(latent_dim,200),
nn.ReLU(),
SplitLinearLayer(200, (m_dim,m_dim)),
GatherLayer(),
),
'likelihood': GroupedLikelihood(
BernoulliLikelihood(),
BernoulliLikelihood(),
),
'prior': StandardGaussianPrior(),
})
# Feed the VAE to an objective. The objective determines how data is routed
# through the various VAE pieces to determine a loss. Objectives also subclass
# `torch.nn.Module`.
objective = StandardElbo(vae)
# Train the VAE like any other PyTorch model.
loader = make_dataloader(...)
optimizer = torch.optim.Adam(objective.parameters())
for epoch in range(100):
for batch in loader:
objective.zero_grad()
loss = objective(batch)
loss.backward()
optimizer.step()
- MVAE
--variational-strategy=gaussian_poe
--variational-posterior=diag_gaussian
--prior=standard_gaussian
--objective=mvae_elbo
- MMVAE
--variational-strategy=gaussian_moe
--variational-posterior=diag_gaussian_mixture
--prior=standard_gaussian
--objective=mmvae_elbo
- s-VAE (originally a single modality VAE)
--variational-strategy=vmf_poe
--variational-posterior=vmf_product
--prior=uniform_hyperspherical
--objective=elbo
- MIWAE
--unstructured-encoder=True
--variational-posterior=diag_gaussian
--prior=standard_gaussian
--objective=elbo
- partial VAE TO DO
--variational-strategy=permutation_invariant
--variational-posterior=diag_gaussian
--prior=standard_gaussian
--objective=elbo
- VAEVAE?
- MoPoE VAE?
Check out src/datasets/
for some examples of how to do this. To use the
existing training framework, you will also have to modify DATASET_MAP
and
MODEL_MAP
in src/param_maps.py
.
- Python3 (3.6+)
- PyTorch (1.6+)
- Python Fire (only used in
main.py
)
- MVAE repo, uses a product of experts strategy for combining evidence across modalities.
- MMVAE repo, uses a mixture of experts strategy for combining evidence across modalities.
- Hyperspherical VAE repo, a VAE with a latent space defined on an n-sphere with von Mises-Fisher-distributed approximate posteriors.
- Validation set for early stopping
- Implement STL gradients?
- Student experts?
- Compare network architectures w/ other papers
- partial-VAE implementation
- Add a documentation markdown file
- Implement jackknife variational inference?
- AR-ELBO for vMF
- Double check unstructured recognition models work
- Is there an easy way for Encoder and DecoderModalityEmbeddings to share parameters?
- Test vMF KL divergence
- Why is MVAE performing poorly?