Skip to content

Commit

Permalink
Small non-API adjustments to code. New AltMaxVar and PDD_GCCA
Browse files Browse the repository at this point in the history
  • Loading branch information
jameschapman19 committed May 17, 2022
1 parent e5cfcea commit 04a0754
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 34 deletions.
6 changes: 3 additions & 3 deletions cca_zoo/deepmodels/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,15 +174,15 @@ def forward(self, views, mle=True, **kwargs):
def _decode(self, z):
raise NotImplementedError

def recon(self, *args, mle=True):
z = self.forward(*args, mle=mle)
def recon(self, views, **kwargs):
z = self.forward(views, **kwargs)
return self._decode(z)


def detach_all(z):
if isinstance(z, dict):
for k, v in z.items():
v.detach().cpu().numpy()
detach_all(v)
elif isinstance(z, list):
z = [z_.detach().cpu().numpy() for z_ in z]
else:
Expand Down
7 changes: 7 additions & 0 deletions cca_zoo/deepmodels/_dvcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,10 @@ def on_validation_epoch_end(self) -> None:
self.logger.experiment.add_image(
"generated_images_2", grid2, self.current_epoch
)

def recon_uncertainty(self, views, **kwargs):
z = self.forward(views, **kwargs)
z['shared']=z["logvar_shared"]
if self.private_encoders is not None:
z['private'] = z["logvar_private"]
return self._decode(z)
2 changes: 0 additions & 2 deletions cca_zoo/model_selection/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from sklearn.model_selection._search import (
BaseSearchCV as SKBaseSearchCV,
ParameterGrid,
_check_param_grid,
)
from sklearn.model_selection._validation import _fit_and_score, _insert_error_scores
from sklearn.pipeline import Pipeline
Expand Down Expand Up @@ -642,7 +641,6 @@ def __init__(
return_train_score=return_train_score,
)
self.param_grid = param2grid(param_grid)
_check_param_grid(param_grid)

def _run_search(self, evaluate_candidates):
"""Search all candidates in param_grid"""
Expand Down
48 changes: 22 additions & 26 deletions examples/plot_dvcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,37 +8,33 @@
import pytorch_lightning as pl

from cca_zoo.deepmodels import _architectures, DCCAE, DVCCA, SplitAE
from cca_zoo.utils import tsne_label
from examples.utils import example_mnist_data


# %%


def plot_reconstruction(model, dataloader):
for i, batch in enumerate(dataloader):
x, y = batch["views"]
z = model(x[0], y[0], mle=True)
recons = model._decode(z)
fig, ax = plt.subplots(ncols=4)
def plot_reconstruction(model, x, uncertainty=False):
recons = model.recon(x['views'], mle=True)
n_cols=2+uncertainty
fig, ax = plt.subplots(ncols=n_cols)
ax[0].set_title("Original View 1")
ax[1].set_title("Original View 2")
ax[2].set_title("Reconstruction View 1")
ax[3].set_title("Reconstruction View 2")
ax[0].imshow(x[0].detach().numpy().reshape((28, 28)))
ax[1].imshow(y[0].detach().numpy().reshape((28, 28)))
ax[2].imshow(recons[0].detach().numpy().reshape((28, 28)))
ax[3].imshow(recons[1].detach().numpy().reshape((28, 28)))

ax[1].set_title("Mean View 1")
ax[0].imshow(x['views'][0].detach().numpy().reshape((28, 28)))
ax[1].imshow(recons[0].detach().numpy().reshape((28, 28)))
if uncertainty:
ax[2].set_title("Std View 1")
uncertainty_recons = model.recon_uncertainty(x['views'])
ax[2].imshow(uncertainty_recons[0].detach().numpy().reshape((28, 28)))

LATENT_DIMS = 2
EPOCHS = 10
EPOCHS = 1
N_TRAIN = 500
N_VAL = 100
lr = 0.0001
dropout = 0.1
layer_sizes = (1024, 1024, 1024)

train_loader, val_loader, train_labels = example_mnist_data(N_TRAIN, N_VAL)
train_loader, val_loader, train_labels = example_mnist_data(N_TRAIN, N_VAL, type='noisy')

# %%
# Deep Variational CCA
Expand Down Expand Up @@ -68,8 +64,8 @@ def plot_reconstruction(model, dataloader):
flush_logs_every_n_steps=1,
)
trainer.fit(dvcca, train_loader, val_loader)
dvcca.plot_latent_label(train_loader)
plot_reconstruction(dvcca, train_loader)
tsne_label(dvcca.transform(train_loader)['shared'], train_labels)
plot_reconstruction(dvcca, train_loader.dataset[0], uncertainty=True)
plt.suptitle("DVCCA")
plt.show()

Expand Down Expand Up @@ -115,7 +111,8 @@ def plot_reconstruction(model, dataloader):
flush_logs_every_n_steps=1,
)
trainer.fit(dvccap, train_loader, val_loader)
plot_reconstruction(dvccap, train_loader)
tsne_label(dvccap.transform(train_loader)['shared'], train_labels)
plot_reconstruction(dvccap, train_loader.dataset[0], uncertainty=True)
plt.suptitle("DVCCA Private")
plt.show()

Expand Down Expand Up @@ -148,9 +145,8 @@ def plot_reconstruction(model, dataloader):
flush_logs_every_n_steps=1,
)
trainer.fit(dccae, train_loader, val_loader)
dccae.plot_latent_label(train_loader)
plt.suptitle("DCCAE")
plot_reconstruction(dccae, train_loader)
tsne_label(dccae.transform(train_loader)[0], train_labels)
plot_reconstruction(dccae, train_loader.dataset[0])
plt.suptitle("DCCAE")
plt.show()

Expand Down Expand Up @@ -179,7 +175,7 @@ def plot_reconstruction(model, dataloader):
flush_logs_every_n_steps=1,
)
trainer.fit(splitae, train_loader, val_loader)
plt.suptitle("SplitAE")
plot_reconstruction(splitae, train_loader)
tsne_label(splitae.transform(train_loader), train_labels)
plot_reconstruction(splitae, train_loader.dataset[0])
plt.suptitle("SplitAE")
plt.show()
9 changes: 6 additions & 3 deletions examples/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from multiviewdata.torchdatasets import SplitMNIST
from multiviewdata.torchdatasets import SplitMNIST, NoisyMNIST
from torch.utils.data import Subset
import numpy as np

from cca_zoo.deepmodels import get_dataloaders


def example_mnist_data(n_train, n_val, batch_size=50, val_batch_size=10):
train_dataset = SplitMNIST(root="", mnist_type="MNIST", train=True, download=True)
def example_mnist_data(n_train, n_val, batch_size=50, val_batch_size=10, type='split'):
if type=='split':
train_dataset = SplitMNIST(root="", mnist_type="MNIST", train=True, download=True)
else:
train_dataset = NoisyMNIST(root="", mnist_type="MNIST", train=True, download=True)
val_dataset = Subset(train_dataset, np.arange(n_train, n_train + n_val))
train_dataset = Subset(train_dataset, np.arange(n_train))
train_loader, val_loader = get_dataloaders(
Expand Down

0 comments on commit 04a0754

Please sign in to comment.