Skip to content

Pythae 0.1.1

Compare
Choose a tag to compare
@clementchadebec clementchadebec released this 23 Feb 16:06
· 26 commits to main since this release

New features

  • Added the training_callback TrainHistoryCallback that stores the training metrics during training in #71 by @VolodyaCO
from pythae.trainers.training_callbacks import TrainHistoryCallback

>>> train_history = TrainHistoryCallback()
>>> callbacks = [train_history]
>>> pipeline(
...    train_data=train_dataset,
...    eval_data=eval_dataset,
...    callbacks=callbacks
... )
>>> train_history.history
... {
...    'train_loss': [58.51896972363562, 42.15931177749049, 40.583426756017346],
...    'eval_loss': [43.39408182034827, 41.45351771943888, 39.77221281209569]
... }
  • Added a predict method that encodes and decodes input data without loss computation in #75 by @soumickmj and @ravih18
>>> out = model.predict(eval_dataset[:3])
>>> out.embedding.shape, out.recon_x.shape
... (torch.Size([3, 16]), torch.Size([3, 1, 28, 28]))
  • Added embed method that returns the latent representations of the input data in #76 by @tbouchik
>>> out = model.embed(eval_dataset[:3].to(device))
>>> out.shape
... torch.Size([3, 16])