Probaforms
is a python library of conditional Generative Adversarial Networks, Normalizing Flows, Variational Autoencoders and other generative models for tabular data. All models have a sklearn-like interface to enable rapid use in a variety of science and engineering applications.
Model | Type | Paper |
---|---|---|
ConditionalNormal | MDN | Bishop CM. Mixture density networks. 1994. |
CVAE | VAE | Kingma DP, Welling M. Auto-encoding variational bayes. arXiv:1312.6114. ICLR 2014. |
ConditionalWGAN | GAN | Arjovsky M, Chintala S, Bottou L. Wasserstein generative adversarial networks. arXiv:1701.07875. ICML 2017. |
RealNVP | Normalizing Flow | Dinh L, Sohl-Dickstein J, Bengio S. Density estimation using real nvp. arXiv:1605.08803. ICLR 2017. |
pip install probaforms
or
git clone https://github.com/hse-cs/probaforms
cd probaforms
pip install -e .
or
poetry install
(See more examples in the documentation.)
The following code snippet generates a noisy synthetic data, fits a conditional generative model, sample new objects, and displays the results.
from sklearn.datasets import make_moons
import matplotlib.pyplot as plt
from probaforms.models import RealNVP
# generate sample X with conditions C
X, y = make_moons(n_samples=1000, noise=0.1)
C = y.reshape(-1, 1)
# fit nomalizing flow model
model = RealNVP(lr=0.01, n_epochs=100)
model.fit(X, C)
# sample new objects
X_gen = model.sample(C)
# display the results
plt.scatter(X_gen[y==0, 0], X_gen[y==0, 1])
plt.scatter(X_gen[y==1, 0], X_gen[y==1, 1])
plt.show()
- Home: https://github.com/hse-cs/probaforms
- Documentation: https://hse-cs.github.io/probaforms
- For any usage questions, suggestions and bugs use the issue page, please.