Skip to content

Commit

Permalink
Small changes to #100
Browse files Browse the repository at this point in the history
  • Loading branch information
clementchadebec committed Aug 19, 2023
1 parent 384adb9 commit c767af7
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 8 deletions.
5 changes: 2 additions & 3 deletions src/pythae/pipelines/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def __init__(
if training_config is None:
if model.model_name == "RAE_L2":
training_config = CoupledOptimizerTrainerConfig(
encoder_optim_decay=0,
decoder_optim_decay=model.model_config.reg_weight,
encoder_optimizer_params={"weight_decay": 0},
decoder_optimizer_params={"weight_decay": model.model_config.reg_weight},
)

elif (
Expand Down Expand Up @@ -216,7 +216,6 @@ def __call__(
model=self.model,
train_dataset=train_dataloader or train_dataset,
eval_dataset=eval_dataloader or eval_dataset,
eval_dataloader=eval_dataloader,
training_config=self.training_config,
callbacks=callbacks,
)
Expand Down
8 changes: 8 additions & 0 deletions src/pythae/trainers/base_trainer/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,20 @@ def __init__(
# Define the loaders
if isinstance(train_dataset, DataLoader):
train_loader = train_dataset
logger.warn(
"Using the provided train dataloader! Carefull this may overwrite some "
"parameters provided in your training config."
)
else:
train_loader = self.get_train_dataloader(train_dataset)

if eval_dataset is not None:
if isinstance(eval_dataset, DataLoader):
eval_loader = eval_dataset
logger.warn(
"Using the provided eval dataloader! Carefull this may overwrite some "
"parameters provided in your training config."
)
else:
eval_loader = self.get_eval_dataloader(eval_dataset)
else:
Expand Down
26 changes: 21 additions & 5 deletions tests/test_pipeline_standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from pythae.customexception import DatasetError
from pythae.data.datasets import DatasetOutput
from pythae.models import VAE, FactorVAE, FactorVAEConfig, VAEConfig
from pythae.models import VAE, VAEConfig, Adversarial_AE, Adversarial_AE_Config, RAE_L2, RAE_L2_Config, VAEGAN, VAEGANConfig
from pythae.pipelines import *
from pythae.samplers import NormalSampler, NormalSamplerConfig
from pythae.trainers import BaseTrainerConfig
Expand Down Expand Up @@ -69,26 +69,39 @@ def custom_no_len_train_dataset(self):
return CustomWrongOutputDataset(
os.path.join(PATH, "data/mnist_clean_train_dataset_sample")
)

@pytest.fixture(
params=[
(VAE, VAEConfig),
(Adversarial_AE, Adversarial_AE_Config),
(RAE_L2, RAE_L2_Config),
(VAEGAN, VAEGANConfig)
]
)
def model(self, request, train_dataset):
model = request.param[0](request.param[1](input_dim=tuple(train_dataset.data[0].shape), latent_dim=2))

return model

@pytest.fixture
def train_dataloader(self, custom_train_dataset):
return DataLoader(dataset=custom_train_dataset, batch_size=32)

@pytest.fixture
def training_pipeline(self, train_dataset):
def training_pipeline(self, model, train_dataset):
vae_config = VAEConfig(
input_dim=tuple(train_dataset.data[0].shape), latent_dim=2
)
vae = VAE(vae_config)
pipe = TrainingPipeline(model=vae)
pipe = TrainingPipeline(model=model)
return pipe

def test_base_pipeline(self):
with pytest.raises(NotImplementedError):
pipe = Pipeline()
pipe()

def test_training_pipeline(self, tmpdir, training_pipeline, train_dataset):
def test_training_pipeline(self, tmpdir, training_pipeline, train_dataset, model):

with pytest.raises(AssertionError):
pipeline = TrainingPipeline(
Expand All @@ -100,7 +113,10 @@ def test_training_pipeline(self, tmpdir, training_pipeline, train_dataset):
training_pipeline.training_config.output_dir = dir_path
training_pipeline.training_config.num_epochs = 1
training_pipeline(train_dataset.data)
assert isinstance(training_pipeline.model, VAE)
assert isinstance(training_pipeline.model, model.__class__)

if model.__class__ == RAE_L2:
assert training_pipeline.trainer.decoder_optimizer.state_dict()['param_groups'][0]['weight_decay'] == model.model_config.reg_weight

def test_training_pipeline_wrong_output_dataset(
self,
Expand Down

0 comments on commit c767af7

Please sign in to comment.