Skip to content

Commit

Permalink
Merge branch 'dev' into origin/main
Browse files Browse the repository at this point in the history
  • Loading branch information
nmichlo committed Jun 9, 2022
2 parents d3aa2d3 + 1f56bed commit c8756e9
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest] # [ubuntu-latest, windows-latest, macos-latest]
python-version: ["3.8", "3.9"]
python-version: ["3.8", "3.9", "3.10"]

steps:
- uses: actions/checkout@v2
Expand Down
12 changes: 6 additions & 6 deletions disent/frameworks/vae/_weaklysupervised__adavae.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def compute_average_gvae_std(d0_posterior: Normal, d1_posterior: Normal) -> Norm
assert isinstance(d1_posterior, Normal), f'posterior distributions must be {Normal.__name__} distributions, got: {type(d1_posterior)}'
# averages
ave_std = 0.5 * (d0_posterior.stddev + d1_posterior.stddev)
ave_mean = 0.5 * (d1_posterior.mean + d1_posterior.mean)
ave_mean = 0.5 * (d0_posterior.mean + d1_posterior.mean)
# done!
return Normal(loc=ave_mean, scale=ave_std)

Expand All @@ -235,7 +235,7 @@ def compute_average_gvae(d0_posterior: Normal, d1_posterior: Normal) -> Normal:
assert isinstance(d1_posterior, Normal), f'posterior distributions must be {Normal.__name__} distributions, got: {type(d1_posterior)}'
# averages
ave_var = 0.5 * (d0_posterior.variance + d1_posterior.variance)
ave_mean = 0.5 * (d1_posterior.mean + d1_posterior.mean)
ave_mean = 0.5 * (d0_posterior.mean + d1_posterior.mean)
# done!
return Normal(loc=ave_mean, scale=torch.sqrt(ave_var))

Expand Down Expand Up @@ -323,10 +323,10 @@ def hook_intercept_ds(self, ds_posterior: Sequence[Distribution], ds_prior: Sequ
ave_std = (0.5 * d0_posterior.variance + 0.5 * d1_posterior.variance) ** 0.5

# [4.b] select shared or original values based on mask
z0_mean = torch.where(share_mask, d0_posterior.loc, ave_mean)
z1_mean = torch.where(share_mask, d1_posterior.loc, ave_mean)
z0_std = torch.where(share_mask, d0_posterior.scale, ave_std)
z1_std = torch.where(share_mask, d1_posterior.scale, ave_std)
z0_mean = torch.where(share_mask, ave_mean, d0_posterior.loc)
z1_mean = torch.where(share_mask, ave_mean, d1_posterior.loc)
z0_std = torch.where(share_mask, ave_std, d0_posterior.scale)
z1_std = torch.where(share_mask, ave_std, d1_posterior.scale)

# construct distributions
ave_d0_posterior = Normal(loc=z0_mean, scale=z0_std)
Expand Down
14 changes: 7 additions & 7 deletions disent/util/seeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@

import contextlib
import logging
import random
import numpy as np


log = logging.getLogger(__name__)
Expand All @@ -44,8 +42,10 @@ def seed(long=777):
log.warning(f'[SEEDING]: no seed was specified. Seeding skipped!')
return
# seed python
import random
random.seed(long)
# seed numpy
import numpy as np
np.random.seed(long)
# seed torch - it can be slow to import
try:
Expand All @@ -60,34 +60,34 @@ def seed(long=777):


class TempNumpySeed(contextlib.ContextDecorator):
def __init__(self, seed=None, offset=0):
def __init__(self, seed: int = None):
# check and normalize seed
if seed is not None:
try:
seed = int(seed)
except:
raise ValueError(f'{seed=} is not int-like!')
# offset seed
if seed is not None:
seed += offset
raise ValueError(f'seed={seed} is not int-like!')
# save values
self._seed = seed
self._state = None

def __enter__(self):
if self._seed is not None:
import numpy as np
self._state = np.random.get_state()
np.random.seed(self._seed)

def __exit__(self, *args, **kwargs):
if self._seed is not None:
import numpy as np
np.random.set_state(self._state)
self._state = None

def _recreate_cm(self):
# TODO: do we need to override this?
return self


# ========================================================================= #
# END #
# ========================================================================= #
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
author="Nathan Juraj Michlo",
author_email="NathanJMichlo@gmail.com",

version="0.5.0",
version="0.5.1",
python_requires=">=3.8", # we make use of standard library features only in 3.8
packages=setuptools.find_packages(),

Expand All @@ -64,6 +64,7 @@
"Operating System :: OS Independent",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Intended Audience :: Science/Research",
],
)
Expand Down
51 changes: 51 additions & 0 deletions tests/test_frameworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@

import pytest
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader

from disent.dataset import DisentDataset
from disent.dataset.data import XYObjectData
from disent.dataset.sampling import GroundTruthSingleSampler
from disent.dataset.sampling import GroundTruthPairSampler
from disent.dataset.sampling import GroundTruthTripleSampler
from disent.dataset.sampling import RandomSampler
from disent.frameworks.ae import *
from disent.frameworks.vae import *
from disent.model import AutoEncoder
Expand All @@ -46,6 +48,8 @@
# ========================================================================= #
# TEST FRAMEWORKS #
# ========================================================================= #
from disent.util.seeds import seed
from disent.util.seeds import TempNumpySeed
from docs.examples.extend_experiment.code.weaklysupervised__si_adavae import SwappedInputAdaVae
from docs.examples.extend_experiment.code.weaklysupervised__si_betavae import SwappedInputBetaVae

Expand Down Expand Up @@ -166,6 +170,53 @@ def test_framework_config_defaults():
)


def test_ada_vae_similarity():

seed(42)

data = XYObjectData()
dataset = DisentDataset(data, sampler=RandomSampler(num_samples=2), transform=ToImgTensorF32())
dataloader = DataLoader(dataset, num_workers=0, batch_size=3)

model = AutoEncoder(
encoder=EncoderLinear(x_shape=data.x_shape, z_size=25, z_multiplier=2),
decoder=DecoderLinear(x_shape=data.x_shape, z_size=25, z_multiplier=1),
)

adavae0 = AdaGVaeMinimal(model=model, cfg=AdaGVaeMinimal.cfg())
adavae1 = AdaVae(model=model, cfg=AdaVae.cfg())
adavae2 = AdaVae(model=model, cfg=AdaVae.cfg(
ada_average_mode='gvae',
ada_thresh_mode='symmetric_kl',
ada_thresh_ratio=0.5,
))

batch = next(iter(dataloader))

# TODO: add a TempNumpySeed equivalent for torch
seed(777)
result0a = adavae0.do_training_step(batch, 0)
seed(777)
result0b = adavae0.do_training_step(batch, 0)
assert torch.allclose(result0a, result0b), f'{result0a} does not match {result0b}'

seed(777)
result1a = adavae1.do_training_step(batch, 0)
seed(777)
result1b = adavae1.do_training_step(batch, 0)
assert torch.allclose(result1a, result1b), f'{result1a} does not match {result1b}'

seed(777)
result2a = adavae2.do_training_step(batch, 0)
seed(777)
result2b = adavae2.do_training_step(batch, 0)
assert torch.allclose(result2a, result2b), f'{result2a} does not match {result2b}'

# check similar
assert torch.allclose(result0a, result1a), f'{result0a} does not match {result1a}'
assert torch.allclose(result1a, result2a), f'{result1a} does not match {result2a}'


# ========================================================================= #
# END #
# ========================================================================= #

0 comments on commit c8756e9

Please sign in to comment.