Skip to content
/ cGAN Public

Tensorflow implementation of Conditional GAN trained on MNIST dataset

License

Notifications You must be signed in to change notification settings

matusstas/cGAN

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

24 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Conditional GAN

IF YOU FIND THIS REPOSITORY HELPFUL, PLEASE CONSIDER STARRING IT.

github clones demo contributors license last commit

Tensorflow implementation of Conditional GAN with the specific goal of generating realistic images of handwritten digits. To ensure optimal training performance, the MNIST dataset, which consists of 60,000 samples (10,000 for each class), was used. Model was trained for a total of 2000 epochs, which took approximately 3 hours on an NVIDIA A100 40GB GPU. Demo is available on Hugging Face. Training is available on Weights & Biases.

2500 generated handwritten digits

Load pretrained model in HDF5 format

generator = load_model("cgan.h5")

Load model's weights

# Initialize optimizers
opt_g = Adam(learning_rate=0.0001, beta_1=0.5)
opt_d = Adam(learning_rate=0.00001, beta_1=0.5)

# Initialize losses
loss_g = BinaryCrossentropy()
loss_d = BinaryCrossentropy()

# Initialize models
generator = build_generator()
discriminator = build_discriminator()
gan = GAN(generator, discriminator)
gan.compile(opt_g, opt_d, loss_g, loss_d)

gan.load_weights("./checkpoints/my_checkpoint")

Weights & Biases

Training history

cGAN evolution