diff --git a/.gitignore b/.gitignore index 3bb42d66..fae93184 100644 --- a/.gitignore +++ b/.gitignore @@ -51,6 +51,8 @@ examples/notebooks/my_model_with_custom_archi/ examples/notebooks/my_model/ examples/net.py examples/scripts/configs/* +examples/scripts/reproducibility/reproducibility/* + diff --git a/README.md b/README.md index d1de5597..03f3ec0e 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,15 @@ provides the possibility to perform benchmark experiments and comparisons by tra the models with the same autoencoding neural network architecture. The feature *make your own autoencoder* allows you to train any of these models with your own data and own Encoder and Decoder neural networks. It integrates an experiment monitoring tool [wandb](https://wandb.ai/) 🧪 and allows model sharing and loading from the [HuggingFace Hub](https://huggingface.co/models) 🤗 in a few lines of code. +## Quick access: +- [Installation](#installation) +- [Implemented models](#available-models) / [Implemented samplers](#available-samplers) +- [Reproducibility statement](#reproducibility) / [Results flavor](#results) +- [Model training](#launching-a-model-training) / [Data generation](#launching-data-generation) / [Custom network architectures](#define-you-own-autoencoder-architecture) +- [Model sharing with 🤗 Hub](#sharing-your-models-with-the-huggingface-hub-) / [Experiment tracking with `wandb`](#monitoring-your-experiments-with-wandb-) +- [Tutorials](#getting-your-hands-on-the-code) / [Documentation](https://pythae.readthedocs.io/en/latest/) +- [Contributing 🚀](#contributing-) / [Issues 🛠️](#dealing-with-issues-%EF%B8%8F) +- [Citing this repository](#citation) # Installation @@ -81,7 +90,8 @@ VAE with Inverse Autoregressive Flows (VAE_IAF) | [![Open In Colab](https://col | Wasserstein Autoencoder (WAE) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/clementchadebec/benchmark_VAE/blob/main/examples/notebooks/models_training/wae_training.ipynb) | [link](https://arxiv.org/abs/1711.01558) | [link](https://github.com/tolstikhin/wae) | | Info Variational Autoencoder (INFOVAE_MMD) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/clementchadebec/benchmark_VAE/blob/main/examples/notebooks/models_training/info_vae_training.ipynb) | [link](https://arxiv.org/abs/1706.02262) | | | VAMP Autoencoder (VAMP) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/clementchadebec/benchmark_VAE/blob/main/examples/notebooks/models_training/vamp_training.ipynb) | [link](https://arxiv.org/abs/1705.07120) | [link](https://github.com/jmtomczak/vae_vampprior) | -| Hyperspherical VAE (SVAE) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/clementchadebec/benchmark_VAE/blob/main/examples/notebooks/models_training/svae_training.ipynb) | [link](https://arxiv.org/abs/1804.00891) | [link](https://github.com/nicola-decao/s-vae-pytorch) | +| Hyperspherical VAE (SVAE) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/clementchadebec/benchmark_VAE/blob/main/examples/notebooks/models_training/svae_training.ipynb) | [link](https://arxiv.org/abs/1804.00891) | [link](https://github.com/nicola-decao/s-vae-pytorch) +| Poincaré Disk VAE (PoincareVAE) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/clementchadebec/benchmark_VAE/blob/main/examples/notebooks/models_training/pvae_training.ipynb) | [link](https://arxiv.org/abs/1901.06033) | [link](https://github.com/emilemathieu/pvae) | | Adversarial Autoencoder (Adversarial_AE) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/clementchadebec/benchmark_VAE/blob/main/examples/notebooks/models_training/adversarial_ae_training.ipynb) | [link](https://arxiv.org/abs/1511.05644) | Variational Autoencoder GAN (VAEGAN) 🥗 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/clementchadebec/benchmark_VAE/blob/main/examples/notebooks/models_training/vaegan_training.ipynb) | [link](https://arxiv.org/abs/1512.09300) | [link](https://github.com/andersbll/autoencoding_beyond_pixels)| [link](https://arxiv.org/abs/1512.09300) | [link](https://github.com/andersbll/autoencoding_beyond_pixels) | Vector Quantized VAE (VQVAE) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/clementchadebec/benchmark_VAE/blob/main/examples/notebooks/models_training/vqvae_training.ipynb) | [link](https://arxiv.org/abs/1711.00937) | [link](https://github.com/deepmind/sonnet/blob/v2/sonnet/) @@ -102,6 +112,7 @@ Below is the list of the models currently implemented in the library. | Gaussian mixture (GaussianMixtureSampler) | all models | [link](https://arxiv.org/abs/1903.12436) | [link](https://github.com/ParthaEth/Regularized_autoencoders-RAE-/tree/master/models/rae) | | Two stage VAE sampler (TwoStageVAESampler) | all VAE based models| [link](https://openreview.net/pdf?id=B1e0X3C9tQ) | [link](https://github.com/daib13/TwoStageVAE/) |) | Unit sphere uniform sampler (HypersphereUniformSampler) | SVAE | [link](https://arxiv.org/abs/1804.00891) | [link](https://github.com/nicola-decao/s-vae-pytorch) +| Poincaré Disk sampler (PoincareDiskSampler) | PoincareVAE | [link](https://arxiv.org/abs/1901.06033) | [link](https://github.com/emilemathieu/pvae) | VAMP prior sampler (VAMPSampler) | VAMP | [link](https://arxiv.org/abs/1705.07120) | [link](https://github.com/jmtomczak/vae_vampprior) | | Manifold sampler (RHVAESampler) | RHVAE | [link](https://arxiv.org/abs/2105.00026) | [link](https://github.com/clementchadebec/pyraug)| | Masked Autoregressive Flow Sampler (MAFSampler) | all models | [link](https://arxiv.org/abs/1705.07057v4) | [link](https://github.com/gpapamak/maf) | diff --git a/docs/source/models/autoencoders/models.rst b/docs/source/models/autoencoders/models.rst index d2480d38..d7a92630 100644 --- a/docs/source/models/autoencoders/models.rst +++ b/docs/source/models/autoencoders/models.rst @@ -22,6 +22,7 @@ Autoencoders infovae vamp svae + pvae aae vaegan vqvae @@ -53,6 +54,7 @@ Available Models ~pythae.models.INFOVAE_MMD ~pythae.models.VAMP ~pythae.models.SVAE + ~pythae.models.PoincareVAE ~pythae.models.Adversarial_AE ~pythae.models.VAEGAN ~pythae.models.VQVAE diff --git a/docs/source/models/autoencoders/pvae.rst b/docs/source/models/autoencoders/pvae.rst new file mode 100644 index 00000000..c59a2c1a --- /dev/null +++ b/docs/source/models/autoencoders/pvae.rst @@ -0,0 +1,13 @@ +********************************** +PoincareVAE +********************************** + + +.. automodule:: + pythae.models.pvae + +.. autoclass:: pythae.models.PoincareVAEConfig + :members: + +.. autoclass:: pythae.models.PoincareVAE + :members: \ No newline at end of file diff --git a/docs/source/models/pythae.models.rst b/docs/source/models/pythae.models.rst index 2da47100..fcf81b67 100644 --- a/docs/source/models/pythae.models.rst +++ b/docs/source/models/pythae.models.rst @@ -36,6 +36,7 @@ Available Autoencoders ~pythae.models.INFOVAE_MMD ~pythae.models.VAMP ~pythae.models.SVAE + ~pythae.models.PoincareVAE ~pythae.models.Adversarial_AE ~pythae.models.VAEGAN ~pythae.models.VQVAE diff --git a/docs/source/samplers/poincare_disk_sampler.rst b/docs/source/samplers/poincare_disk_sampler.rst new file mode 100644 index 00000000..accd54c7 --- /dev/null +++ b/docs/source/samplers/poincare_disk_sampler.rst @@ -0,0 +1,9 @@ +********************************** +PoincareDiskSampler +********************************** + +.. automodule:: + pythae.samplers.pvae_sampler + +.. autoclass:: pythae.samplers.PoincareDiskSampler + :members: \ No newline at end of file diff --git a/docs/source/samplers/pythae.samplers.rst b/docs/source/samplers/pythae.samplers.rst index 1f08be89..2f7f5994 100644 --- a/docs/source/samplers/pythae.samplers.rst +++ b/docs/source/samplers/pythae.samplers.rst @@ -11,6 +11,7 @@ Samplers gmm_sampler twostage_sampler unit_sphere_unif_sampler + poincare_disk_sampler vamp_sampler rhvae_sampler maf_sampler @@ -28,6 +29,7 @@ Samplers ~pythae.samplers.GaussianMixtureSampler ~pythae.samplers.TwoStageVAESampler ~pythae.samplers.HypersphereUniformSampler + ~pythae.samplers.PoincareDiskSampler ~pythae.samplers.VAMPSampler ~pythae.samplers.RHVAESampler ~pythae.samplers.MAFSampler diff --git a/examples/notebooks/models_training/adversarial_ae_training.ipynb b/examples/notebooks/models_training/adversarial_ae_training.ipynb index a28d80d4..f19ba326 100644 --- a/examples/notebooks/models_training/adversarial_ae_training.ipynb +++ b/examples/notebooks/models_training/adversarial_ae_training.ipynb @@ -16,8 +16,11 @@ "metadata": {}, "outputs": [], "source": [ + "import torch\n", "import torchvision.datasets as datasets\n", "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", "%load_ext autoreload\n", "%autoreload 2" ] @@ -253,7 +256,7 @@ "metadata": {}, "outputs": [], "source": [ - "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + "reconstructions = trained_model.reconstruct(eval_dataset[:25].to(device)).detach().cpu()" ] }, { @@ -301,7 +304,7 @@ "metadata": {}, "outputs": [], "source": [ - "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + "interpolations = trained_model.interpolate(eval_dataset[:5].to(device), eval_dataset[5:10].to(device), granularity=10).detach().cpu()" ] }, { diff --git a/examples/notebooks/models_training/ae_training.ipynb b/examples/notebooks/models_training/ae_training.ipynb index 68cc252c..a569e9ab 100644 --- a/examples/notebooks/models_training/ae_training.ipynb +++ b/examples/notebooks/models_training/ae_training.ipynb @@ -16,8 +16,11 @@ "metadata": {}, "outputs": [], "source": [ + "import torch\n", "import torchvision.datasets as datasets\n", "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", "%load_ext autoreload\n", "%autoreload 2" ] @@ -252,7 +255,7 @@ "metadata": {}, "outputs": [], "source": [ - "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + "reconstructions = trained_model.reconstruct(eval_dataset[:25].to(device)).detach().cpu()" ] }, { @@ -300,7 +303,7 @@ "metadata": {}, "outputs": [], "source": [ - "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + "interpolations = trained_model.interpolate(eval_dataset[:5].to(device), eval_dataset[5:10].to(device), granularity=10).detach().cpu()" ] }, { diff --git a/examples/notebooks/models_training/beta_tc_vae_training.ipynb b/examples/notebooks/models_training/beta_tc_vae_training.ipynb index cf82e9a2..3cbeacdf 100644 --- a/examples/notebooks/models_training/beta_tc_vae_training.ipynb +++ b/examples/notebooks/models_training/beta_tc_vae_training.ipynb @@ -16,8 +16,11 @@ "metadata": {}, "outputs": [], "source": [ + "import torch\n", "import torchvision.datasets as datasets\n", "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", "%load_ext autoreload\n", "%autoreload 2" ] @@ -256,7 +259,7 @@ "metadata": {}, "outputs": [], "source": [ - "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + "reconstructions = trained_model.reconstruct(eval_dataset[:25].to(device)).detach().cpu()" ] }, { @@ -304,7 +307,7 @@ "metadata": {}, "outputs": [], "source": [ - "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + "interpolations = trained_model.interpolate(eval_dataset[:5].to(device), eval_dataset[5:10].to(device), granularity=10).detach().cpu()" ] }, { diff --git a/examples/notebooks/models_training/beta_vae_training.ipynb b/examples/notebooks/models_training/beta_vae_training.ipynb index f1fff78c..42df1b16 100644 --- a/examples/notebooks/models_training/beta_vae_training.ipynb +++ b/examples/notebooks/models_training/beta_vae_training.ipynb @@ -16,8 +16,11 @@ "metadata": {}, "outputs": [], "source": [ + "import torch\n", "import torchvision.datasets as datasets\n", "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", "%load_ext autoreload\n", "%autoreload 2" ] @@ -254,7 +257,7 @@ "metadata": {}, "outputs": [], "source": [ - "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + "reconstructions = trained_model.reconstruct(eval_dataset[:25].to(device)).detach().cpu()" ] }, { @@ -302,7 +305,7 @@ "metadata": {}, "outputs": [], "source": [ - "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + "interpolations = trained_model.interpolate(eval_dataset[:5].to(device), eval_dataset[5:10].to(device), granularity=10).detach().cpu()" ] }, { diff --git a/examples/notebooks/models_training/disentangled_beta_vae_training.ipynb b/examples/notebooks/models_training/disentangled_beta_vae_training.ipynb index 58bb2eaf..dd63129e 100644 --- a/examples/notebooks/models_training/disentangled_beta_vae_training.ipynb +++ b/examples/notebooks/models_training/disentangled_beta_vae_training.ipynb @@ -16,8 +16,11 @@ "metadata": {}, "outputs": [], "source": [ + "import torch\n", "import torchvision.datasets as datasets\n", "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", "%load_ext autoreload\n", "%autoreload 2" ] @@ -256,7 +259,7 @@ "metadata": {}, "outputs": [], "source": [ - "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + "reconstructions = trained_model.reconstruct(eval_dataset[:25].to(device)).detach().cpu()" ] }, { @@ -304,7 +307,7 @@ "metadata": {}, "outputs": [], "source": [ - "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + "interpolations = trained_model.interpolate(eval_dataset[:5].to(device), eval_dataset[5:10].to(device), granularity=10).detach().cpu()" ] }, { diff --git a/examples/notebooks/models_training/factor_vae_training.ipynb b/examples/notebooks/models_training/factor_vae_training.ipynb index 134c5a76..6d6b1953 100644 --- a/examples/notebooks/models_training/factor_vae_training.ipynb +++ b/examples/notebooks/models_training/factor_vae_training.ipynb @@ -16,8 +16,11 @@ "metadata": {}, "outputs": [], "source": [ + "import torch\n", "import torchvision.datasets as datasets\n", "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", "%load_ext autoreload\n", "%autoreload 2" ] @@ -254,7 +257,7 @@ "metadata": {}, "outputs": [], "source": [ - "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + "reconstructions = trained_model.reconstruct(eval_dataset[:25].to(device)).detach().cpu()" ] }, { @@ -302,7 +305,7 @@ "metadata": {}, "outputs": [], "source": [ - "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + "interpolations = trained_model.interpolate(eval_dataset[:5].to(device), eval_dataset[5:10].to(device), granularity=10).detach().cpu()" ] }, { @@ -340,7 +343,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.8.12" }, "orig_nbformat": 4 }, diff --git a/examples/notebooks/models_training/hvae_training.ipynb b/examples/notebooks/models_training/hvae_training.ipynb index 6323b1bd..f8f70e5f 100644 --- a/examples/notebooks/models_training/hvae_training.ipynb +++ b/examples/notebooks/models_training/hvae_training.ipynb @@ -16,8 +16,11 @@ "metadata": {}, "outputs": [], "source": [ + "import torch\n", "import torchvision.datasets as datasets\n", "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", "%load_ext autoreload\n", "%autoreload 2" ] @@ -255,7 +258,7 @@ "metadata": {}, "outputs": [], "source": [ - "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + "reconstructions = trained_model.reconstruct(eval_dataset[:25].to(device)).detach().cpu()" ] }, { @@ -303,7 +306,7 @@ "metadata": {}, "outputs": [], "source": [ - "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + "interpolations = trained_model.interpolate(eval_dataset[:5].to(device), eval_dataset[5:10].to(device), granularity=10).detach().cpu()" ] }, { @@ -341,7 +344,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.11" + "version": "3.8.12" }, "orig_nbformat": 4 }, diff --git a/examples/notebooks/models_training/info_vae_training.ipynb b/examples/notebooks/models_training/info_vae_training.ipynb index 72e1745c..656c6342 100644 --- a/examples/notebooks/models_training/info_vae_training.ipynb +++ b/examples/notebooks/models_training/info_vae_training.ipynb @@ -16,8 +16,11 @@ "metadata": {}, "outputs": [], "source": [ + "import torch\n", "import torchvision.datasets as datasets\n", "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", "%load_ext autoreload\n", "%autoreload 2" ] @@ -263,7 +266,7 @@ "metadata": {}, "outputs": [], "source": [ - "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + "reconstructions = trained_model.reconstruct(eval_dataset[:25].to(device)).detach().cpu()" ] }, { @@ -311,7 +314,7 @@ "metadata": {}, "outputs": [], "source": [ - "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + "interpolations = trained_model.interpolate(eval_dataset[:5].to(device), eval_dataset[5:10].to(device), granularity=10).detach().cpu()" ] }, { @@ -349,7 +352,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.11" + "version": "3.8.12" }, "orig_nbformat": 4 }, diff --git a/examples/notebooks/models_training/iwae_training.ipynb b/examples/notebooks/models_training/iwae_training.ipynb index cc2f4d81..e7904b21 100644 --- a/examples/notebooks/models_training/iwae_training.ipynb +++ b/examples/notebooks/models_training/iwae_training.ipynb @@ -16,8 +16,11 @@ "metadata": {}, "outputs": [], "source": [ + "import torch\n", "import torchvision.datasets as datasets\n", "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", "%load_ext autoreload\n", "%autoreload 2" ] @@ -253,7 +256,7 @@ "metadata": {}, "outputs": [], "source": [ - "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + "reconstructions = trained_model.reconstruct(eval_dataset[:25].to(device)).detach().cpu()" ] }, { @@ -301,7 +304,7 @@ "metadata": {}, "outputs": [], "source": [ - "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + "interpolations = trained_model.interpolate(eval_dataset[:5].to(device), eval_dataset[5:10].to(device), granularity=10).detach().cpu()" ] }, { @@ -339,7 +342,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.8.12" }, "orig_nbformat": 4 }, diff --git a/examples/notebooks/models_training/ms_ssim_vae_training.ipynb b/examples/notebooks/models_training/ms_ssim_vae_training.ipynb index b6238259..a3a5c63e 100644 --- a/examples/notebooks/models_training/ms_ssim_vae_training.ipynb +++ b/examples/notebooks/models_training/ms_ssim_vae_training.ipynb @@ -16,8 +16,11 @@ "metadata": {}, "outputs": [], "source": [ + "import torch\n", "import torchvision.datasets as datasets\n", "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", "%load_ext autoreload\n", "%autoreload 2" ] @@ -254,7 +257,7 @@ "metadata": {}, "outputs": [], "source": [ - "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + "reconstructions = trained_model.reconstruct(eval_dataset[:25].to(device)).detach().cpu()" ] }, { @@ -302,7 +305,7 @@ "metadata": {}, "outputs": [], "source": [ - "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + "interpolations = trained_model.interpolate(eval_dataset[:5].to(device), eval_dataset[5:10].to(device), granularity=10).detach().cpu()" ] }, { diff --git a/examples/notebooks/models_training/pvae_training.ipynb b/examples/notebooks/models_training/pvae_training.ipynb new file mode 100644 index 00000000..a02ec709 --- /dev/null +++ b/examples/notebooks/models_training/pvae_training.ipynb @@ -0,0 +1,441 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Install the library\n", + "%pip install pythae" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torchvision.datasets as datasets\n", + "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mnist_trainset = datasets.MNIST(root='../../data', train=True, download=True, transform=None)\n", + "\n", + "train_dataset = mnist_trainset.data[:-10000].reshape(-1, 1, 28, 28) / 255.\n", + "train_targets = mnist_trainset.targets[:-10000]\n", + "eval_dataset = mnist_trainset.data[-10000:].reshape(-1, 1, 28, 28) / 255.\n", + "eval_targets = mnist_trainset.targets[-10000:]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pythae.models import PoincareVAE, PoincareVAEConfig\n", + "from pythae.trainers import BaseTrainerConfig\n", + "from pythae.pipelines.training import TrainingPipeline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Let's define some custom Encoder/Decoder to stick to the paper proposal\n", + "import math\n", + "import numpy as np\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "from pythae.models.nn import BaseEncoder, BaseDecoder\n", + "from pythae.models.base.base_utils import ModelOutput\n", + "from pythae.models.pvae.pvae_utils import PoincareBall\n", + "\n", + "class RiemannianLayer(nn.Module):\n", + " def __init__(self, in_features, out_features, manifold, over_param, weight_norm):\n", + " super(RiemannianLayer, self).__init__()\n", + " self.in_features = in_features\n", + " self.out_features = out_features\n", + " self.manifold = manifold\n", + " self._weight = nn.Parameter(torch.Tensor(out_features, in_features))\n", + " self.over_param = over_param\n", + " self.weight_norm = weight_norm\n", + " self._bias = nn.Parameter(torch.Tensor(out_features, 1))\n", + " self.reset_parameters()\n", + "\n", + " @property\n", + " def weight(self):\n", + " return self.manifold.transp0(self.bias, self._weight) # weight \\in T_0 => weight \\in T_bias\n", + "\n", + " @property\n", + " def bias(self):\n", + " if self.over_param:\n", + " return self._bias\n", + " else:\n", + " return self.manifold.expmap0(self._weight * self._bias) # reparameterisation of a point on the manifold\n", + "\n", + " def reset_parameters(self):\n", + " nn.init.kaiming_normal_(self._weight, a=math.sqrt(5))\n", + " fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self._weight)\n", + " bound = 4 / math.sqrt(fan_in)\n", + " nn.init.uniform_(self._bias, -bound, bound)\n", + " if self.over_param:\n", + " with torch.no_grad(): self._bias.set_(self.manifold.expmap0(self._bias))\n", + "\n", + "class GeodesicLayer(RiemannianLayer):\n", + " def __init__(self, in_features, out_features, manifold, over_param=False, weight_norm=False):\n", + " super(GeodesicLayer, self).__init__(in_features, out_features, manifold, over_param, weight_norm)\n", + "\n", + " def forward(self, input):\n", + " input = input.unsqueeze(0)\n", + " input = input.unsqueeze(-2).expand(*input.shape[:-(len(input.shape) - 2)], self.out_features, self.in_features)\n", + " res = self.manifold.normdist2plane(input, self.bias, self.weight,\n", + " signed=True, norm=self.weight_norm)\n", + " return res\n", + "\n", + "### Define paper encoder network\n", + "class Encoder(BaseEncoder):\n", + " \"\"\" Usual encoder followed by an exponential map \"\"\"\n", + " def __init__(self, model_config, prior_iso=False):\n", + " super(Encoder, self).__init__()\n", + " self.manifold = PoincareBall(dim=model_config.latent_dim, c=model_config.curvature)\n", + " self.enc = nn.Sequential(\n", + " nn.Linear(np.prod(model_config.input_dim), 600), nn.ReLU(),\n", + " )\n", + " self.fc21 = nn.Linear(600, model_config.latent_dim)\n", + " self.fc22 = nn.Linear(600, model_config.latent_dim if not prior_iso else 1)\n", + "\n", + " def forward(self, x):\n", + " e = self.enc(x.reshape(x.shape[0], -1))\n", + " mu = self.fc21(e)\n", + " mu = self.manifold.expmap0(mu)\n", + " return ModelOutput(\n", + " embedding=mu,\n", + " log_covariance=torch.log(F.softplus(self.fc22(e)) + 1e-5), # expects log_covariance\n", + " log_concentration=torch.log(F.softplus(self.fc22(e)) + 1e-5) # for Riemannian Normal\n", + "\n", + " )\n", + "\n", + "### Define paper decoder network\n", + "class Decoder(BaseDecoder):\n", + " \"\"\" First layer is a Hypergyroplane followed by usual decoder \"\"\"\n", + " def __init__(self, model_config):\n", + " super(Decoder, self).__init__()\n", + " self.manifold = PoincareBall(dim=model_config.latent_dim, c=model_config.curvature)\n", + " self.input_dim = model_config.input_dim\n", + " self.dec = nn.Sequential(\n", + " GeodesicLayer(model_config.latent_dim, 600, self.manifold),\n", + " nn.ReLU(),\n", + " nn.Linear(600, np.prod(model_config.input_dim)),\n", + " nn.Sigmoid()\n", + " )\n", + "\n", + " def forward(self, z):\n", + " out = self.dec(z).reshape((z.shape[0],) + self.input_dim) # reshape data\n", + " return ModelOutput(\n", + " reconstruction=out\n", + " )\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "config = BaseTrainerConfig(\n", + " output_dir='my_model',\n", + " learning_rate=5e-4,\n", + " batch_size=100,\n", + " num_epochs=100, # Change this to train the model a bit more\n", + ")\n", + "\n", + "\n", + "model_config = PoincareVAEConfig(\n", + " input_dim=(1, 28, 28),\n", + " latent_dim=2,\n", + " reconstruction_loss=\"bce\",\n", + " prior_distribution=\"riemannian_normal\",\n", + " posterior_distribution=\"wrapped_normal\",\n", + " curvature=0.7\n", + ")\n", + "\n", + "model = PoincareVAE(\n", + " model_config=model_config,\n", + " encoder=Encoder(model_config), \n", + " decoder=Decoder(model_config) \n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pipeline = TrainingPipeline(\n", + " training_config=config,\n", + " model=model\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pipeline(\n", + " train_data=train_dataset,\n", + " eval_data=eval_dataset\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from pythae.models import AutoModel" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "last_training = sorted(os.listdir('my_model'))[-1]\n", + "trained_model = AutoModel.load_from_folder(os.path.join('my_model', last_training, 'final_model')).to(device)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualize latent space" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib\n", + "import seaborn as sns\n", + "import matplotlib.pyplot as plt\n", + "\n", + "colors = sns.color_palette('pastel')\n", + "\n", + "fig = plt.figure(figsize=(10,8))\n", + "\n", + "label = eval_targets\n", + "\n", + "torch.manual_seed(42)\n", + "idx = torch.randperm(len(eval_dataset))\n", + "with torch.no_grad():\n", + " mu = trained_model.encoder(eval_dataset.to(device)).embedding.detach().cpu()\n", + "plt.scatter(mu[:, 0], mu[:, 1], c=label, cmap=matplotlib.colors.ListedColormap(colors))\n", + "\n", + "cb = plt.colorbar()\n", + "loc = np.arange(0,max(label),max(label)/float(len(colors)))\n", + "cb.set_ticks(loc)\n", + "cb.set_ticklabels([f'{i}' for i in range(10)])\n", + "plt.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Generate data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pythae.samplers import PoincareDiskSampler" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# create normal sampler\n", + "pvae_samper = PoincareDiskSampler(\n", + " model=trained_model\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# sample\n", + "gen_data = pvae_samper.sample(\n", + " num_samples=25\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show results with normal sampler\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(gen_data[i*5 +j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## ... the other samplers work the same" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing reconstructions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "reconstructions = trained_model.reconstruct(eval_dataset[:25].to(device)).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show reconstructions\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(reconstructions[i*5 + j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show the true data\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(eval_dataset[i*5 +j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing interpolations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "interpolations = trained_model.interpolate(eval_dataset[:5].to(device), eval_dataset[5:10].to(device), granularity=10).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show interpolations\n", + "fig, axes = plt.subplots(nrows=5, ncols=10, figsize=(10, 5))\n", + "\n", + "for i in range(5):\n", + " for j in range(10):\n", + " axes[i][j].imshow(interpolations[i, j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "3efa06c4da850a09a4898b773c7e91b0da3286dbbffa369a8099a14a8fa43098" + }, + "kernelspec": { + "display_name": "Python 3.8.11 64-bit ('pythae_dev': conda)", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.12" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/notebooks/models_training/rae_gp_training.ipynb b/examples/notebooks/models_training/rae_gp_training.ipynb index 4e8433a0..56525274 100644 --- a/examples/notebooks/models_training/rae_gp_training.ipynb +++ b/examples/notebooks/models_training/rae_gp_training.ipynb @@ -16,8 +16,11 @@ "metadata": {}, "outputs": [], "source": [ + "import torch\n", "import torchvision.datasets as datasets\n", "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", "%load_ext autoreload\n", "%autoreload 2" ] @@ -253,7 +256,7 @@ "metadata": {}, "outputs": [], "source": [ - "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + "reconstructions = trained_model.reconstruct(eval_dataset[:25].to(device)).detach().cpu()" ] }, { @@ -301,7 +304,7 @@ "metadata": {}, "outputs": [], "source": [ - "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + "interpolations = trained_model.interpolate(eval_dataset[:5].to(device), eval_dataset[5:10].to(device), granularity=10).detach().cpu()" ] }, { @@ -339,7 +342,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.11" + "version": "3.8.12" }, "orig_nbformat": 4 }, diff --git a/examples/notebooks/models_training/rae_l2_training.ipynb b/examples/notebooks/models_training/rae_l2_training.ipynb index e90a62ca..69a7f4ce 100644 --- a/examples/notebooks/models_training/rae_l2_training.ipynb +++ b/examples/notebooks/models_training/rae_l2_training.ipynb @@ -16,8 +16,11 @@ "metadata": {}, "outputs": [], "source": [ + "import torch\n", "import torchvision.datasets as datasets\n", "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", "%load_ext autoreload\n", "%autoreload 2" ] @@ -255,7 +258,7 @@ "metadata": {}, "outputs": [], "source": [ - "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + "reconstructions = trained_model.reconstruct(eval_dataset[:25].to(device)).detach().cpu()" ] }, { @@ -303,7 +306,7 @@ "metadata": {}, "outputs": [], "source": [ - "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + "interpolations = trained_model.interpolate(eval_dataset[:5].to(device), eval_dataset[5:10].to(device), granularity=10).detach().cpu()" ] }, { @@ -341,7 +344,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.11" + "version": "3.8.12" }, "orig_nbformat": 4 }, diff --git a/examples/notebooks/models_training/rhvae_training.ipynb b/examples/notebooks/models_training/rhvae_training.ipynb index 085c92c2..e645b4d2 100644 --- a/examples/notebooks/models_training/rhvae_training.ipynb +++ b/examples/notebooks/models_training/rhvae_training.ipynb @@ -16,8 +16,11 @@ "metadata": {}, "outputs": [], "source": [ + "import torch\n", "import torchvision.datasets as datasets\n", "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", "%load_ext autoreload\n", "%autoreload 2" ] @@ -257,7 +260,7 @@ "metadata": {}, "outputs": [], "source": [ - "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + "reconstructions = trained_model.reconstruct(eval_dataset[:25].to(device)).detach().cpu()" ] }, { @@ -305,7 +308,7 @@ "metadata": {}, "outputs": [], "source": [ - "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + "interpolations = trained_model.interpolate(eval_dataset[:5].to(device), eval_dataset[5:10].to(device), granularity=10).detach().cpu()" ] }, { diff --git a/examples/notebooks/models_training/svae_training.ipynb b/examples/notebooks/models_training/svae_training.ipynb index 67a3770e..71fdbe4a 100644 --- a/examples/notebooks/models_training/svae_training.ipynb +++ b/examples/notebooks/models_training/svae_training.ipynb @@ -16,8 +16,11 @@ "metadata": {}, "outputs": [], "source": [ + "import torch\n", "import torchvision.datasets as datasets\n", "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", "%load_ext autoreload\n", "%autoreload 2" ] @@ -194,7 +197,7 @@ "metadata": {}, "outputs": [], "source": [ - "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + "reconstructions = trained_model.reconstruct(eval_dataset[:25].to(device)).detach().cpu()" ] }, { @@ -242,7 +245,7 @@ "metadata": {}, "outputs": [], "source": [ - "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + "interpolations = trained_model.interpolate(eval_dataset[:5].to(device), eval_dataset[5:10].to(device), granularity=10).detach().cpu()" ] }, { diff --git a/examples/notebooks/models_training/vae_iaf_training.ipynb b/examples/notebooks/models_training/vae_iaf_training.ipynb index e017fcb5..6c6f0c19 100644 --- a/examples/notebooks/models_training/vae_iaf_training.ipynb +++ b/examples/notebooks/models_training/vae_iaf_training.ipynb @@ -16,8 +16,11 @@ "metadata": {}, "outputs": [], "source": [ + "import torch\n", "import torchvision.datasets as datasets\n", "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", "%load_ext autoreload\n", "%autoreload 2" ] @@ -254,7 +257,7 @@ "metadata": {}, "outputs": [], "source": [ - "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + "reconstructions = trained_model.reconstruct(eval_dataset[:25].to(device)).detach().cpu()" ] }, { @@ -302,7 +305,7 @@ "metadata": {}, "outputs": [], "source": [ - "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + "interpolations = trained_model.interpolate(eval_dataset[:5].to(device), eval_dataset[5:10].to(device), granularity=10).detach().cpu()" ] }, { diff --git a/examples/notebooks/models_training/vae_lin_nf_training.ipynb b/examples/notebooks/models_training/vae_lin_nf_training.ipynb index 9632579c..892b64e0 100644 --- a/examples/notebooks/models_training/vae_lin_nf_training.ipynb +++ b/examples/notebooks/models_training/vae_lin_nf_training.ipynb @@ -16,8 +16,11 @@ "metadata": {}, "outputs": [], "source": [ + "import torch\n", "import torchvision.datasets as datasets\n", "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", "%load_ext autoreload\n", "%autoreload 2" ] @@ -253,7 +256,7 @@ "metadata": {}, "outputs": [], "source": [ - "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + "reconstructions = trained_model.reconstruct(eval_dataset[:25].to(device)).detach().cpu()" ] }, { @@ -301,7 +304,7 @@ "metadata": {}, "outputs": [], "source": [ - "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + "interpolations = trained_model.interpolate(eval_dataset[:5].to(device), eval_dataset[5:10].to(device), granularity=10).detach().cpu()" ] }, { diff --git a/examples/notebooks/models_training/vae_training.ipynb b/examples/notebooks/models_training/vae_training.ipynb index 91e16caf..0be27f45 100644 --- a/examples/notebooks/models_training/vae_training.ipynb +++ b/examples/notebooks/models_training/vae_training.ipynb @@ -16,8 +16,11 @@ "metadata": {}, "outputs": [], "source": [ + "import torch\n", "import torchvision.datasets as datasets\n", "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", "%load_ext autoreload\n", "%autoreload 2" ] @@ -252,7 +255,7 @@ "metadata": {}, "outputs": [], "source": [ - "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + "reconstructions = trained_model.reconstruct(eval_dataset[:25].to(device)).detach().cpu()" ] }, { @@ -300,7 +303,7 @@ "metadata": {}, "outputs": [], "source": [ - "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + "interpolations = trained_model.interpolate(eval_dataset[:5].to(device), eval_dataset[5:10].to(device), granularity=10).detach().cpu()" ] }, { @@ -338,7 +341,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.11" + "version": "3.8.12" }, "orig_nbformat": 4 }, diff --git a/examples/notebooks/models_training/vaegan_training.ipynb b/examples/notebooks/models_training/vaegan_training.ipynb index 1daa665d..0ab05e3b 100644 --- a/examples/notebooks/models_training/vaegan_training.ipynb +++ b/examples/notebooks/models_training/vaegan_training.ipynb @@ -16,8 +16,11 @@ "metadata": {}, "outputs": [], "source": [ + "import torch\n", "import torchvision.datasets as datasets\n", "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", "%load_ext autoreload\n", "%autoreload 2" ] @@ -258,7 +261,7 @@ "metadata": {}, "outputs": [], "source": [ - "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + "reconstructions = trained_model.reconstruct(eval_dataset[:25].to(device)).detach().cpu()" ] }, { @@ -306,7 +309,7 @@ "metadata": {}, "outputs": [], "source": [ - "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + "interpolations = trained_model.interpolate(eval_dataset[:5].to(device), eval_dataset[5:10].to(device), granularity=10).detach().cpu()" ] }, { @@ -344,7 +347,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.8.12" }, "orig_nbformat": 4 }, diff --git a/examples/notebooks/models_training/vamp_training.ipynb b/examples/notebooks/models_training/vamp_training.ipynb index ebfbbb11..1a9f76fd 100644 --- a/examples/notebooks/models_training/vamp_training.ipynb +++ b/examples/notebooks/models_training/vamp_training.ipynb @@ -16,8 +16,11 @@ "metadata": {}, "outputs": [], "source": [ + "import torch\n", "import torchvision.datasets as datasets\n", "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", "%load_ext autoreload\n", "%autoreload 2" ] @@ -253,7 +256,7 @@ "metadata": {}, "outputs": [], "source": [ - "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + "reconstructions = trained_model.reconstruct(eval_dataset[:25].to(device)).detach().cpu()" ] }, { @@ -301,7 +304,7 @@ "metadata": {}, "outputs": [], "source": [ - "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + "interpolations = trained_model.interpolate(eval_dataset[:5].to(device), eval_dataset[5:10].to(device), granularity=10).detach().cpu()" ] }, { @@ -339,7 +342,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.11" + "version": "3.8.12" }, "orig_nbformat": 4 }, diff --git a/examples/notebooks/models_training/vqvae_training.ipynb b/examples/notebooks/models_training/vqvae_training.ipynb index 910bd449..cc94d109 100644 --- a/examples/notebooks/models_training/vqvae_training.ipynb +++ b/examples/notebooks/models_training/vqvae_training.ipynb @@ -16,8 +16,11 @@ "metadata": {}, "outputs": [], "source": [ + "import torch\n", "import torchvision.datasets as datasets\n", "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", "%load_ext autoreload\n", "%autoreload 2" ] @@ -193,7 +196,7 @@ "metadata": {}, "outputs": [], "source": [ - "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + "reconstructions = trained_model.reconstruct(eval_dataset[:25].to(device)).detach().cpu()" ] }, { @@ -241,7 +244,7 @@ "metadata": {}, "outputs": [], "source": [ - "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + "interpolations = trained_model.interpolate(eval_dataset[:5].to(device), eval_dataset[5:10].to(device), granularity=10).detach().cpu()" ] }, { diff --git a/examples/notebooks/models_training/wae_training.ipynb b/examples/notebooks/models_training/wae_training.ipynb index 830ea999..03ea6657 100644 --- a/examples/notebooks/models_training/wae_training.ipynb +++ b/examples/notebooks/models_training/wae_training.ipynb @@ -16,8 +16,11 @@ "metadata": {}, "outputs": [], "source": [ + "import torch\n", "import torchvision.datasets as datasets\n", "\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", "%load_ext autoreload\n", "%autoreload 2" ] @@ -255,7 +258,7 @@ "metadata": {}, "outputs": [], "source": [ - "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + "reconstructions = trained_model.reconstruct(eval_dataset[:25].to(device)).detach().cpu()" ] }, { @@ -303,7 +306,7 @@ "metadata": {}, "outputs": [], "source": [ - "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + "interpolations = trained_model.interpolate(eval_dataset[:5].to(device), eval_dataset[5:10].to(device), granularity=10).detach().cpu()" ] }, { diff --git a/examples/notebooks/overview_notebook.ipynb b/examples/notebooks/overview_notebook.ipynb index c9828104..73cb1c0f 100644 --- a/examples/notebooks/overview_notebook.ipynb +++ b/examples/notebooks/overview_notebook.ipynb @@ -27,6 +27,7 @@ "| InfoVAE: Information Maximizing Variational Autoencoders | Info Variational Autoencoder (INFOVAE_MMD) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/clementchadebec/benchmark_VAE/blob/main/examples/notebooks/models_training/info_vae_training.ipynb) | [link](https://arxiv.org/abs/1706.02262) | |\n", "| VAE with a VampPrior | VAMP Autoencoder (VAMP) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/clementchadebec/benchmark_VAE/blob/main/examples/notebooks/models_training/vamp_training.ipynb) | [link](https://arxiv.org/abs/1705.07120) | [link](https://github.com/jmtomczak/vae_vampprior) |\n", "| Hyperspherical Variational Auto-Encoders | Hyperspherical VAE (SVAE) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/clementchadebec/benchmark_VAE/blob/main/examples/notebooks/models_training/svae_training.ipynb) | [link](https://arxiv.org/abs/1804.00891) | [link](https://github.com/nicola-decao/s-vae-pytorch) |\n", + "| Continuous Hierarchical Representations with Poincaré Variational Auto-Encoders | Poincaré Disk VAE (PoincareVAE) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/clementchadebec/benchmark_VAE/blob/main/examples/notebooks/models_training/pvae_training.ipynb) | [link](https://arxiv.org/abs/1901.06033) | [link](https://github.com/emilemathieu/pvae) |\n", "| Adversarial Autoencoders | Adversarial Autoencoder (Adversarial_AE) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/clementchadebec/benchmark_VAE/blob/main/examples/notebooks/models_training/adversarial_ae_training.ipynb) | [link](https://arxiv.org/abs/1511.05644)\n", "| Autoencoding beyond pixels using a learned similarity metric | Variational Autoencoder GAN (VAEGAN) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/clementchadebec/benchmark_VAE/blob/main/examples/notebooks/models_training/vaegan_training.ipynb) | [link](https://arxiv.org/abs/1512.09300) | [link](https://github.com/andersbll/autoencoding_beyond_pixels)\n", "| Neural Discrete Representation Learning | Vector Quantized VAE (VQVAE) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/clementchadebec/benchmark_VAE/blob/main/examples/notebooks/models_training/vqvae_training.ipynb) | [link](https://arxiv.org/abs/1711.00937) | [link](https://github.com/deepmind/sonnet/blob/v2/sonnet)\n", @@ -35,11 +36,6 @@ "| From Variational to Deterministic Autoencoders | Regularized AE with gradient penalty (RAE_GP) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/clementchadebec/benchmark_VAE/blob/main/examples/notebooks/models_training/rae_gp_training.ipynb) | [link](https://arxiv.org/abs/1903.12436) | [link](https://github.com/ParthaEth/Regularized_autoencoders-RAE-/tree/master/) |\n", "| Data Augmentation in High Dimensional Low Sample Size Setting Using a Geometry-Based Variational Autoencoder | Riemannian Hamiltonian VAE (RHVAE) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/clementchadebec/benchmark_VAE/blob/main/examples/notebooks/models_training/rhvae_training.ipynb) | [link](https://arxiv.org/abs/2105.00026) | [link](https://github.com/clementchadebec/pyraug)|\n" ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [] } ], "metadata": { diff --git a/examples/scripts/reproducibility/README.md b/examples/scripts/reproducibility/README.md index fca1c08c..3b4adff8 100644 --- a/examples/scripts/reproducibility/README.md +++ b/examples/scripts/reproducibility/README.md @@ -47,15 +47,17 @@ where `train_data` and `eval_data` have now the shape (n_img x im_channel x heig Below are gathered the results we were able to reproduce -| Model | Dataset | Metric | Obtained value | Reference value | Trained model -|:---:|:---:|:---:|:---:|:---:|:---:| -| VAE | Binary MNIST | NLL (200 IS) | 89.78 (0.01) | 89.9 | [link](https://huggingface.co/clementchadebec/reproduced_vae) -| VAMP (K=500) | Binary MNIST | NLL (5000 IS) | 85.79 (0.00) | 85.57 | [link](https://huggingface.co/clementchadebec/reproduced_vamp) -| SVAE | Dyn. Binarized MNIST | NLL (500 IS) | 93.27 (0.69) | 93.16 (0.31) | [link](https://huggingface.co/clementchadebec/reproduced_svae) -| IWAE (n_samples=50) | Binary MNIST | NLL (5000 IS) | 86.82 (0.01) | 87.1 | [link](https://huggingface.co/clementchadebec/reproduced_iwae) -| HVAE (n_lf=4) | Binary MNIST | NLL (1000 IS) | 86.21 (0.01) | 86.40 | [link](https://huggingface.co/clementchadebec/reproduced_hvae) -| BetaTCVAE | DSPRITES | Modified ELBO/ELBO (after 50 epochs) | 710.41/85.54 | 712.26/86.40 | [link](https://huggingface.co/clementchadebec/reproduced_beta_tc_vae) -| RAE_L2 | MNIST | FID | 9.1 | 9.9 | [link](https://huggingface.co/clementchadebec/reproduced_rae_l2) -| RAE_GP | MNIST | FID | 9.7 | 9.4 | [link](https://huggingface.co/clementchadebec/reproduced_rae_gp) -| WAE | CELEBA 64 | FID | 56.5 | 55 | [link](https://huggingface.co/clementchadebec/reproduced_wae) -| AAE | CELEBA 64 | FID | 43.3 | 42 | [link](https://huggingface.co/clementchadebec/reproduced_aae) +| Model | Dataset | Metric | Obtained value | Reference value| Reference (paper/code) | Trained model +|:---:|:---:|:---:|:---:|:---:|:---:|:---:| +| VAE | Binary MNIST | NLL (200 IS) | 89.78 (0.01) | 89.9 | [paper](https://arxiv.org/abs/1505.05770) | [link](https://huggingface.co/clementchadebec/reproduced_vae) +| VAMP (K=500) | Binary MNIST | NLL (5000 IS) | 85.79 (0.00) | 85.57 | [paper](https://arxiv.org/abs/1705.07120) | [link](https://huggingface.co/clementchadebec/reproduced_vamp) +| SVAE | Dyn. Binarized MNIST | NLL (500 IS) | 93.27 (0.69) | 93.16 (0.31) | [code](https://github.com/nicola-decao/s-vae-pytorch) | [link](https://huggingface.co/clementchadebec/reproduced_svae) | +PoincareVAE (Wrapped)| MNIST | NLL (500 IS) | 101.97 (0.01) | 101.47 (0.01) | [code](https://github.com/emilemathieu/pvae) | [link](https://huggingface.co/clementchadebec/reproduced_wrapped_poincare_vae) +| IWAE (n_samples=50) | Binary MNIST | NLL (5000 IS) | 86.82 (0.01) | 87.1 | [paper](https://arxiv.org/abs/1509.00519) | [link](https://huggingface.co/clementchadebec/reproduced_iwae) +| HVAE (n_lf=4) | Binary MNIST | NLL (1000 IS) | 86.21 (0.01) | 86.40 | [paper](https://arxiv.org/abs/1410.6460) | [link](https://huggingface.co/clementchadebec/reproduced_hvae) +| BetaTCVAE | DSPRITES | Modified ELBO/ELBO (after 50 epochs) | 710.41/85.54 | 712.26/86.40 | [code](https://github.com/rtqichen/beta-tcvae) | [link](https://huggingface.co/clementchadebec/reproduced_beta_tc_vae) +| RAE_L2 | MNIST | FID | 9.1 | 9.9 | [code](https://github.com/ParthaEth/Regularized_autoencoders-RAE-) | [link](https://huggingface.co/clementchadebec/reproduced_rae_l2) +| RAE_GP | MNIST | FID | 9.7 | 9.4 | [code](https://github.com/ParthaEth/Regularized_autoencoders-RAE-)| [link](https://huggingface.co/clementchadebec/reproduced_rae_gp) +| WAE | CELEBA 64 | FID | 56.5 | 55 | [paper](https://arxiv.org/abs/1711.01558) | [link](https://huggingface.co/clementchadebec/reproduced_wae) +| AAE | CELEBA 64 | FID | 43.3 | 42 | [paper](https://arxiv.org/abs/1711.01558)| [link](https://huggingface.co/clementchadebec/reproduced_aae) + diff --git a/examples/scripts/reproducibility/configs/mnist/pvae/base_training_config.json b/examples/scripts/reproducibility/configs/mnist/pvae/base_training_config.json new file mode 100644 index 00000000..f068e8c0 --- /dev/null +++ b/examples/scripts/reproducibility/configs/mnist/pvae/base_training_config.json @@ -0,0 +1,10 @@ +{ + "name": "BaseTrainerConfig", + "output_dir": "reproducibility/mnist", + "batch_size": 128, + "num_epochs": 80, + "learning_rate": 5e-4, + "steps_saving": 100, + "steps_predict": null, + "no_cuda": false +} diff --git a/examples/scripts/reproducibility/configs/mnist/pvae/pvae_config.json b/examples/scripts/reproducibility/configs/mnist/pvae/pvae_config.json new file mode 100644 index 00000000..6a010425 --- /dev/null +++ b/examples/scripts/reproducibility/configs/mnist/pvae/pvae_config.json @@ -0,0 +1,8 @@ +{ + "name": "PoincareVAEConfig", + "latent_dim": 10, + "reconstruction_loss": "bce", + "prior_distribution": "wrapped_normal", + "posterior_distribution": "wrapped_normal", + "curvature": 0.7 +} diff --git a/examples/scripts/reproducibility/pvae.py b/examples/scripts/reproducibility/pvae.py new file mode 100644 index 00000000..a3e0a14e --- /dev/null +++ b/examples/scripts/reproducibility/pvae.py @@ -0,0 +1,233 @@ +import argparse +import logging +import os +import numpy as np +import math +from typing import List + +import numpy as np +import torch +import torch.nn.functional as F +from pythae.data.preprocessors import DataProcessor +from pythae.models import PoincareVAE, PoincareVAEConfig +from pythae.models.pvae.pvae_utils import PoincareBall +from pythae.models import AutoModel +from pythae.trainers import BaseTrainerConfig, BaseTrainer + +from pythae.models.nn import BaseEncoder, BaseDecoder +import torch.nn as nn +from pythae.models.base.base_utils import ModelOutput + + +logger = logging.getLogger(__name__) +console = logging.StreamHandler() +logger.addHandler(console) +logger.setLevel(logging.INFO) + +PATH = os.path.dirname(os.path.abspath(__file__)) + +ap = argparse.ArgumentParser() + +device = 'cuda' if torch.cuda.is_available() else 'cpu' + + +ap.add_argument( + "--model_config", + help="path to model config file (expected json file)", + default=None, +) +ap.add_argument( + "--training_config", + help="path to training config_file (expected json file)", + default=os.path.join(PATH, "configs/base_training_config.json"), +) + +args = ap.parse_args() + +class RiemannianLayer(nn.Module): + def __init__(self, in_features, out_features, manifold, over_param, weight_norm): + super(RiemannianLayer, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.manifold = manifold + self._weight = nn.Parameter(torch.Tensor(out_features, in_features)) + self.over_param = over_param + self.weight_norm = weight_norm + self._bias = nn.Parameter(torch.Tensor(out_features, 1)) + self.reset_parameters() + + @property + def weight(self): + return self.manifold.transp0(self.bias, self._weight) # weight \in T_0 => weight \in T_bias + + @property + def bias(self): + if self.over_param: + return self._bias + else: + return self.manifold.expmap0(self._weight * self._bias) # reparameterisation of a point on the manifold + + def reset_parameters(self): + nn.init.kaiming_normal_(self._weight, a=math.sqrt(5)) + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self._weight) + bound = 4 / math.sqrt(fan_in) + nn.init.uniform_(self._bias, -bound, bound) + if self.over_param: + with torch.no_grad(): self._bias.set_(self.manifold.expmap0(self._bias)) + +class GeodesicLayer(RiemannianLayer): + def __init__(self, in_features, out_features, manifold, over_param=False, weight_norm=False): + super(GeodesicLayer, self).__init__(in_features, out_features, manifold, over_param, weight_norm) + + def forward(self, input): + input = input.unsqueeze(0) + input = input.unsqueeze(-2).expand(*input.shape[:-(len(input.shape) - 2)], self.out_features, self.in_features) + res = self.manifold.normdist2plane(input, self.bias, self.weight, + signed=True, norm=self.weight_norm) + return res + +### Define paper encoder network +class Encoder(BaseEncoder): + """ Usual encoder followed by an exponential map """ + def __init__(self, model_config, prior_iso): + super(Encoder, self).__init__() + self.manifold = PoincareBall(dim=model_config.latent_dim, c=model_config.curvature) + self.enc = nn.Sequential( + nn.Linear(np.prod(model_config.input_dim), 600), nn.ReLU(), + ) + self.fc21 = nn.Linear(600, model_config.latent_dim) + self.fc22 = nn.Linear(600, model_config.latent_dim if not prior_iso else 1) + + def forward(self, x): + e = self.enc(x.reshape(x.shape[0], -1)) + mu = self.fc21(e) + mu = self.manifold.expmap0(mu) + return ModelOutput( + embedding=mu, + log_covariance=torch.log(F.softplus(self.fc22(e)) + 1e-5), # expects log_covariance + log_concentration=torch.log(F.softplus(self.fc22(e)) + 1e-5) # for Riemannian Normal + + ) + +### Define paper decoder network +class Decoder(BaseDecoder): + """ First layer is a Hypergyroplane followed by usual decoder """ + def __init__(self, model_config): + super(Decoder, self).__init__() + self.manifold = PoincareBall(dim=model_config.latent_dim, c=model_config.curvature) + self.input_dim = model_config.input_dim + self.dec = nn.Sequential( + GeodesicLayer(model_config.latent_dim, 600, self.manifold), + nn.ReLU(), + nn.Linear(600, np.prod(model_config.input_dim)), + nn.Sigmoid() + ) + + def forward(self, z): + out = self.dec(z).reshape((z.shape[0],) + self.input_dim) # reshape data + return ModelOutput( + reconstruction=out + ) + + +def main(args): + + ### Load data + train_data = torch.tensor( + np.load(os.path.join(PATH, f"data/mnist", "train_data.npz"))[ + "data" + ] + / 255.0 + ).clamp(1e-5, 1-1e-5) + eval_data = torch.tensor( + np.load(os.path.join(PATH, f"data/mnist", "eval_data.npz"))["data"] + / 255.0 + ).clamp(1e-5, 1-1e-5) + + train_data = torch.cat((train_data, eval_data)) + + test_data = torch.tensor( + np.load(os.path.join(PATH, f"data/mnist", "test_data.npz"))["data"] + / 255.0 + ).clamp(1e-5, 1-1e-5) + + data_input_dim = tuple(train_data.shape[1:]) + + + if args.model_config is not None: + model_config = PoincareVAEConfig.from_json_file(args.model_config) + + else: + model_config = PoincareVAEConfig() + + model_config.input_dim = data_input_dim + + + + model = PoincareVAE( + model_config=model_config, + encoder=Encoder(model_config, prior_iso=True), + decoder=Decoder(model_config), + ) + + ### Set training config + training_config = BaseTrainerConfig.from_json_file(args.training_config) + + ### Process data + data_processor = DataProcessor() + logger.info("Preprocessing train data...") + train_data = data_processor.process_data(torch.bernoulli(train_data)) + train_dataset = data_processor.to_dataset(train_data) + + logger.info("Preprocessing eval data...\n") + eval_data = data_processor.process_data(torch.bernoulli(eval_data)) + eval_dataset = data_processor.to_dataset(eval_data) + + ### Optimizer + optimizer = torch.optim.Adam(model.parameters(), lr=training_config.learning_rate) + + ### Scheduler + scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, milestones=[10000000], gamma=10**(-1/3), verbose=True + ) + + seed = 123 + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + logger.info("Using Base Trainer\n") + trainer = BaseTrainer( + model=model, + train_dataset=train_dataset, + eval_dataset=None,#eval_dataset, + training_config=training_config, + optimizer=optimizer, + scheduler=scheduler, + callbacks=None, + ) + + ### Launch training + trainer.train() + + trained_model = AutoModel.load_from_folder(os.path.join(training_config.output_dir, f'{trainer.model.model_name}_training_{trainer._training_signature}', 'final_model')).to(device) + + test_data = torch.tensor(test_data).to(device).type(torch.float) + + ### Compute NLL + with torch.no_grad(): + nll = [] + for i in range(5): + nll_i = trained_model.get_nll(test_data, n_samples=500, batch_size=500) + logger.info(f"Round {i+1} nll: {nll_i}") + nll.append(nll_i) + + logger.info( + f'\nmean_nll: {np.mean(nll)}' + ) + logger.info( + f'\std_nll: {np.std(nll)}' + ) + +if __name__ == "__main__": + + main(args) diff --git a/src/pythae/models/__init__.py b/src/pythae/models/__init__.py index 9dc04bf2..e35b578f 100755 --- a/src/pythae/models/__init__.py +++ b/src/pythae/models/__init__.py @@ -26,6 +26,7 @@ from .info_vae import INFOVAE_MMD, INFOVAE_MMD_Config from .iwae import IWAE, IWAEConfig from .msssim_vae import MSSSIM_VAE, MSSSIM_VAEConfig +from .pvae import PoincareVAE, PoincareVAEConfig from .rae_gp import RAE_GP, RAE_GP_Config from .rae_l2 import RAE_L2, RAE_L2_Config from .rhvae import RHVAE, RHVAEConfig @@ -84,4 +85,6 @@ "VAE_LinNF_Config", "VAE_IAF", "VAE_IAF_Config", + "PoincareVAE", + "PoincareVAEConfig", ] diff --git a/src/pythae/models/auto_model/auto_config.py b/src/pythae/models/auto_model/auto_config.py index 00ba1592..b44221e5 100644 --- a/src/pythae/models/auto_model/auto_config.py +++ b/src/pythae/models/auto_model/auto_config.py @@ -160,6 +160,11 @@ def from_json_file(cls, json_path): model_config = PixelCNNConfig.from_json_file(json_path) + elif config_name == "PoincareVAEConfig": + from ..pvae import PoincareVAEConfig + + model_config = PoincareVAEConfig.from_json_file(json_path) + else: raise NameError( "Cannot reload automatically the model configuration... " diff --git a/src/pythae/models/auto_model/auto_model.py b/src/pythae/models/auto_model/auto_model.py index 8ac6788e..f1aad154 100644 --- a/src/pythae/models/auto_model/auto_model.py +++ b/src/pythae/models/auto_model/auto_model.py @@ -174,6 +174,11 @@ def load_from_folder(cls, dir_path: str): model = PixelCNN.load_from_folder(dir_path=dir_path) + elif model_name == "PoincareVAEConfig": + from ..pvae import PoincareVAE + + model = PoincareVAE.load_from_folder(dir_path=dir_path) + else: raise NameError( "Cannot reload automatically the model... " @@ -411,6 +416,13 @@ def load_from_hf_hub( hf_hub_path=hf_hub_path, allow_pickle=allow_pickle ) + elif model_name == "PoincareVAEConfig": + from ..pvae import PoincareVAE + + model = PoincareVAE.load_from_hf_hub( + hf_hub_path=hf_hub_path, allow_pickle=allow_pickle + ) + else: raise NameError( "Cannot reload automatically the model... " diff --git a/src/pythae/models/pvae/__init__.py b/src/pythae/models/pvae/__init__.py new file mode 100644 index 00000000..be1a2370 --- /dev/null +++ b/src/pythae/models/pvae/__init__.py @@ -0,0 +1,20 @@ +"""This module is the implementation of a Poincaré Disk Variational Autoencoder +(https://arxiv.org/abs/1901.06033). + +Available samplers +------------------- + +.. autosummary:: + ~pythae.samplers.PoincareDiskSampler + ~pythae.samplers.NormalSampler + ~pythae.samplers.GaussianMixtureSampler + ~pythae.samplers.TwoStageVAESampler + ~pythae.samplers.MAFSampler + ~pythae.samplers.IAFSampler + :nosignatures: +""" + +from .pvae_config import PoincareVAEConfig +from .pvae_model import PoincareVAE + +__all__ = ["PoincareVAE", "PoincareVAEConfig"] diff --git a/src/pythae/models/pvae/pvae_config.py b/src/pythae/models/pvae/pvae_config.py new file mode 100644 index 00000000..d72432bc --- /dev/null +++ b/src/pythae/models/pvae/pvae_config.py @@ -0,0 +1,28 @@ +from pydantic.dataclasses import dataclass +from typing_extensions import Literal + +from ..vae import VAEConfig + + +@dataclass +class PoincareVAEConfig(VAEConfig): + """Poincaré VAE config class. + + Parameters: + input_dim (tuple): The input_data dimension. + latent_dim (int): The latent space dimension. Default: None. + reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse' + prior_distribution (str): The distribution to use as prior + ["wrapped_normal", "riemannian_normal"]. Default: "wrapped_normal" + posterior_distribution (str): The distribution to use as posterior + ["wrapped_normal", "riemannian_normal"]. Default: "wrapped_normal" + curvature (int): The curvature of the manifold. Default: 1 + """ + + prior_distribution: Literal[ + "wrapped_normal", "riemannian_normal" + ] = "wrapped_normal" + posterior_distribution: Literal[ + "wrapped_normal", "riemannian_normal" + ] = "wrapped_normal" + curvature: float = 1 diff --git a/src/pythae/models/pvae/pvae_model.py b/src/pythae/models/pvae/pvae_model.py new file mode 100644 index 00000000..189f8677 --- /dev/null +++ b/src/pythae/models/pvae/pvae_model.py @@ -0,0 +1,286 @@ +import warnings +from typing import Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...data.datasets import BaseDataset +from ..base.base_utils import ModelOutput +from ..nn import BaseDecoder, BaseEncoder +from ..nn.default_architectures import Encoder_SVAE_MLP, Encoder_VAE_MLP +from ..vae import VAE +from .pvae_config import PoincareVAEConfig +from .pvae_utils import PoincareBall, RiemannianNormal, WrappedNormal + + +class PoincareVAE(VAE): + """Poincaré Variational Autoencoder model. + + Args: + model_config (PoincareVAEConfig): The Poincaré Variational Autoencoder configuration + setting the main parameters of the model. + + encoder (BaseEncoder): An instance of BaseEncoder (inheriting from `torch.nn.Module` which + plays the role of encoder. This argument allows you to use your own neural networks + architectures if desired. If None is provided, a simple Multi Layer Preception + (https://en.wikipedia.org/wiki/Multilayer_perceptron) is used. Default: None. + + decoder (BaseDecoder): An instance of BaseDecoder (inheriting from `torch.nn.Module` which + plays the role of decoder. This argument allows you to use your own neural networks + architectures if desired. If None is provided, a simple Multi Layer Preception + (https://en.wikipedia.org/wiki/Multilayer_perceptron) is used. Default: None. + + .. note:: + For high dimensional data we advice you to provide you own network architectures. With the + provided MLP you may end up with a ``MemoryError``. + """ + + def __init__( + self, + model_config: PoincareVAEConfig, + encoder: Optional[BaseEncoder] = None, + decoder: Optional[BaseDecoder] = None, + ): + + VAE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder) + + self.model_name = "PoincareVAE" + + self.latent_manifold = PoincareBall( + dim=model_config.latent_dim, c=model_config.curvature + ) + + if model_config.prior_distribution == "riemannian_normal": + self.prior = RiemannianNormal + else: + self.prior = WrappedNormal + + if model_config.posterior_distribution == "riemannian_normal": + warnings.warn( + "Carefull, this model expects the encoder to give a one dimensional " + "`log_concentration` tensor for the Riemannian normal distribution. " + "Make sure the encoder actually outputs this." + ) + self.posterior = RiemannianNormal + else: + self.posterior = WrappedNormal + + if encoder is None: + if model_config.posterior_distribution == "riemannian_normal": + encoder = Encoder_SVAE_MLP(model_config) + else: + encoder = Encoder_VAE_MLP(model_config) + self.model_config.uses_default_encoder = True + + else: + self.model_config.uses_default_encoder = False + + self.set_encoder(encoder) + + self._pz_mu = nn.Parameter( + torch.zeros(1, model_config.latent_dim), requires_grad=False + ) + self._pz_logvar = nn.Parameter(torch.zeros(1, 1), requires_grad=False) + + def forward(self, inputs: BaseDataset, **kwargs): + """ + The VAE model + + Args: + inputs (BaseDataset): The training dataset with labels + + Returns: + ModelOutput: An instance of ModelOutput containing all the relevant parameters + + """ + + x = inputs["data"] + + encoder_output = self.encoder(x) + + if self.model_config.posterior_distribution == "riemannian_normal": + mu, log_var = encoder_output.embedding, encoder_output.log_concentration + else: + mu, log_var = encoder_output.embedding, encoder_output.log_covariance + + std = torch.exp(0.5 * log_var) + + qz_x = self.posterior(loc=mu, scale=std, manifold=self.latent_manifold) + z = qz_x.rsample(torch.Size([1])) + + recon_x = self.decoder(z.squeeze(0))["reconstruction"] + + loss, recon_loss, kld = self.loss_function(recon_x, x, z, qz_x) + + output = ModelOutput( + reconstruction_loss=recon_loss, + reg_loss=kld, + loss=loss, + recon_x=recon_x, + z=z.squeeze(0), + ) + + return output + + def loss_function(self, recon_x, x, z, qz_x): + + if self.model_config.reconstruction_loss == "mse": + + recon_loss = F.mse_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) + + elif self.model_config.reconstruction_loss == "bce": + + recon_loss = F.binary_cross_entropy( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) + + pz = self.prior( + loc=self._pz_mu, scale=self._pz_logvar.exp(), manifold=self.latent_manifold + ) + + KLD = (qz_x.log_prob(z) - pz.log_prob(z)).sum(-1).squeeze(0) + + return (recon_loss + KLD).mean(dim=0), recon_loss.mean(dim=0), KLD.mean(dim=0) + + def interpolate( + self, + starting_inputs: torch.Tensor, + ending_inputs: torch.Tensor, + granularity: int = 10, + ): + """This function performs a geodesic interpolation in the poincaré disk of the autoencoder + from starting inputs to ending inputs. It returns the interpolation trajectories. + + Args: + starting_inputs (torch.Tensor): The starting inputs in the interpolation of shape + [B x input_dim] + ending_inputs (torch.Tensor): The starting inputs in the interpolation of shape + [B x input_dim] + granularity (int): The granularity of the interpolation. + + Returns: + torch.Tensor: A tensor of shape [B x granularity x input_dim] containing the + interpolation trajectories. + """ + assert starting_inputs.shape[0] == ending_inputs.shape[0], ( + "The number of starting_inputs should equal the number of ending_inputs. Got " + f"{starting_inputs.shape[0]} sampler for starting_inputs and {ending_inputs.shape[0]} " + "for endinging_inputs." + ) + + starting_z = self.encoder(starting_inputs).embedding + ending_z = self.encoder(ending_inputs).embedding + t = torch.linspace(0, 1, granularity).to(starting_inputs.device) + + inter_geo = torch.zeros( + starting_inputs.shape[0], granularity, starting_z.shape[-1] + ).to(starting_z.device) + + for i, t_i in enumerate(t): + z_i = self.latent_manifold.geodesic(t_i, starting_z, ending_z) + inter_geo[:, i, :] = z_i + + decoded_geo = self.decoder( + inter_geo.reshape( + (starting_z.shape[0] * t.shape[0],) + (starting_z.shape[1:]) + ) + ).reconstruction.reshape( + ( + starting_inputs.shape[0], + t.shape[0], + ) + + (starting_inputs.shape[1:]) + ) + return decoded_geo + + def get_nll(self, data, n_samples=1, batch_size=100): + """ + Function computed the estimate negative log-likelihood of the model. It uses importance + sampling method with the approximate posterior distribution. This may take a while. + + Args: + data (torch.Tensor): The input data from which the log-likelihood should be estimated. + Data must be of shape [Batch x n_channels x ...] + n_samples (int): The number of importance samples to use for estimation + batch_size (int): The batchsize to use to avoid memory issues + """ + + if n_samples <= batch_size: + n_full_batch = 1 + else: + n_full_batch = n_samples // batch_size + n_samples = batch_size + + log_p = [] + + for i in range(len(data)): + x = data[i].unsqueeze(0) + + log_p_x = [] + + for j in range(n_full_batch): + + x_rep = torch.cat(batch_size * [x]) + + encoder_output = self.encoder(x_rep) + if self.model_config.posterior_distribution == "riemannian_normal": + mu, log_var = ( + encoder_output.embedding, + encoder_output.log_concentration, + ) + else: + mu, log_var = ( + encoder_output.embedding, + encoder_output.log_covariance, + ) + + std = torch.exp(0.5 * log_var) + + qz_x = self.posterior(loc=mu, scale=std, manifold=self.latent_manifold) + z = qz_x.rsample(torch.Size([1])) + + pz = self.prior( + loc=self._pz_mu, + scale=self._pz_logvar.exp(), + manifold=self.latent_manifold, + ) + + log_q_z_given_x = qz_x.log_prob(z).sum(-1).squeeze(0) + log_p_z = pz.log_prob(z).sum(-1).squeeze(0) + + recon_x = self.decoder(z.squeeze(0))["reconstruction"] + + if self.model_config.reconstruction_loss == "mse": + + log_p_x_given_z = -0.5 * F.mse_loss( + recon_x.reshape(x_rep.shape[0], -1), + x_rep.reshape(x_rep.shape[0], -1), + reduction="none", + ).sum(dim=-1) - torch.tensor( + [np.prod(self.input_dim) / 2 * np.log(np.pi * 2)] + ).to( + data.device + ) # decoding distribution is assumed unit variance N(mu, I) + + elif self.model_config.reconstruction_loss == "bce": + + log_p_x_given_z = -F.binary_cross_entropy( + recon_x.reshape(x_rep.shape[0], -1), + x_rep.reshape(x_rep.shape[0], -1), + reduction="none", + ).sum(dim=-1) + + log_p_x.append(log_p_x_given_z + log_p_z - log_q_z_given_x) + + log_p_x = torch.cat(log_p_x) + + log_p.append((torch.logsumexp(log_p_x, 0) - np.log(len(log_p_x))).item()) + return np.mean(log_p) diff --git a/src/pythae/models/pvae/pvae_utils.py b/src/pythae/models/pvae/pvae_utils.py new file mode 100644 index 00000000..a836da72 --- /dev/null +++ b/src/pythae/models/pvae/pvae_utils.py @@ -0,0 +1,1028 @@ +"""Distributions and manifold taken from + (https://github.com/emilemathieu/pvae/blob/master/pvae) and + (https://github.com/geoopt/geoopt/blob/master/geoopt/manifolds) + +""" +import math +from numbers import Number +from typing import Optional, Tuple + +import numpy as np +import torch +import torch.distributions as dist +from torch.autograd import Function, grad +from torch.distributions import Normal +from torch.distributions.utils import _standard_normal, broadcast_all +from torch.nn import functional as F + +MIN_NORM = 1e-15 +BALL_EPS = {torch.float32: 4e-3, torch.float64: 1e-5} + + +def log_sum_exp_signs(value, signs, dim=0, keepdim=False): + m, _ = torch.max(value, dim=dim, keepdim=True) + value0 = value - m + if keepdim is False: + m = m.squeeze(dim) + return m + torch.log(torch.sum(signs * torch.exp(value0), dim=dim, keepdim=keepdim)) + + +def rexpand(A, *dimensions): + """Expand tensor, adding new dimensions on right.""" + return A.view(A.shape + (1,) * len(dimensions)).expand(A.shape + tuple(dimensions)) + + +def logsinh(x): + # torch.log(sinh(x)) + return x + torch.log(1 - torch.exp(-2 * x)) - math.log(2) + + +def tanh(x): ## OK + return x.clamp(-15, 15).tanh() + + +def arsinh(x: torch.Tensor): ## OK + return (x + torch.sqrt(1 + x.pow(2))).clamp_min(MIN_NORM).log().to(x.dtype) + + +def artanh(x: torch.Tensor): ## OK + x = x.clamp(-1 + 1e-5, 1 - 1e-5) + return (torch.log_(1 + x).sub_(torch.log_(1 - x))).mul_(0.5) + + +def _lambda_x(x, c, keepdim: bool = False, dim: int = -1): ## OK + return 2 / (1 - c * x.pow(2).sum(dim=dim, keepdim=keepdim)).clamp_min(MIN_NORM) + + +def _mobius_add(x, y, c, dim=-1): ## OK + x2 = x.pow(2).sum(dim=dim, keepdim=True) + y2 = y.pow(2).sum(dim=dim, keepdim=True) + xy = (x * y).sum(dim=dim, keepdim=True) + num = (1 + 2 * c * xy + c * y2) * x + (1 - c * x2) * y + denom = 1 + 2 * c * xy + c ** 2 * x2 * y2 + return num / denom.clamp_min(MIN_NORM) + + +def _mobius_scalar_mul(r, x, c, dim: int = -1): ## OK + x_norm = x.norm(dim=dim, keepdim=True, p=2).clamp_min(MIN_NORM) + sqrt_c = c ** 0.5 + res_c = tanh(r * artanh(sqrt_c * x_norm)) * x / (x_norm * sqrt_c) + return res_c + + +def _project(x, c, dim: int = -1, eps: float = None): ## OK + norm = x.norm(dim=dim, keepdim=True, p=2).clamp_min(MIN_NORM) + if eps is None: + eps = BALL_EPS[x.dtype] + maxnorm = (1 - eps) / (c ** 0.5) + cond = norm > maxnorm + projected = x / norm * maxnorm + return torch.where(cond, projected, x) + + +def _gyration(u, v, w, c, dim: int = -1): ## OK + u2 = u.pow(2).sum(dim=dim, keepdim=True) + v2 = v.pow(2).sum(dim=dim, keepdim=True) + uv = (u * v).sum(dim=dim, keepdim=True) + uw = (u * w).sum(dim=dim, keepdim=True) + vw = (v * w).sum(dim=dim, keepdim=True) + c2 = c ** 2 + a = -c2 * uw * v2 + c * vw + 2 * c2 * uv * vw + b = -c2 * vw * u2 - c * uw + d = 1 + 2 * c * uv + c2 * u2 * v2 + return w + 2 * (a * u + b * v) / d.clamp_min(MIN_NORM) + + +class PoincareBall: + def __init__(self, dim, c=1.0): + self.c = c + self.dim = dim + + @property + def coord_dim(self): + return int(self.dim) + + @property + def zero(self): + return torch.zeros(1, self.dim).to(self.device) + + # def norm(self, x: torch.Tensor, u: torch.Tensor, *, keepdim=False, dim=-1 + # ) -> torch.Tensor: ## OK + # return _lambda_x(x, c=self.c, keepdim=keepdim, dim=dim) * u.norm( + # dim=dim, keepdim=keepdim, p=2 + # ) + + def dist( ## OK + self, x: torch.Tensor, y: torch.Tensor, *, keepdim=False, dim=-1 + ) -> torch.Tensor: ## OK + sqrt_c = self.c ** 0.5 + dist_c = artanh( + sqrt_c + * _mobius_add(-x, y, self.c, dim=dim).norm(dim=dim, p=2, keepdim=keepdim) + ) + return dist_c * 2 / sqrt_c + + def lambda_x( + self, x: torch.Tensor, *, dim=-1, keepdim=False + ) -> torch.Tensor: ## OK + return _lambda_x(x, c=self.c, dim=dim, keepdim=keepdim) + + def mobius_add( + self, x: torch.Tensor, y: torch.Tensor, *, dim=-1, project=True + ) -> torch.Tensor: ## OK + res = _mobius_add(x, y, c=self.c, dim=dim) + if project: + return _project(res, c=self.c, dim=dim) + else: + return res + + def logmap0(self, x: torch.Tensor, y: torch.Tensor, *, dim=-1) -> torch.Tensor: + sqrt_c = self.c ** 0.5 + y_norm = y.norm(dim=dim, p=2, keepdim=True).clamp_min(MIN_NORM) + return y / y_norm / sqrt_c * artanh(sqrt_c * y_norm) + + def logmap( + self, x: torch.Tensor, y: torch.Tensor, *, dim=-1 + ) -> torch.Tensor: ## OK + sub = _mobius_add(-x, y, self.c, dim=dim) + sub_norm = sub.norm(dim=dim, p=2, keepdim=True).clamp_min(MIN_NORM) + lam = _lambda_x(x, self.c, keepdim=True, dim=dim) + sqrt_c = self.c ** 0.5 + return 2 / sqrt_c / lam * artanh(sqrt_c * sub_norm) * sub / sub_norm + + def transp0(self, y: torch.Tensor, v: torch.Tensor, *, dim=-1) -> torch.Tensor: + return v * (1 - self.c * y.pow(2).sum(dim=dim, keepdim=True)).clamp_min( + MIN_NORM + ) + + def transp( + self, x: torch.Tensor, y: torch.Tensor, v: torch.Tensor, *, dim=-1 + ): ## OK + return ( + _gyration(y, -x, v, self.c, dim=dim) + * _lambda_x(x, self.c, keepdim=True, dim=dim) + / _lambda_x(y, self.c, keepdim=True, dim=dim) + ) + + def logdetexp(self, x, y, is_vector=False, keepdim=False): ## OK + d = ( + self.norm(x, y, keepdim=keepdim) + if is_vector + else self.dist(x, y, keepdim=keepdim) + ) + return (self.dim - 1) * ( + torch.sinh(math.sqrt(self.c) * d) / math.sqrt(self.c) / d + ).log() + + def expmap0(self, u, dim: int = -1): + sqrt_c = self.c ** 0.5 + u_norm = u.norm(dim=dim, p=2, keepdim=True).clamp_min(MIN_NORM) + gamma_1 = tanh(sqrt_c * u_norm) * u / (sqrt_c * u_norm) + return gamma_1 + + def expmap(self, x, u, dim: int = -1): + sqrt_c = self.c ** 0.5 + u_norm = u.norm(dim=dim, p=2, keepdim=True).clamp_min(MIN_NORM) + second_term = ( + tanh(sqrt_c / 2 * _lambda_x(x, self.c, keepdim=True, dim=dim) * u_norm) + * u + / (sqrt_c * u_norm) + ) + gamma_1 = _mobius_add(x, second_term, self.c, dim=dim) + return gamma_1 + + def expmap_polar(self, x, u, r, dim: int = -1): ## OK + sqrt_c = self.c ** 0.5 + u_norm = u.norm(dim=dim, p=2, keepdim=True).clamp_min(MIN_NORM) + second_term = ( + tanh(torch.tensor([sqrt_c]).to(x.device) / 2 * r) * u / (sqrt_c * u_norm) + ) + gamma_1 = self.mobius_add(x, second_term, dim=dim) + return gamma_1 + + def geodesic(self, t, x, y, dim: int = -1): ## OK + v = _mobius_add(-x, y, self.c, dim=dim) + tv = _mobius_scalar_mul(t, v, self.c, dim=dim) + gamma_t = _mobius_add(x, tv, self.c, dim=dim) + return gamma_t + + def normdist2plane( + self, + x, + a, + p, + keepdim: bool = False, + signed: bool = False, + dim: int = -1, + norm: bool = False, + ): + c = self.c + sqrt_c = c ** 0.5 + diff = self.mobius_add(-p, x, dim=dim) + diff_norm2 = diff.pow(2).sum(dim=dim, keepdim=keepdim).clamp_min(MIN_NORM) + sc_diff_a = (diff * a).sum(dim=dim, keepdim=keepdim) + if not signed: + sc_diff_a = sc_diff_a.abs() + a_norm = a.norm(dim=dim, keepdim=keepdim, p=2).clamp_min(MIN_NORM) + num = 2 * sqrt_c * sc_diff_a + denom = (1 - c * diff_norm2) * a_norm + res = arsinh(num / denom.clamp_min(MIN_NORM)) / sqrt_c + if norm: + res = res * a_norm # * self.lambda_x(a, dim=dim, keepdim=keepdim) + return res + + def _check_point_on_manifold(self, x, *, atol=1e-5, rtol=1e-5): + px = _project(x, c=self.c) + ok = torch.allclose(x, px, atol=atol, rtol=rtol) + if not ok: + reason = "'x' norm lies out of the bounds [-1/sqrt(c)+eps, 1/sqrt(c)-eps]" + else: + reason = None + return ok, reason + + def _check_vector_on_tangent( + self, x: torch.Tensor, u: torch.Tensor, *, atol=1e-5, rtol=1e-5, dim=-1 + ) -> Tuple[bool, Optional[str]]: + return True, None + + +class WrappedNormal(dist.Distribution): ## OK + """Wrapped Normal distribution""" + + arg_constraints = {"loc": dist.constraints.real, "scale": dist.constraints.positive} + support = dist.constraints.real + has_rsample = True + _mean_carrier_measure = 0 + + @property + def scale(self): + return F.softplus(self._scale) if self.softplus else self._scale + + def __init__(self, loc, scale, manifold, validate_args=None, softplus=False): + self.dtype = loc.dtype + self.softplus = softplus + self.loc, self._scale = broadcast_all(loc, scale) + self.manifold = manifold + self.manifold._check_point_on_manifold(self.loc) + self.device = loc.device + if isinstance(loc, Number) and isinstance(scale, Number): + batch_shape, event_shape = torch.Size(), torch.Size() + else: + batch_shape = self.loc.shape[:-1] + event_shape = torch.Size([self.manifold.dim]) + super(WrappedNormal, self).__init__( + batch_shape, event_shape, validate_args=validate_args + ) + + def sample(self, shape=torch.Size()): ## OK + with torch.no_grad(): + return self.rsample(shape) + + def rsample(self, sample_shape=torch.Size()): ## OK + shape = self._extended_shape(sample_shape) + v = self.scale * _standard_normal( + shape, dtype=self.loc.dtype, device=self.loc.device + ) + self.manifold._check_vector_on_tangent( + torch.zeros(1, self.manifold.dim).to(v.device), v + ) + v = v / self.manifold.lambda_x( + torch.zeros(1, self.manifold.dim).to(v.device), keepdim=True + ) + u = self.manifold.transp( + torch.zeros(1, self.manifold.dim).to(v.device), self.loc, v + ) + z = self.manifold.expmap(self.loc, u) + return z + + def log_prob(self, x): ## OK + shape = x.shape + loc = self.loc.unsqueeze(0).expand( + x.shape[0], *self.batch_shape, self.manifold.coord_dim + ) + if len(shape) < len(loc.shape): + x = x.unsqueeze(1) + v = self.manifold.logmap(loc, x) + v = self.manifold.transp(loc, torch.zeros(1, self.manifold.dim).to(v.device), v) + u = v * self.manifold.lambda_x( + torch.zeros(1, self.manifold.dim).to(v.device), keepdim=True + ) + norm_pdf = ( + Normal(torch.zeros_like(self.scale), self.scale) + .log_prob(u) + .sum(-1, keepdim=True) + ) + logdetexp = self.manifold.logdetexp(loc, x, keepdim=True) + result = norm_pdf - logdetexp + return result + + +infty = torch.tensor(float("Inf")) + + +def diff(x): + return x[:, 1:] - x[:, :-1] + + +class ARS: + """ + This class implements the Adaptive Rejection Sampling technique of Gilks and Wild '92. + Where possible, naming convention has been borrowed from this paper. + The PDF must be log-concave. + Currently does not exploit lower hull described in paper- which is fine for drawing + only small amount of samples at a time. + """ + + def __init__( + self, + logpdf, + grad_logpdf, + device, + xi, + lb=-infty, + ub=infty, + use_lower=False, + ns=50, + **fargs, + ): + """ + initialize the upper (and if needed lower) hulls with the specified params + + Parameters + ========== + f: function that computes log(f(u,...)), for given u, where f(u) is proportional to the + density we want to sample from + fprima: d/du log(f(u,...)) + xi: ordered vector of starting points in wich log(f(u,...) is defined + to initialize the hulls + use_lower: True means the lower sqeezing will be used; which is more efficient + for drawing large numbers of samples + + + lb: lower bound of the domain + ub: upper bound of the domain + ns: maximum number of points defining the hulls + fargs: arguments for f and fprima + """ + self.device = device + + self.lb = lb + self.ub = ub + + self.logpdf = logpdf + self.grad_logpdf = grad_logpdf + self.fargs = fargs + + # set limit on how many points to maintain on hull + self.ns = ns + self.xi = xi.to( + self.device + ) # initialize x, the vector of absicassae at which the function h has been evaluated + self.B, self.K = self.xi.size() # hull size + self.h = torch.zeros(self.B, ns).to(self.device) + self.hprime = torch.zeros(self.B, ns).to(self.device) + self.x = torch.zeros(self.B, ns).to(self.device) + self.h[:, : self.K] = self.logpdf(self.xi, **self.fargs) + self.hprime[:, : self.K] = self.grad_logpdf(self.xi, **self.fargs) + self.x[:, : self.K] = self.xi + # Avoid under/overflow errors. the envelope and pdf are only + # proportional to the true pdf, so can choose any constant of proportionality. + self.offset = self.h.max(-1)[0].view(-1, 1) + self.h = self.h - self.offset + + # Derivative at first point in xi must be > 0 + # Derivative at last point in xi must be < 0 + if not (self.hprime[:, 0] > 0).all(): + raise IOError("initial anchor points must span mode of PDF (left)") + if not (self.hprime[:, self.K - 1] < 0).all(): + raise IOError("initial anchor points must span mode of PDF (right)") + self.insert() + + def sample(self, shape=torch.Size()): + """ + Draw N samples and update upper and lower hulls accordingly + """ + shape = shape if isinstance(shape, torch.Size) else torch.Size([shape]) + samples = torch.ones(self.B, *shape).to(self.device) + bool_mask = (torch.ones(self.B, *shape) == 1).to(self.device) + count = 0 + while bool_mask.sum() != 0: + count += 1 + xt, i = self.sampleUpper(shape) + ht = self.logpdf(xt, **self.fargs) + # hprimet = self.grad_logpdf(xt, **self.fargs) + ht = ht - self.offset + ut = self.h.gather(1, i) + (xt - self.x.gather(1, i)) * self.hprime.gather( + 1, i + ) + + # Accept sample? + u = torch.rand(shape).to(self.device) + accept = u < torch.exp(ht - ut) + reject = ~accept + samples[bool_mask * accept] = xt[bool_mask * accept] + bool_mask[bool_mask * accept] = reject[bool_mask * accept] + # Update hull with new function evaluations + # if self.K < self.ns: + # nb_insert = self.ns - self.K + # self.insert(nb_insert, xt[:, :nb_insert], ht[:, :nb_insert], hprimet[:, :nb_insert]) + + return samples.t().unsqueeze(-1) + + def insert(self, nbnew=0, xnew=None, hnew=None, hprimenew=None): + """ + Update hulls with new point(s) if none given, just recalculate hull from existing x,h,hprime + #""" + self.z = torch.zeros(self.B, self.K + 1).to(self.device) + self.z[:, 0] = self.lb + self.z[:, self.K] = self.ub + self.z[:, 1 : self.K] = ( + diff(self.h[:, : self.K]) + - diff(self.x[:, : self.K] * self.hprime[:, : self.K]) + ) / -diff(self.hprime[:, : self.K]) + idx = [0] + list(range(self.K)) + self.u = self.h[:, idx] + self.hprime[:, idx] * (self.z - self.x[:, idx]) + + self.s = diff(torch.exp(self.u)) / self.hprime[:, : self.K] + self.s[self.hprime[:, : self.K] == 0.0] = 0.0 # should be 0 when gradient is 0 + self.cs = torch.cat( + (torch.zeros(self.B, 1).to(self.device), torch.cumsum(self.s, dim=-1)), + dim=-1, + ) + self.cu = self.cs[:, -1] + + def sampleUpper(self, shape=torch.Size()): + """ + Return a single value randomly sampled from the upper hull and index of segment + """ + + u = torch.rand(self.B, *shape).to(self.device) + i = (self.cs / self.cu.unsqueeze(-1)).unsqueeze(-1) <= u.unsqueeze(1).expand( + *self.cs.shape, *shape + ) + idx = i.sum(1) - 1 + + xt = self.x.gather(1, idx) + ( + -self.h.gather(1, idx) + + torch.log( + self.hprime.gather(1, idx) + * (self.cu.unsqueeze(-1) * u - self.cs.gather(1, idx)) + + torch.exp(self.u.gather(1, idx)) + ) + ) / self.hprime.gather(1, idx) + + return xt, idx + + +def cdf_r(value, scale, c, dim): + value = value.double() + scale = scale.double() + c = np.double(c) + + if dim == 2: + return ( + 1 + / torch.erf(math.sqrt(c) * scale / math.sqrt(2)) + * 0.5 + * ( + 2 * torch.erf(math.sqrt(c) * scale / math.sqrt(2)) + + torch.erf( + (value - math.sqrt(c) * scale.pow(2)) / math.sqrt(2) / scale + ) + - torch.erf( + (math.sqrt(c) * scale.pow(2) + value) / math.sqrt(2) / scale + ) + ) + ) + else: + device = value.device + + k_float = rexpand(torch.arange(dim), *value.size()).double().to(device) + dim = torch.tensor(dim).to(device).double() + + s1 = ( + torch.lgamma(dim) + - torch.lgamma(k_float + 1) + - torch.lgamma(dim - k_float) + + (dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2) / 2 + + torch.log( + torch.erf( + (value - (dim - 1 - 2 * k_float) * math.sqrt(c) * scale.pow(2)) + / scale + / math.sqrt(2) + ) + + torch.erf( + (dim - 1 - 2 * k_float) * math.sqrt(c) * scale / math.sqrt(2) + ) + ) + ) + s2 = ( + torch.lgamma(dim) + - torch.lgamma(k_float + 1) + - torch.lgamma(dim - k_float) + + (dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2) / 2 + + torch.log1p( + torch.erf((dim - 1 - 2 * k_float) * math.sqrt(c) * scale / math.sqrt(2)) + ) + ) + + signs = ( + torch.tensor([1.0, -1.0]) + .double() + .to(device) + .repeat(((int(dim) + 1) // 2) * 2)[: int(dim)] + ) + signs = rexpand(signs, *value.size()) + + S1 = log_sum_exp_signs(s1, signs, dim=0) + S2 = log_sum_exp_signs(s2, signs, dim=0) + + output = torch.exp(S1 - S2) + zero_value_idx = value == 0.0 + output[zero_value_idx] = 0.0 + return output.float() + + +def grad_cdf_value_scale(value, scale, c, dim): + device = value.device + + dim = torch.tensor(int(dim)).to(device).double() + + signs = ( + torch.tensor([1.0, -1.0]) + .double() + .to(device) + .repeat(((int(dim) + 1) // 2) * 2)[: int(dim)] + ) + signs = rexpand(signs, *value.size()) + k_float = rexpand(torch.arange(dim), *value.size()).double().to(device) + + log_arg1 = ( + (dim - 1 - 2 * k_float).pow(2) + * c + * scale + * ( + torch.erf( + (value - (dim - 1 - 2 * k_float) * math.sqrt(c) * scale.pow(2)) + / scale + / math.sqrt(2) + ) + + torch.erf((dim - 1 - 2 * k_float) * math.sqrt(c) * scale / math.sqrt(2)) + ) + ) + + log_arg2 = math.sqrt(2 / math.pi) * ( + (dim - 1 - 2 * k_float) + * math.sqrt(c) + * torch.exp(-(dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2) / 2) + - ( + (value / scale.pow(2) + (dim - 1 - 2 * k_float) * math.sqrt(c)) + * torch.exp( + -(value - (dim - 1 - 2 * k_float) * math.sqrt(c) * scale.pow(2)).pow(2) + / (2 * scale.pow(2)) + ) + ) + ) + + log_arg = log_arg1 + log_arg2 + sign_log_arg = torch.sign(log_arg) + + s = ( + torch.lgamma(dim) + - torch.lgamma(k_float + 1) + - torch.lgamma(dim - k_float) + + (dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2) / 2 + + torch.log(sign_log_arg * log_arg) + ) + + log_grad_sum_sigma = log_sum_exp_signs(s, signs * sign_log_arg, dim=0) + grad_sum_sigma = torch.sum(signs * sign_log_arg * torch.exp(s), dim=0) + + s1 = ( + torch.lgamma(dim) + - torch.lgamma(k_float + 1) + - torch.lgamma(dim - k_float) + + (dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2) / 2 + + torch.log( + torch.erf( + (value - (dim - 1 - 2 * k_float) * math.sqrt(c) * scale.pow(2)) + / scale + / math.sqrt(2) + ) + + torch.erf((dim - 1 - 2 * k_float) * math.sqrt(c) * scale / math.sqrt(2)) + ) + ) + + S1 = log_sum_exp_signs(s1, signs, dim=0) + grad_log_cdf_scale = grad_sum_sigma / S1.exp() + log_unormalised_prob = ( + -value.pow(2) / (2 * scale.pow(2)) + + (dim - 1) * logsinh(math.sqrt(c) * value) + - (dim - 1) / 2 * math.log(c) + ) + + with torch.autograd.enable_grad(): + scale = scale.float() + logZ = _log_normalizer_closed_grad.apply(scale, c, dim) + grad_logZ_scale = grad(logZ, scale, grad_outputs=torch.ones_like(scale)) + + grad_log_cdf_scale = -grad_logZ_scale[0] + 1 / scale + grad_log_cdf_scale.float() + cdf = ( + cdf_r(value.double(), scale.double(), np.double(c), int(dim)).float().squeeze(0) + ) + grad_scale = cdf * grad_log_cdf_scale + + grad_value = (log_unormalised_prob.float() - logZ).exp() + return grad_value, grad_scale + + +class _log_normalizer_closed_grad(Function): + @staticmethod + def forward(ctx, scale, c, dim): + scale = scale.double() + c = np.double(c) + ctx.scale = scale.clone().detach() + ctx.c = torch.tensor([c]).to(scale.device) + ctx.dim = dim + + device = scale.device + output = ( + 0.5 * (math.log(math.pi) - math.log(2)) + + scale.log() + - (int(dim) - 1) * (math.log(c) / 2 + math.log(2)) + ) + dim = torch.tensor(int(dim)).to(device).double() + + k_float = rexpand(torch.arange(int(dim)), *scale.size()).double().to(device) + s = ( + torch.lgamma(dim) + - torch.lgamma(k_float + 1) + - torch.lgamma(dim - k_float) + + (dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2) / 2 + + torch.log1p( + torch.erf((dim - 1 - 2 * k_float) * math.sqrt(c) * scale / math.sqrt(2)) + ) + ) + signs = ( + torch.tensor([1.0, -1.0]) + .double() + .to(device) + .repeat(((int(dim) + 1) // 2) * 2)[: int(dim)] + ) + signs = rexpand(signs, *scale.size()) + ctx.log_sum_term = log_sum_exp_signs(s, signs, dim=0) + output = output + ctx.log_sum_term + + return output.float() + + @staticmethod + def backward(ctx, grad_output): + grad_input = grad_output.clone() + + device = grad_input.device + scale = ctx.scale + c = ctx.c + dim = torch.tensor(int(ctx.dim)).to(device).double() + + k_float = rexpand(torch.arange(int(dim)), *scale.size()).double().to(device) + signs = ( + torch.tensor([1.0, -1.0]) + .double() + .to(device) + .repeat(((int(dim) + 1) // 2) * 2)[: int(dim)] + ) + signs = rexpand(signs, *scale.size()) + + log_arg = (dim - 1 - 2 * k_float).pow(2) * c * scale * ( + 1 + torch.erf((dim - 1 - 2 * k_float) * math.sqrt(c) * scale / math.sqrt(2)) + ) + torch.exp( + -(dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2) / 2 + ) * 2 / math.sqrt( + math.pi + ) * ( + dim - 1 - 2 * k_float + ) * math.sqrt( + c + ) / math.sqrt( + 2 + ) + log_arg_signs = torch.sign(log_arg) + s = ( + torch.lgamma(dim) + - torch.lgamma(k_float + 1) + - torch.lgamma(dim - k_float) + + (dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2) / 2 + + torch.log(log_arg_signs * log_arg) + ) + log_grad_sum_sigma = log_sum_exp_signs(s, log_arg_signs * signs, dim=0) + + grad_scale = torch.exp(log_grad_sum_sigma - ctx.log_sum_term) + grad_scale = 1 / ctx.scale + grad_scale + + grad_scale = ( + (grad_input * grad_scale.float()).view(-1, *grad_input.shape).sum(0) + ) + return (grad_scale, None, None) + + +class impl_rsample(Function): + @staticmethod + def forward(ctx, value, scale, c, dim): + ctx.scale = scale.clone().detach().double().requires_grad_(True) + ctx.value = value.clone().detach().double().requires_grad_(True) + ctx.c = torch.tensor([c]).to(scale.device).double().requires_grad_(True) + ctx.dim = dim + return value + + @staticmethod + def backward(ctx, grad_output): + grad_input = grad_output.clone() + grad_cdf_value, grad_cdf_scale = grad_cdf_value_scale( + ctx.value, ctx.scale, ctx.c, ctx.dim + ) + assert not torch.isnan(grad_cdf_value).any() + assert not torch.isnan(grad_cdf_scale).any() + grad_value_scale = -(grad_cdf_value).pow(-1) * grad_cdf_scale.expand( + grad_input.shape + ) + grad_scale = ( + (grad_input * grad_value_scale).view(-1, *grad_cdf_scale.shape).sum(0) + ) + # grad_value_c = -(grad_cdf_value).pow(-1) * grad_cdf_c.expand(grad_input.shape) + # grad_c = (grad_input * grad_value_c).view(-1, *grad_cdf_c.shape).sum(0) + return (None, grad_scale, None, None) + + +class HyperbolicRadius(dist.Distribution): + support = dist.constraints.positive + has_rsample = True + + def __init__(self, dim, c, scale, ars=True, validate_args=None): + self.dim = dim + self.c = c + self.scale = scale + self.device = scale.device + self.ars = ars + if isinstance(scale, Number): + batch_shape = torch.Size() + else: + batch_shape = self.scale.size() + self.log_normalizer = self._log_normalizer() + if ( + torch.isnan(self.log_normalizer).any() + or torch.isinf(self.log_normalizer).any() + ): + print( + "nan or inf in log_normalizer", + torch.cat((self.log_normalizer, self.scale), dim=1), + ) + raise + super(HyperbolicRadius, self).__init__(batch_shape) + + def rsample(self, sample_shape=torch.Size()): + value = self.sample(sample_shape) + return impl_rsample.apply(value, self.scale, self.c, self.dim) + + def sample(self, sample_shape=torch.Size()): + if sample_shape == torch.Size(): + sample_shape = torch.Size([1]) + with torch.no_grad(): + mean = self.mean + stddev = self.stddev + if torch.isnan(stddev).any(): + stddev[torch.isnan(stddev)] = self.scale[torch.isnan(stddev)] + if torch.isnan(mean).any(): + mean[torch.isnan(mean)] = ( + (self.dim - 1) * self.scale.pow(2) * math.sqrt(self.c) + )[torch.isnan(mean)] + steps = torch.linspace(0.1, 3, 10).to(self.device) + steps = torch.cat((-steps.flip(0), steps)) + xi = [mean + s * torch.min(stddev, 0.95 * mean / 3) for s in steps] + xi = torch.cat(xi, dim=1) + ars = ARS( + self.log_prob, self.grad_log_prob, self.device, xi=xi, ns=20, lb=0 + ) + value = ars.sample(sample_shape) + return value + + def log_prob(self, value): + res = ( + -value.pow(2) / (2 * self.scale.pow(2)) + + (self.dim - 1) * logsinh(math.sqrt(self.c) * value) + - (self.dim - 1) / 2 * math.log(self.c) + - self.log_normalizer + ) # .expand(value.shape) + assert not torch.isnan(res).any() + return res + + def grad_log_prob(self, value): + res = -value / self.scale.pow(2) + (self.dim - 1) * math.sqrt( + self.c + ) * torch.cosh(math.sqrt(self.c) * value) / torch.sinh( + math.sqrt(self.c) * value + ) + return res + + def cdf(self, value): + return cdf_r(value, self.scale, self.c, self.dim) + + @property + def mean(self): + c = np.double(self.c) + scale = self.scale.double() + dim = torch.tensor(int(self.dim)).double().to(self.device) + signs = ( + torch.tensor([1.0, -1.0]) + .double() + .to(self.device) + .repeat(((self.dim + 1) // 2) * 2)[: self.dim] + .unsqueeze(-1) + .unsqueeze(-1) + .expand(self.dim, *self.scale.size()) + ) + + k_float = ( + rexpand(torch.arange(self.dim), *self.scale.size()).double().to(self.device) + ) + s2 = ( + torch.lgamma(dim) + - torch.lgamma(k_float + 1) + - torch.lgamma(dim - k_float) + + (dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2) / 2 + + torch.log1p( + torch.erf((dim - 1 - 2 * k_float) * math.sqrt(c) * scale / math.sqrt(2)) + ) + ) + S2 = log_sum_exp_signs(s2, signs, dim=0) + + log_arg = (dim - 1 - 2 * k_float) * math.sqrt(c) * scale.pow(2) * ( + 1 + torch.erf((dim - 1 - 2 * k_float) * math.sqrt(c) * scale / math.sqrt(2)) + ) + torch.exp( + -(dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2) / 2 + ) * scale * math.sqrt( + 2 / math.pi + ) + log_arg_signs = torch.sign(log_arg) + s1 = ( + torch.lgamma(dim) + - torch.lgamma(k_float + 1) + - torch.lgamma(dim - k_float) + + (dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2) / 2 + + torch.log(log_arg_signs * log_arg) + ) + S1 = log_sum_exp_signs(s1, signs * log_arg_signs, dim=0) + + output = torch.exp(S1 - S2) + return output.float() + + @property + def variance(self): + c = np.double(self.c) + scale = self.scale.double() + dim = torch.tensor(int(self.dim)).double().to(self.device) + signs = ( + torch.tensor([1.0, -1.0]) + .double() + .to(self.device) + .repeat(((int(dim) + 1) // 2) * 2)[: int(dim)] + .unsqueeze(-1) + .unsqueeze(-1) + .expand(int(dim), *self.scale.size()) + ) + + k_float = ( + rexpand(torch.arange(self.dim), *self.scale.size()).double().to(self.device) + ) + s2 = ( + torch.lgamma(dim) + - torch.lgamma(k_float + 1) + - torch.lgamma(dim - k_float) + + (dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2) / 2 + + torch.log1p( + torch.erf((dim - 1 - 2 * k_float) * math.sqrt(c) * scale / math.sqrt(2)) + ) + ) + S2 = log_sum_exp_signs(s2, signs, dim=0) + + log_arg = (1 + (dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2)) * ( + 1 + torch.erf((dim - 1 - 2 * k_float) * math.sqrt(c) * scale / math.sqrt(2)) + ) + (dim - 1 - 2 * k_float) * math.sqrt(c) * torch.exp( + -(dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2) / 2 + ) * scale * math.sqrt( + 2 / math.pi + ) + log_arg_signs = torch.sign(log_arg) + s1 = ( + torch.lgamma(dim) + - torch.lgamma(k_float + 1) + - torch.lgamma(dim - k_float) + + (dim - 1 - 2 * k_float).pow(2) * c * scale.pow(2) / 2 + + 2 * scale.log() + + torch.log(log_arg_signs * log_arg) + ) + S1 = log_sum_exp_signs(s1, signs * log_arg_signs, dim=0) + + output = torch.exp(S1 - S2) + output = output.float() - self.mean.pow(2) + return output + + @property + def stddev(self): + return self.variance.sqrt() + + def _log_normalizer(self): + return _log_normalizer_closed_grad.apply(self.scale, self.c, self.dim) + + +class HypersphericalUniform(dist.Distribution): + """Taken from + https://github.com/nicola-decao/s-vae-pytorch/blob/master/hyperspherical_vae/distributions/von_mises_fisher.py + """ + + support = dist.constraints.real + has_rsample = False + _mean_carrier_measure = 0 + + @property + def dim(self): + return self._dim + + def __init__(self, dim, device="cpu", validate_args=None): + super(HypersphericalUniform, self).__init__( + torch.Size([dim]), validate_args=validate_args + ) + self._dim = dim + self._device = device + + def sample(self, shape=torch.Size()): + with torch.no_grad(): + return self.rsample(shape) + + def rsample(self, sample_shape=torch.Size()): + shape = torch.Size([*sample_shape, self._dim + 1]) + output = _standard_normal(shape, dtype=torch.float, device=self._device) + + return output / output.norm(dim=-1, keepdim=True) + + def entropy(self): + return self.__log_surface_area() + + def log_prob(self, x): + return -torch.ones(x.shape[:-1]).to(self._device) * self._log_normalizer() + + def _log_normalizer(self): + return self._log_surface_area().to(self._device) + + def _log_surface_area(self): + return ( + math.log(2) + + ((self._dim + 1) / 2) * math.log(math.pi) + - torch.lgamma(torch.Tensor([(self._dim + 1) / 2])) + ) + + +class RiemannianNormal(dist.Distribution): ## OK + # arg_constraints = {'loc': dist.constraints.interval(-1, 1), 'scale': dist.constraints.positive} + support = dist.constraints.interval(-1, 1) + has_rsample = True + + @property + def mean(self): + return self.loc + + def __init__(self, loc, scale, manifold, validate_args=None): + assert not (torch.isnan(loc).any() or torch.isnan(scale).any()) + self.manifold = manifold + self.loc = loc + self.manifold._check_point_on_manifold(self.loc) + self.scale = scale.clamp(min=0.1, max=7.0) + self.radius = HyperbolicRadius(manifold.dim, manifold.c, self.scale) + self.direction = HypersphericalUniform(manifold.dim - 1, device=loc.device) + if isinstance(loc, Number) and isinstance(scale, Number): + batch_shape = torch.Size() + else: + batch_shape = self.loc.size() + super(RiemannianNormal, self).__init__(batch_shape, validate_args=validate_args) + + def sample(self, shape=torch.Size()): + with torch.no_grad(): + return self.rsample(shape) + + def rsample(self, sample_shape=torch.Size()): + shape = self._extended_shape(sample_shape) + alpha = self.direction.sample(torch.Size([*shape[:-1]])) + radius = self.radius.rsample(sample_shape) + # u = radius * alpha / self.manifold.lambda_x(self.loc, keepdim=True) + # res = self.manifold.expmap(self.loc, u) + res = self.manifold.expmap_polar(self.loc, alpha, radius) + return res + + def log_prob(self, value): ## OK + loc = self.loc.expand(value.shape) + radius_sq = self.manifold.dist(loc, value, keepdim=True).pow(2) + res = ( + -radius_sq / 2 / self.scale.pow(2) + - self.direction._log_normalizer() + - self.radius.log_normalizer + ) + return res diff --git a/src/pythae/pipelines/generation.py b/src/pythae/pipelines/generation.py index 5d1b1ac0..97bfd3b9 100644 --- a/src/pythae/pipelines/generation.py +++ b/src/pythae/pipelines/generation.py @@ -70,6 +70,9 @@ def __init__( model=model, sampler_config=sampler_config ) + elif sampler_config.name == "PoincareDiskSamplerConfig": + sampler = PoincareDiskSampler(model=model, sampler_config=sampler_config) + else: raise NotImplementedError( "Unrecognized sampler config name... Check that that the sampler_config name " diff --git a/src/pythae/samplers/__init__.py b/src/pythae/samplers/__init__.py index 0c0b0398..9ba2674d 100644 --- a/src/pythae/samplers/__init__.py +++ b/src/pythae/samplers/__init__.py @@ -23,6 +23,7 @@ from .manifold_sampler import RHVAESampler, RHVAESamplerConfig from .normal_sampling import NormalSampler, NormalSamplerConfig from .pixelcnn_sampler import PixelCNNSampler, PixelCNNSamplerConfig +from .pvae_sampler import PoincareDiskSampler, PoincareDiskSamplerConfig from .two_stage_vae_sampler import TwoStageVAESampler, TwoStageVAESamplerConfig from .vamp_sampler import VAMPSampler, VAMPSamplerConfig @@ -47,4 +48,6 @@ "IAFSamplerConfig", "PixelCNNSampler", "PixelCNNSamplerConfig", + "PoincareDiskSampler", + "PoincareDiskSamplerConfig", ] diff --git a/src/pythae/samplers/pvae_sampler/__init__.py b/src/pythae/samplers/pvae_sampler/__init__.py new file mode 100644 index 00000000..200c9cd1 --- /dev/null +++ b/src/pythae/samplers/pvae_sampler/__init__.py @@ -0,0 +1,15 @@ +"""Implementation of a the sampling scheme from a Wrapped Riemannian or Riemannian Gaussian +distribution on the Poincaré Disk as proposed in (https://arxiv.org/abs/1901.06033). + +Available models: +------------------ + +.. autosummary:: + ~pythae.models.PoincareVAE + :nosignatures: +""" + +from .pvae_sampler import PoincareDiskSampler +from .pvae_sampler_config import PoincareDiskSamplerConfig + +__all__ = ["PoincareDiskSampler", "PoincareDiskSamplerConfig"] diff --git a/src/pythae/samplers/pvae_sampler/pvae_sampler.py b/src/pythae/samplers/pvae_sampler/pvae_sampler.py new file mode 100644 index 00000000..8eb45b17 --- /dev/null +++ b/src/pythae/samplers/pvae_sampler/pvae_sampler.py @@ -0,0 +1,102 @@ +import torch + +from ...models import PoincareVAE +from ..base import BaseSampler +from .pvae_sampler_config import PoincareDiskSamplerConfig + + +class PoincareDiskSampler(BaseSampler): + """Sampling from the Poincaré Disk using either a Wrapped Riemannian or Riemannian Gaussian + distribution. + + Args: + model (VAMP): The vae model to sample from. + sampler_config (BaseSamplerConfig): An instance of BaseSamplerConfig in which any sampler's + parameters is made available. If None a default configuration is used. Default: None + + """ + + def __init__( + self, model: PoincareVAE, sampler_config: PoincareDiskSamplerConfig = None + ): + + assert isinstance( + model, PoincareVAE + ), "This sampler is only suitable for PoincareVAE model" + + if sampler_config is None: + sampler_config = PoincareDiskSamplerConfig() + + BaseSampler.__init__(self, model=model, sampler_config=sampler_config) + + self.gen_distribution = self.model.prior( + loc=self.model._pz_mu, + scale=self.model._pz_logvar.exp(), + manifold=self.model.latent_manifold, + ) + + def sample( + self, + num_samples: int = 1, + batch_size: int = 500, + output_dir: str = None, + return_gen: bool = True, + save_sampler_config: bool = False, + ) -> torch.Tensor: + """Main sampling function of the sampler. + + Args: + num_samples (int): The number of samples to generate + batch_size (int): The batch size to use during sampling + output_dir (str): The directory where the images will be saved. If does not exist the + folder is created. If None: the images are not saved. Defaults: None. + return_gen (bool): Whether the sampler should directly return a tensor of generated + data. Default: True. + save_sampler_config (bool): Whether to save the sampler config. It is saved in + output_dir + + Returns: + ~torch.Tensor: The generated images + """ + full_batch_nbr = int(num_samples / batch_size) + last_batch_samples_nbr = num_samples % batch_size + + x_gen_list = [] + + for i in range(full_batch_nbr): + + z = self.gen_distribution.rsample(torch.Size([batch_size])).reshape( + batch_size, -1 + ) + x_gen = self.model.decoder(z)["reconstruction"].detach() + + if output_dir is not None: + for j in range(batch_size): + self.save_img( + x_gen[j], output_dir, "%08d.png" % int(batch_size * i + j) + ) + + x_gen_list.append(x_gen) + + if last_batch_samples_nbr > 0: + + z = self.gen_distribution.rsample( + torch.Size([last_batch_samples_nbr]) + ).reshape(last_batch_samples_nbr, -1) + x_gen = self.model.decoder(z)["reconstruction"].detach() + + if output_dir is not None: + for j in range(last_batch_samples_nbr): + self.save_img( + x_gen[j], + output_dir, + "%08d.png" % int(batch_size * full_batch_nbr + j), + ) + + x_gen_list.append(x_gen) + + if save_sampler_config: + self.save(output_dir) + + if return_gen: + return torch.cat(x_gen_list, dim=0) diff --git a/src/pythae/samplers/pvae_sampler/pvae_sampler_config.py b/src/pythae/samplers/pvae_sampler/pvae_sampler_config.py new file mode 100644 index 00000000..c4466241 --- /dev/null +++ b/src/pythae/samplers/pvae_sampler/pvae_sampler_config.py @@ -0,0 +1,12 @@ +from pydantic.dataclasses import dataclass + +from ..base import BaseSamplerConfig + + +@dataclass +class PoincareDiskSamplerConfig(BaseSamplerConfig): + """This is the Poincare Disk prior sampler configuration instance deriving from + :class:`BaseSamplerConfig`. + """ + + pass diff --git a/tests/test_PoincareVAE.py b/tests/test_PoincareVAE.py new file mode 100644 index 00000000..9c086d79 --- /dev/null +++ b/tests/test_PoincareVAE.py @@ -0,0 +1,863 @@ +import os +from copy import deepcopy + +import pytest +from sklearn import manifold +import torch +from torch.optim import Adam + +from pythae.customexception import BadInheritanceError +from pythae.models.base.base_utils import ModelOutput +from pythae.models import PoincareVAE, PoincareVAEConfig, AutoModel +from pythae.models.pvae.pvae_utils import PoincareBall +from pythae.samplers import PoincareDiskSamplerConfig, NormalSamplerConfig, GaussianMixtureSamplerConfig, MAFSamplerConfig, TwoStageVAESamplerConfig, IAFSamplerConfig +from pythae.trainers import BaseTrainer, BaseTrainerConfig +from pythae.pipelines import TrainingPipeline, GenerationPipeline +from tests.data.custom_architectures import ( + Decoder_AE_Conv, + Encoder_VAE_Conv, + Encoder_SVAE_Conv, + NetBadInheritance, +) + +PATH = os.path.dirname(os.path.abspath(__file__)) + + +@pytest.fixture(params=[PoincareVAEConfig(), PoincareVAEConfig(latent_dim=5, prior_distribution="wrapped_normal", posterior_distribution="riemannian_normal", curvature=0.5)]) +def model_configs_no_input_dim(request): + return request.param + + +@pytest.fixture( + params=[ + PoincareVAEConfig(input_dim=(1, 28, 28), latent_dim=2, reconstruction_loss="bce", prior_distribution="wrapped_normal", posterior_distribution="wrapped_normal", curvature=0.7), + PoincareVAEConfig(input_dim=(1, 28), latent_dim=5, prior_distribution="riemannian_normal", posterior_distribution="riemannian_normal", curvature=0.8), + ] +) +def model_configs(request): + return request.param + + +@pytest.fixture +def custom_encoder(model_configs): + if model_configs.posterior_distribution == "riemannian_normal": + return Encoder_SVAE_Conv(model_configs) + return Encoder_VAE_Conv(model_configs) + + +@pytest.fixture +def custom_decoder(model_configs): + return Decoder_AE_Conv(model_configs) + + +class Test_Model_Building: + @pytest.fixture() + def bad_net(self): + return NetBadInheritance() + + def test_build_model(self, model_configs): + model = PoincareVAE(model_configs) + assert all( + [ + model.input_dim == model_configs.input_dim, + model.latent_dim == model_configs.latent_dim, + ] + ) + + def test_raises_bad_inheritance(self, model_configs, bad_net): + with pytest.raises(BadInheritanceError): + model = PoincareVAE(model_configs, encoder=bad_net) + + with pytest.raises(BadInheritanceError): + model = PoincareVAE(model_configs, decoder=bad_net) + + def test_raises_no_input_dim( + self, model_configs_no_input_dim, custom_encoder, custom_decoder + ): + with pytest.raises(AttributeError): + model = PoincareVAE(model_configs_no_input_dim) + + with pytest.raises(AttributeError): + model = PoincareVAE(model_configs_no_input_dim, encoder=custom_encoder) + + with pytest.raises(AttributeError): + model = PoincareVAE(model_configs_no_input_dim, decoder=custom_decoder) + + model = PoincareVAE( + model_configs_no_input_dim, encoder=custom_encoder, decoder=custom_decoder + ) + + def test_build_custom_arch(self, model_configs, custom_encoder, custom_decoder): + + model = PoincareVAE(model_configs, encoder=custom_encoder, decoder=custom_decoder) + + assert model.encoder == custom_encoder + assert not model.model_config.uses_default_encoder + assert model.decoder == custom_decoder + assert not model.model_config.uses_default_decoder + + model = PoincareVAE(model_configs, encoder=custom_encoder) + + assert model.encoder == custom_encoder + assert not model.model_config.uses_default_encoder + assert model.model_config.uses_default_decoder + + model = PoincareVAE(model_configs, decoder=custom_decoder) + + assert model.model_config.uses_default_encoder + assert model.decoder == custom_decoder + assert not model.model_config.uses_default_decoder + + def test_misc_manifold_func(self): + + manifold = PoincareBall(dim=2, c=0.7) + x = torch.randn(10, 2) + y = torch.randn(10, 2) + manifold.logmap0(x, y) + manifold.expmap0(x) + manifold.transp0(x, y) + manifold.normdist2plane(x, x, x) + manifold.normdist2plane(x, x, x, signed=True, norm=True) + + + +class Test_Model_Saving: + def test_default_model_saving(self, tmpdir, model_configs): + + tmpdir.mkdir("dummy_folder") + dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") + + model = PoincareVAE(model_configs) + + model.state_dict()["encoder.layers.0.0.weight"][0] = 0 + + model.save(dir_path=dir_path) + + assert set(os.listdir(dir_path)) == set(["model_config.json", "model.pt", "environment.json"]) + + # reload model + model_rec = AutoModel.load_from_folder(dir_path) + + # check configs are the same + assert model_rec.model_config.__dict__ == model.model_config.__dict__ + + assert all( + [ + torch.equal(model_rec.state_dict()[key], model.state_dict()[key]) + for key in model.state_dict().keys() + ] + ) + + def test_custom_encoder_model_saving(self, tmpdir, model_configs, custom_encoder): + + tmpdir.mkdir("dummy_folder") + dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") + + model = PoincareVAE(model_configs, encoder=custom_encoder) + + model.state_dict()["encoder.layers.0.0.weight"][0] = 0 + + model.save(dir_path=dir_path) + + assert set(os.listdir(dir_path)) == set( + ["model_config.json", "model.pt", "encoder.pkl", "environment.json"] + ) + + # reload model + model_rec = AutoModel.load_from_folder(dir_path) + + # check configs are the same + assert model_rec.model_config.__dict__ == model.model_config.__dict__ + + assert all( + [ + torch.equal(model_rec.state_dict()[key], model.state_dict()[key]) + for key in model.state_dict().keys() + ] + ) + + def test_custom_decoder_model_saving(self, tmpdir, model_configs, custom_decoder): + + tmpdir.mkdir("dummy_folder") + dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") + + model = PoincareVAE(model_configs, decoder=custom_decoder) + + model.state_dict()["encoder.layers.0.0.weight"][0] = 0 + + model.save(dir_path=dir_path) + + assert set(os.listdir(dir_path)) == set( + ["model_config.json", "model.pt", "decoder.pkl", "environment.json"] + ) + + # reload model + model_rec = AutoModel.load_from_folder(dir_path) + + # check configs are the same + assert model_rec.model_config.__dict__ == model.model_config.__dict__ + + assert all( + [ + torch.equal(model_rec.state_dict()[key], model.state_dict()[key]) + for key in model.state_dict().keys() + ] + ) + + def test_full_custom_model_saving( + self, tmpdir, model_configs, custom_encoder, custom_decoder + ): + + tmpdir.mkdir("dummy_folder") + dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") + + model = PoincareVAE(model_configs, encoder=custom_encoder, decoder=custom_decoder) + + model.state_dict()["encoder.layers.0.0.weight"][0] = 0 + + model.save(dir_path=dir_path) + + assert set(os.listdir(dir_path)) == set( + [ + "model_config.json", + "model.pt", + "encoder.pkl", + "decoder.pkl", + "environment.json" + ] + ) + + # reload model + model_rec = AutoModel.load_from_folder(dir_path) + + # check configs are the same + assert model_rec.model_config.__dict__ == model.model_config.__dict__ + + assert all( + [ + torch.equal(model_rec.state_dict()[key], model.state_dict()[key]) + for key in model.state_dict().keys() + ] + ) + + def test_raises_missing_files( + self, tmpdir, model_configs, custom_encoder, custom_decoder + ): + + tmpdir.mkdir("dummy_folder") + dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") + + model = PoincareVAE(model_configs, encoder=custom_encoder, decoder=custom_decoder) + + model.state_dict()["encoder.layers.0.0.weight"][0] = 0 + + model.save(dir_path=dir_path) + + os.remove(os.path.join(dir_path, "decoder.pkl")) + + # check raises decoder.pkl is missing + with pytest.raises(FileNotFoundError): + model_rec = AutoModel.load_from_folder(dir_path) + + os.remove(os.path.join(dir_path, "encoder.pkl")) + + # check raises encoder.pkl is missing + with pytest.raises(FileNotFoundError): + model_rec = AutoModel.load_from_folder(dir_path) + + os.remove(os.path.join(dir_path, "model.pt")) + + # check raises encoder.pkl is missing + with pytest.raises(FileNotFoundError): + model_rec = AutoModel.load_from_folder(dir_path) + + os.remove(os.path.join(dir_path, "model_config.json")) + + # check raises encoder.pkl is missing + with pytest.raises(FileNotFoundError): + model_rec = AutoModel.load_from_folder(dir_path) + + +class Test_Model_forward: + @pytest.fixture + def demo_data(self): + data = torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ] + return data # This is an extract of 3 data from MNIST (unnormalized) used to test custom architecture + + @pytest.fixture + def vae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data["data"][0].shape) + return PoincareVAE(model_configs) + + def test_model_train_output(self, vae, demo_data): + + vae.train() + + out = vae(demo_data) + + assert isinstance(out, ModelOutput) + + assert set(["reconstruction_loss", "reg_loss", "loss", "recon_x", "z"]) == set( + out.keys() + ) + + assert out.z.shape[0] == demo_data["data"].shape[0] + assert out.recon_x.shape == demo_data["data"].shape + +class Test_Model_interpolate: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture() + def granularity(self): + return int(torch.randint(1, 10, (1,))) + + @pytest.fixture + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return PoincareVAE(model_configs) + + + def test_interpolate(self, ae, demo_data, granularity): + with pytest.raises(AssertionError): + ae.interpolate(demo_data, demo_data[1:], granularity) + + interp = ae.interpolate(demo_data, demo_data, granularity) + + assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + +class Test_Model_reconstruct: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return PoincareVAE(model_configs) + + + def test_reconstruct(self, ae, demo_data): + + recon = ae.reconstruct(demo_data) + assert tuple(recon.shape) == demo_data.shape + + +class Test_NLL_Compute: + @pytest.fixture + def demo_data(self): + data = torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ] + return data # This is an extract of 3 data from MNIST (unnormalized) used to test custom architecture + + @pytest.fixture + def vae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data["data"][0].shape) + return PoincareVAE(model_configs) + + @pytest.fixture(params=[(20, 10), (11, 22)]) + def nll_params(self, request): + return request.param + + def test_nll_compute(self, vae, demo_data, nll_params): + nll = vae.get_nll( + data=demo_data["data"], n_samples=nll_params[0], batch_size=nll_params[1] + ) + + assert isinstance(nll, float) + assert nll < 0 + + +@pytest.mark.slow +class Test_VAE_Training: + @pytest.fixture + def train_dataset(self): + return torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample")) + + @pytest.fixture( + params=[BaseTrainerConfig(num_epochs=3, steps_saving=2, learning_rate=1e-5)] + ) + def training_configs(self, tmpdir, request): + tmpdir.mkdir("dummy_folder") + dir_path = os.path.join(tmpdir, "dummy_folder") + request.param.output_dir = dir_path + return request.param + + @pytest.fixture( + params=[ + torch.rand(1), + torch.rand(1), + torch.rand(1), + torch.rand(1), + torch.rand(1), + ] + ) + def vae(self, model_configs, custom_encoder, custom_decoder, request): + # randomized + + alpha = request.param + + if alpha < 0.25: + model = PoincareVAE(model_configs) + + elif 0.25 <= alpha < 0.5: + model = PoincareVAE(model_configs, encoder=custom_encoder) + + elif 0.5 <= alpha < 0.75: + model = PoincareVAE(model_configs, decoder=custom_decoder) + + else: + model = PoincareVAE(model_configs, encoder=custom_encoder, decoder=custom_decoder) + + return model + + @pytest.fixture(params=[Adam]) + def optimizers(self, request, vae, training_configs): + if request.param is not None: + optimizer = request.param( + vae.parameters(), lr=training_configs.learning_rate + ) + + else: + optimizer = None + + return optimizer + + def test_vae_train_step(self, vae, train_dataset, training_configs, optimizers): + trainer = BaseTrainer( + model=vae, + train_dataset=train_dataset, + training_config=training_configs, + optimizer=optimizers, + ) + + start_model_state_dict = deepcopy(trainer.model.state_dict()) + + step_1_loss = trainer.train_step(epoch=1) + + step_1_model_state_dict = deepcopy(trainer.model.state_dict()) + + # check that weights were updated + assert not all( + [ + torch.equal(start_model_state_dict[key], step_1_model_state_dict[key]) + for key in start_model_state_dict.keys() + ] + ) + + def test_vae_eval_step(self, vae, train_dataset, training_configs, optimizers): + trainer = BaseTrainer( + model=vae, + train_dataset=train_dataset, + eval_dataset=train_dataset, + training_config=training_configs, + optimizer=optimizers, + ) + + start_model_state_dict = deepcopy(trainer.model.state_dict()) + + step_1_loss = trainer.eval_step(epoch=1) + + step_1_model_state_dict = deepcopy(trainer.model.state_dict()) + + # check that weights were not updated + assert all( + [ + torch.equal(start_model_state_dict[key], step_1_model_state_dict[key]) + for key in start_model_state_dict.keys() + ] + ) + + def test_vae_predict_step( + self, vae, train_dataset, training_configs, optimizers + ): + trainer = BaseTrainer( + model=vae, + train_dataset=train_dataset, + eval_dataset=train_dataset, + training_config=training_configs, + optimizer=optimizers, + ) + + start_model_state_dict = deepcopy(trainer.model.state_dict()) + + inputs, recon, generated = trainer.predict(trainer.model) + + step_1_model_state_dict = deepcopy(trainer.model.state_dict()) + + # check that weights were not updated + assert all( + [ + torch.equal(start_model_state_dict[key], step_1_model_state_dict[key]) + for key in start_model_state_dict.keys() + ] + ) + + assert torch.equal(inputs.cpu(), train_dataset.data.cpu()) + assert recon.shape == inputs.shape + assert generated.shape == inputs.shape + + def test_vae_main_train_loop( + self, tmpdir, vae, train_dataset, training_configs, optimizers + ): + + trainer = BaseTrainer( + model=vae, + train_dataset=train_dataset, + eval_dataset=train_dataset, + training_config=training_configs, + optimizer=optimizers, + ) + + start_model_state_dict = deepcopy(trainer.model.state_dict()) + + trainer.train() + + step_1_model_state_dict = deepcopy(trainer.model.state_dict()) + + # check that weights were updated + assert not all( + [ + torch.equal(start_model_state_dict[key], step_1_model_state_dict[key]) + for key in start_model_state_dict.keys() + ] + ) + + def test_checkpoint_saving( + self, tmpdir, vae, train_dataset, training_configs, optimizers + ): + + dir_path = training_configs.output_dir + + trainer = BaseTrainer( + model=vae, + train_dataset=train_dataset, + training_config=training_configs, + optimizer=optimizers, + ) + + # Make a training step + step_1_loss = trainer.train_step(epoch=1) + + model = deepcopy(trainer.model) + optimizer = deepcopy(trainer.optimizer) + + trainer.save_checkpoint(dir_path=dir_path, epoch=0, model=model) + + checkpoint_dir = os.path.join(dir_path, "checkpoint_epoch_0") + + assert os.path.isdir(checkpoint_dir) + + files_list = os.listdir(checkpoint_dir) + + assert set(["model.pt", "optimizer.pt", "training_config.json"]).issubset( + set(files_list) + ) + + # check pickled custom decoder + if not vae.model_config.uses_default_decoder: + assert "decoder.pkl" in files_list + + else: + assert not "decoder.pkl" in files_list + + # check pickled custom encoder + if not vae.model_config.uses_default_encoder: + assert "encoder.pkl" in files_list + + else: + assert not "encoder.pkl" in files_list + + model_rec_state_dict = torch.load(os.path.join(checkpoint_dir, "model.pt"))[ + "model_state_dict" + ] + + assert all( + [ + torch.equal( + model_rec_state_dict[key].cpu(), model.state_dict()[key].cpu() + ) + for key in model.state_dict().keys() + ] + ) + + # check reload full model + model_rec = AutoModel.load_from_folder(os.path.join(checkpoint_dir)) + + assert all( + [ + torch.equal( + model_rec.state_dict()[key].cpu(), model.state_dict()[key].cpu() + ) + for key in model.state_dict().keys() + ] + ) + + assert type(model_rec.encoder.cpu()) == type(model.encoder.cpu()) + assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) + + optim_rec_state_dict = torch.load(os.path.join(checkpoint_dir, "optimizer.pt")) + + assert all( + [ + dict_rec == dict_optimizer + for (dict_rec, dict_optimizer) in zip( + optim_rec_state_dict["param_groups"], + optimizer.state_dict()["param_groups"], + ) + ] + ) + + assert all( + [ + dict_rec == dict_optimizer + for (dict_rec, dict_optimizer) in zip( + optim_rec_state_dict["state"], optimizer.state_dict()["state"] + ) + ] + ) + + def test_checkpoint_saving_during_training( + self, tmpdir, vae, train_dataset, training_configs, optimizers + ): + # + target_saving_epoch = training_configs.steps_saving + + dir_path = training_configs.output_dir + + trainer = BaseTrainer( + model=vae, + train_dataset=train_dataset, + training_config=training_configs, + optimizer=optimizers, + ) + + model = deepcopy(trainer.model) + + trainer.train() + + training_dir = os.path.join( + dir_path, f"PoincareVAE_training_{trainer._training_signature}" + ) + assert os.path.isdir(training_dir) + + checkpoint_dir = os.path.join( + training_dir, f"checkpoint_epoch_{target_saving_epoch}" + ) + + assert os.path.isdir(checkpoint_dir) + + files_list = os.listdir(checkpoint_dir) + + # check files + assert set(["model.pt", "optimizer.pt", "training_config.json"]).issubset( + set(files_list) + ) + + # check pickled custom decoder + if not vae.model_config.uses_default_decoder: + assert "decoder.pkl" in files_list + + else: + assert not "decoder.pkl" in files_list + + # check pickled custom encoder + if not vae.model_config.uses_default_encoder: + assert "encoder.pkl" in files_list + + else: + assert not "encoder.pkl" in files_list + + model_rec_state_dict = torch.load(os.path.join(checkpoint_dir, "model.pt"))[ + "model_state_dict" + ] + + assert not all( + [ + torch.equal(model_rec_state_dict[key], model.state_dict()[key]) + for key in model.state_dict().keys() + ] + ) + + def test_final_model_saving( + self, tmpdir, vae, train_dataset, training_configs, optimizers + ): + + dir_path = training_configs.output_dir + + trainer = BaseTrainer( + model=vae, + train_dataset=train_dataset, + training_config=training_configs, + optimizer=optimizers, + ) + + trainer.train() + + model = deepcopy(trainer._best_model) + + training_dir = os.path.join( + dir_path, f"PoincareVAE_training_{trainer._training_signature}" + ) + assert os.path.isdir(training_dir) + + final_dir = os.path.join(training_dir, f"final_model") + assert os.path.isdir(final_dir) + + files_list = os.listdir(final_dir) + + assert set(["model.pt", "model_config.json", "training_config.json"]).issubset( + set(files_list) + ) + + # check pickled custom decoder + if not vae.model_config.uses_default_decoder: + assert "decoder.pkl" in files_list + + else: + assert not "decoder.pkl" in files_list + + # check pickled custom encoder + if not vae.model_config.uses_default_encoder: + assert "encoder.pkl" in files_list + + else: + assert not "encoder.pkl" in files_list + + # check reload full model + model_rec = AutoModel.load_from_folder(os.path.join(final_dir)) + + assert all( + [ + torch.equal( + model_rec.state_dict()[key].cpu(), model.state_dict()[key].cpu() + ) + for key in model.state_dict().keys() + ] + ) + + assert type(model_rec.encoder.cpu()) == type(model.encoder.cpu()) + assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) + + def test_vae_training_pipeline(self, tmpdir, vae, train_dataset, training_configs): + + dir_path = training_configs.output_dir + + # build pipeline + pipeline = TrainingPipeline(model=vae, training_config=training_configs) + + assert pipeline.training_config.__dict__ == training_configs.__dict__ + + # Launch Pipeline + pipeline( + train_data=train_dataset.data, # gives tensor to pipeline + eval_data=train_dataset.data, # gives tensor to pipeline + ) + + model = deepcopy(pipeline.trainer._best_model) + + training_dir = os.path.join( + dir_path, f"PoincareVAE_training_{pipeline.trainer._training_signature}" + ) + assert os.path.isdir(training_dir) + + final_dir = os.path.join(training_dir, f"final_model") + assert os.path.isdir(final_dir) + + files_list = os.listdir(final_dir) + + assert set(["model.pt", "model_config.json", "training_config.json"]).issubset( + set(files_list) + ) + + # check pickled custom decoder + if not vae.model_config.uses_default_decoder: + assert "decoder.pkl" in files_list + + else: + assert not "decoder.pkl" in files_list + + # check pickled custom encoder + if not vae.model_config.uses_default_encoder: + assert "encoder.pkl" in files_list + + else: + assert not "encoder.pkl" in files_list + + # check reload full model + model_rec = AutoModel.load_from_folder(os.path.join(final_dir)) + + assert all( + [ + torch.equal( + model_rec.state_dict()[key].cpu(), model.state_dict()[key].cpu() + ) + for key in model.state_dict().keys() + ] + ) + + assert type(model_rec.encoder.cpu()) == type(model.encoder.cpu()) + assert type(model_rec.decoder.cpu()) == type(model.decoder.cpu()) + +class Test_VAE_Generation: + @pytest.fixture + def train_data(self): + return torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample")).data + + @pytest.fixture(params=[ + PoincareVAEConfig(input_dim=(1, 28, 28), latent_dim=7, prior_distribution="wrapped_normal", curvature=0.2), + PoincareVAEConfig(input_dim=(1, 28, 28), latent_dim=2, prior_distribution="riemannian_normal", curvature=0.7) + ]) + def ae_config(self, request): + return request.param + + @pytest.fixture + def ae_model(self, ae_config): + return PoincareVAE(ae_config) + + @pytest.fixture( + params=[ + PoincareDiskSamplerConfig(), + NormalSamplerConfig(), + GaussianMixtureSamplerConfig(), + MAFSamplerConfig(), + IAFSamplerConfig(), + TwoStageVAESamplerConfig() + ] + ) + def sampler_configs(self, request): + return request.param + + def test_fits_in_generation_pipeline(self, ae_model, sampler_configs, train_data): + pipeline = GenerationPipeline(model=ae_model, sampler_config=sampler_configs) + gen_data = pipeline( + num_samples=11, + batch_size=7, + output_dir=None, + return_gen=True, + train_data=train_data, + eval_data=train_data, + training_config=BaseTrainerConfig(num_epochs=1) + ) + + assert gen_data.shape[0] == 11 \ No newline at end of file diff --git a/tests/test_VAE.py b/tests/test_VAE.py index e0521f20..a244b837 100644 --- a/tests/test_VAE.py +++ b/tests/test_VAE.py @@ -3,7 +3,7 @@ import pytest import torch -from torch.optim import SGD, Adadelta, Adagrad, Adam, RMSprop +from torch.optim import Adam from pythae.customexception import BadInheritanceError from pythae.models.base.base_utils import ModelOutput diff --git a/tests/test_adversarial_trainer.py b/tests/test_adversarial_trainer.py index e689b449..78ab8e4b 100644 --- a/tests/test_adversarial_trainer.py +++ b/tests/test_adversarial_trainer.py @@ -87,7 +87,7 @@ class Test_Set_Training_config: params=[ AdversarialTrainerConfig(autoencoder_optim_decay=0), AdversarialTrainerConfig( - batch_size=10, learning_rate=1e-5, autoencoder_optim_decay=0 + batch_size=10, learning_rate=1e-3, autoencoder_optim_decay=0 ), ] ) @@ -269,7 +269,7 @@ def test_set_custom_scheduler( @pytest.mark.slow class Test_Main_Training: - @pytest.fixture(params=[AdversarialTrainerConfig(num_epochs=3)]) + @pytest.fixture(params=[AdversarialTrainerConfig(num_epochs=3, learning_rate=1e-3)]) def training_configs(self, tmpdir, request): tmpdir.mkdir("dummy_folder") dir_path = os.path.join(tmpdir, "dummy_folder") diff --git a/tests/test_pvae_sampler.py b/tests/test_pvae_sampler.py new file mode 100644 index 00000000..1333feac --- /dev/null +++ b/tests/test_pvae_sampler.py @@ -0,0 +1,180 @@ +import os + +import numpy as np +import pytest +import torch + +from pythae.models import PoincareVAE, PoincareVAEConfig +from pythae.samplers import PoincareDiskSampler, PoincareDiskSamplerConfig, NormalSampler, NormalSamplerConfig +from pythae.pipelines import GenerationPipeline + +PATH = os.path.dirname(os.path.abspath(__file__)) + + +@pytest.fixture +def dummy_data(): + ### 3 imgs from mnist that are used to simulated generated ones + return torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample")).data + + +@pytest.fixture( + params=[ + PoincareVAE(PoincareVAEConfig(input_dim=(1, 28, 28), latent_dim=7, prior_distribution="wrapped_normal", curvature=0.2)), + PoincareVAE(PoincareVAEConfig(input_dim=(1, 28, 28), latent_dim=2, prior_distribution="riemannian_normal", curvature=0.7)) + ] +) +def model(request): + return request.param + +@pytest.fixture( + params=[ + PoincareDiskSamplerConfig(), + None, + ] +) +def sampler_config(request): + return request.param + + +@pytest.fixture() +def sampler(model, sampler_config): + return PoincareDiskSampler(model=model, sampler_config=sampler_config) + + +@pytest.fixture(params=[(4, 2), (5, 5), (2, 3)]) +def num_sample_and_batch_size(request): + return request.param + + +class Test_PoicareSampler_saving: + def test_save_config(self, tmpdir, sampler): + + tmpdir.mkdir("dummy_folder") + dir_path = os.path.join(tmpdir, "dummy_folder") + + sampler.save(dir_path) + + sampler_config_file = os.path.join(dir_path, "sampler_config.json") + + assert os.path.isfile(sampler_config_file) + + generation_config_rec = PoincareDiskSamplerConfig.from_json_file(sampler_config_file) + + assert generation_config_rec.__dict__ == sampler.sampler_config.__dict__ + + +class Test_PoicareSampler_Sampling: + def test_return_sampling( + self, model, dummy_data, sampler, num_sample_and_batch_size + ): + + num_samples, batch_size = ( + num_sample_and_batch_size[0], + num_sample_and_batch_size[1], + ) + + sampler.fit(train_data=dummy_data) + + gen_samples = sampler.sample( + num_samples=num_samples, batch_size=batch_size, return_gen=True + ) + + assert gen_samples.shape[0] == num_samples + + def test_save_sampling( + self, tmpdir, dummy_data, model, sampler, num_sample_and_batch_size + ): + + dir_path = os.path.join(tmpdir, "dummy_folder") + num_samples, batch_size = ( + num_sample_and_batch_size[0], + num_sample_and_batch_size[1], + ) + + sampler.fit(train_data=dummy_data) + + gen_samples = sampler.sample( + num_samples=num_samples, + batch_size=batch_size, + output_dir=dir_path, + return_gen=True, + ) + + assert gen_samples.shape[0] == num_samples + assert len(os.listdir(dir_path)) == num_samples + + def test_save_sampling_and_sampler_config( + self, tmpdir, dummy_data, model, sampler, num_sample_and_batch_size + ): + + dir_path = os.path.join(tmpdir, "dummy_folder") + num_samples, batch_size = ( + num_sample_and_batch_size[0], + num_sample_and_batch_size[1], + ) + + sampler.fit(train_data=dummy_data) + + gen_samples = sampler.sample( + num_samples=num_samples, + batch_size=batch_size, + output_dir=dir_path, + return_gen=True, + save_sampler_config=True, + ) + + assert gen_samples.shape[0] == num_samples + assert len(os.listdir(dir_path)) == num_samples + 1 + assert "sampler_config.json" in os.listdir(dir_path) + + def test_generation_pipeline( + self, tmpdir, dummy_data, model, sampler_config, num_sample_and_batch_size + ): + + dir_path = os.path.join(tmpdir, "dummy_folder1") + num_samples, batch_size = ( + num_sample_and_batch_size[0], + num_sample_and_batch_size[1], + ) + + pipe = GenerationPipeline(model=model, sampler_config=None) + + assert isinstance(pipe.sampler, NormalSampler) + assert pipe.sampler.sampler_config == NormalSamplerConfig() + + gen_data = pipe(num_samples=num_samples, + batch_size=batch_size, + output_dir=dir_path, + return_gen=True, + save_sampler_config=True, + train_data=dummy_data, + eval_data=None + ) + + assert tuple(gen_data.shape) == (num_samples,) + tuple(model.model_config.input_dim) + assert len(os.listdir(dir_path)) == num_samples + 1 + assert "sampler_config.json" in os.listdir(dir_path) + + dir_path = os.path.join(tmpdir, "dummy_folder2") + + pipe = GenerationPipeline(model=model, sampler_config=sampler_config) + + if sampler_config is None: + assert isinstance(pipe.sampler, NormalSampler) + + else: + assert isinstance(pipe.sampler, PoincareDiskSampler) + assert pipe.sampler.sampler_config == sampler_config + + gen_data = pipe(num_samples=num_samples, + batch_size=batch_size, + output_dir=dir_path, + return_gen=False, + save_sampler_config=False, + train_data=dummy_data, + eval_data=dummy_data + ) + + assert gen_data is None + assert "sampler_config.json" not in os.listdir(dir_path) + assert len(os.listdir(dir_path)) == num_samples