Skip to content

Commit

Permalink
Merge branch 'main' into unify_gaussian_likelihoods
Browse files Browse the repository at this point in the history
  • Loading branch information
clementchadebec committed Sep 6, 2023
2 parents e0d9d13 + ca55647 commit 9ce2ed4
Show file tree
Hide file tree
Showing 13 changed files with 475 additions and 47 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ To launch a model training, you only need to call a `TrainingPipeline` instance.

At the end of training, the best model weights, model configuration and training configuration are stored in a `final_model` folder available in `my_model/MODEL_NAME_training_YYYY-MM-DD_hh-mm-ss` (with `my_model` being the `output_dir` argument of the `BaseTrainerConfig`). If you further set the `steps_saving` argument to a certain value, folders named `checkpoint_epoch_k` containing the best model weights, optimizer, scheduler, configuration and training configuration at epoch *k* will also appear in `my_model/MODEL_NAME_training_YYYY-MM-DD_hh-mm-ss`.

## Lauching a training on benchmark datasets
## Launching a training on benchmark datasets
We also provide a training script example [here](https://github.com/clementchadebec/benchmark_VAE/tree/main/examples/scripts/training.py) that can be used to train the models on benchmarks datasets (mnist, cifar10, celeba ...). The script can be launched with the following commandline

```bash
Expand Down Expand Up @@ -600,7 +600,7 @@ First let's have a look at the reconstructed samples taken from the evaluation s
----------------------------
### Generation

Here, we show the generated samples using using each model implemented in the library and different samplers.
Here, we show the generated samples using each model implemented in the library and different samplers.

| Models | MNIST | CELEBA
|:----------------------------------:|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:--------------------------------------------:|
Expand Down
3 changes: 2 additions & 1 deletion examples/scripts/configs/cifar10/base_training_config.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
{
"name": "BaseTrainerConfig",
"output_dir": "my_models_on_cifar",
"batch_size": 100,
"per_device_train_batch_size": 100,
"per_device_eval_batch_size": 100,
"num_epochs": 100,
"learning_rate": 1e-4,
"steps_saving": null,
Expand Down
5 changes: 3 additions & 2 deletions examples/scripts/configs/cifar10/beta_vae_config.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{
"latent_dim": 16,
"reconstruction_loss": "mse",
"number_components": 50
}
"number_components": 50,
"name": "BetaVAE"
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
{
"name": "BaseTrainerConfig",
"output_dir": "reproducibility/dsprites",
"batch_size": 1000,
"per_device_train_batch_size": 1000,
"per_device_eval_batch_size": 1000,
"num_epochs": 50,
"learning_rate": 1e-3,
"steps_saving": null,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
{
"name": "BaseTrainerConfig",
"output_dir": "reproducibility/dsprites",
"batch_size": 64,
"per_device_train_batch_size": 64,
"per_device_eval_batch_size": 64,
"num_epochs": 500,
"learning_rate": 1e-3,
"steps_saving": 50,
Expand Down
3 changes: 2 additions & 1 deletion examples/scripts/configs/mnist/base_training_config.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
{
"name": "BaseTrainerConfig",
"output_dir": "my_models_on_mnist",
"batch_size": 100,
"per_device_train_batch_size": 100,
"per_device_eval_batch_size": 100,
"num_epochs": 100,
"learning_rate": 1e-3,
"steps_saving": null,
Expand Down
98 changes: 98 additions & 0 deletions examples/scripts/data-download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import sys
from torchvision.datasets import MNIST, CelebA, CIFAR10
import argparse
from pathlib import Path
import torch
import numpy as np
from torchvision.transforms import PILToTensor
from tqdm import tqdm


def main():

parser = argparse.ArgumentParser(
description="python script to download datasets which are available with torchvision"
)

parser.add_argument(
"-j", "--nthreads", type=int, default=1, help="number of threads to use"
)
parser.add_argument(
"-b", "--batchsize", type=int, default=64, help="batch_size for loading"
)
parser.add_argument(
"-o",
"--outdir",
type=Path,
default=Path("."),
help="the base folder in which to store the output",
)
parser.add_argument(
"dataset",
nargs="+",
help="datasets to download (possible values: MNIST, CelebA, CIFAR10)",
)
args = parser.parse_args()
if not "dataset" in args:
print("dataset argument not found in", args)
parser.print_help()
return 1

tv_datasets = {"mnist": MNIST, "celeba": CelebA, "cifar10": CIFAR10}
rootdir = args.outdir
if not rootdir.exists():
print(f"creating root folder {rootdir}")
rootdir.mkdir(parents=True)

for dname in args.dataset:
if dname.lower() not in tv_datasets.keys():
print(f"{dname} not available for download yet. skipping.")
continue

dfolder = rootdir / dname
dataset = tv_datasets[dname]
if "celeba" in dname.lower():
train_kwarg = {"split": "train"}
val_kwarg = {"split": "val"}
else:
train_kwarg = {"train": True}
val_kwarg = {"train": False}

train_data = dataset(
dfolder, download=True, transform=PILToTensor(), **train_kwarg
)
train_loader = torch.utils.data.DataLoader(
train_data, batch_size=4, shuffle=False, num_workers=args.nthreads
)

train_batches = []
for b, (x, y) in enumerate(tqdm(train_loader)):
train_batches.append(x.clone().detach().numpy())

val_data = dataset(dfolder, download=True, transform=PILToTensor(), **val_kwarg)
val_loader = torch.utils.data.DataLoader(
val_data, batch_size=4, shuffle=True, num_workers=args.nthreads
)
val_batches = []
for b, (x, y) in enumerate(tqdm(val_loader)):
val_batches.append(x.clone().detach().numpy())

train_x = np.concatenate(train_batches)
np.savez_compressed(dfolder / "train_data.npz", data=train_x)
print(
"Wrote ",
dfolder / "train_data.npz",
f"(shape {train_x.shape}, {train_x.dtype})",
)
val_x = np.concatenate(val_batches)
np.savez_compressed(dfolder / "eval_data.npz", data=val_x)
print(
"Wrote ", dfolder / "eval_data.npz", f"(shape {val_x.shape}, {val_x.dtype})"
)

return 0


if __name__ == "__main__":
rv = main()
sys.exit(rv)
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
cloudpickle>=2.1.0
imageio
numpy>=1.19
pydantic>=1.8.2
pydantic==1.8.2
scikit-learn
scipy>=1.7.1
torch>=1.10.1
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"cloudpickle>=2.1.0",
"imageio",
"numpy>=1.19",
"pydantic>=1.8.2",
"pydantic==1.8.2",
"scikit-learn",
"scipy>=1.7.1",
"torch>=1.10.1",
Expand Down
66 changes: 66 additions & 0 deletions src/pythae/models/rhvae/rhvae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,72 @@ def forward(self, inputs: BaseDataset, **kwargs) -> ModelOutput:

return output

def predict(self, inputs: torch.Tensor) -> ModelOutput:
"""The input data is encoded and decoded without computing loss
Args:
inputs (torch.Tensor): The input data to be reconstructed, as well as to generate the embedding.
Returns:
ModelOutput: An instance of ModelOutput containing reconstruction, raw embedding (output of encoder), and the final embedding (output of metric)
"""
encoder_output = self.encoder(inputs)
mu, log_var = encoder_output.embedding, encoder_output.log_covariance

std = torch.exp(0.5 * log_var)
z0, _ = self._sample_gauss(mu, std)

z = z0

G = self.G(z)
G_inv = self.G_inv(z)
L = torch.linalg.cholesky(G)

G_log_det = -torch.logdet(G_inv)

gamma = torch.randn_like(z0, device=inputs.device)
rho = gamma / self.beta_zero_sqrt
beta_sqrt_old = self.beta_zero_sqrt

# sample \rho from N(0, G)
rho = (L @ rho.unsqueeze(-1)).squeeze(-1)

recon_x = self.decoder(z)["reconstruction"]

for k in range(self.n_lf):

# perform leapfrog steps

# step 1
rho_ = self._leap_step_1(recon_x, inputs, z, rho, G_inv, G_log_det)

# step 2
z = self._leap_step_2(recon_x, inputs, z, rho_, G_inv, G_log_det)

recon_x = self.decoder(z)["reconstruction"]

# compute metric value on new z using final metric
G = self.G(z)
G_inv = self.G_inv(z)

G_log_det = -torch.logdet(G_inv)

# step 3
rho__ = self._leap_step_3(recon_x, inputs, z, rho_, G_inv, G_log_det)

# tempering
beta_sqrt = self._tempering(k + 1, self.n_lf)
rho = (beta_sqrt_old / beta_sqrt) * rho__
beta_sqrt_old = beta_sqrt

output = ModelOutput(
recon_x=recon_x,
raw_embedding=encoder_output.embedding,
embedding=z if self.n_lf > 0 else encoder_output.embedding,
)

return output

def _leap_step_1(self, recon_x, x, z, rho, G_inv, G_log_det, steps=3):
"""
Resolves first equation of generalized leapfrog integrator
Expand Down
Loading

0 comments on commit 9ce2ed4

Please sign in to comment.