Skip to content

Latest commit

 

History

History
79 lines (59 loc) · 4.43 KB

README.md

File metadata and controls

79 lines (59 loc) · 4.43 KB

Text-to-Image_GAN

About

This project explores the use of Conditional Generative Adversarial Networks (CGANs) for continual learning in image generation, based on textual descriptions. Initially trained on the Oxford 102 Flowers dataset to generate flower images from text, the model then retains the learned parameters from the flowers task and applies this knowledge to generate bird images from the Caltech-UCSD Birds Dataset. We use four different text-encoder and two different deep convolutional CGAN models. The models that we trained demonstrate similar performance across different embedding sizes from different text encoders and GAN architectures.

Dataset links: Flower Dataset link example: Example of generated image and the input text

Requirements

Use this file to install: requirements.txt

  • torch~=2.2.0.dev20230915
  • numpy~=1.24.3
  • tqdm~=4.65.0
  • h5py~=3.7.0
  • Pillow~=9.4.0
  • PyYAML~=6.0
  • matplotlib~=3.7.1
  • sentence_transformers

Implementation details

The main implementation is based on the Generative Adversarial Text-to-Image Synthesis paper [1], with help from this repo [2] for code, which is also reimplementation of the original paper with some modifications.

We used four different text encoders in this project:

  • The first one comes from the text embeddings generated by and used in [1].
  • The other three are fined-tuned versions of DistilRoBERTa, MPNet and MiniLM

We use three DCGANs as our model:

They can all be found in this folder: models

Datasets

We used both Caltech-UCSD Birds 200 and Flowers datasets as part of our training for continual learning.

We used the script Dataset.py to convert them to hd5 format.

Original text embeddings developed by authors were also used as one of our embedding methodd, which can be found here text embeddings.

Usage

Training

trainer_GAN.py is the main file to be run, where all global paths for data, saved models and figures should be defined.

You also need to set the following parameters before running the line disc_loss, genr_loss = model.train(train_loader, dataset): Arguments:

  • batch_size: Size of batch for dataloader. Default = 512
  • lr : The learning rate. default = 0.0002
  • epochs : Number of training epochs. default=1
  • num_channels: Number of input channgels, default=3
  • G_type : GAN archiecture to use for Generator (vanilla_gan | cgan | wgan). default = vanilla_gan
  • D_type : GAN archiecture to use for Discriminator (vanilla_gan | cgan | wgan). default = vanilla_gan
  • d_beta1: Optimizar beta_1 for Discriminator, default =0.5
  • d_beta2: Optimizar beta_2 for Discriminator, default = 0.999
  • g_beta1: Optimizar beta_1 for Generator, default =0.5
  • g_beta2: Optimizar beta_2 for Generator, default = 0.999
  • save_path : Path for saving the models, default = `ckpt
  • l1_coef : L1 loss coefficient in the generator loss fucntion for cgan and wgan. default=50
  • l2_coef : Feature matching coefficient in the generator loss fucntion for cgan and qgan. default=100
  • idx: Embedding index from embeddings, default = 3
  • embeddings: Type of embeddings ['default', 'all-mpnet-base-v2', 'all-distilroberta-v1', 'all-MiniLM-L12-v2']
  • names: Name of embedding type ['default', 'MPNET', 'DistilROBERTa' , 'miniLM-L12']
  • dataset: Dataset to use, default = T2IGANDataset(dataset_file="data/flowers.hdf5", split="train", emb_type=embeddings[idx])
  • train_loader: DataLoader for training set, default= DataLoader(dataset, batch_size=batch_size, shuffle=True)
  • embed_size: Size of embeddings, default = dataset.embed_dim (if using CGAN or WGAN)

Plotting the generated images

Run plot_gan_losses(disc_loss, genr_loss) in trainer_GAN.py .

References

[1] Generative Adversarial Text-to-Image Synthesis https://arxiv.org/abs/1605.05396 [2] Text-to-Image-Synthesis https://github.com/aelnouby/Text-to-Image-Synthesis/tree/master