Following repository presents an example implementation of Deep Convolutional Generative Adversial Net (DCGAN). Architecture used for building model is slightly different that the one used in official paper yet its final form is clearly inspired by Alec Redford et al's idea. publication. Data used for model training comes from CIFAR10 dataset available through Torchvision.
GAN's conception was primarly published in Generative Adversial Nets paper written by Ian J. Godefellow et al.
Generative adversial network is a type of framework consisting of generative model
Both framework's parts learn synchronically during training process during which
Following minimax game can be summarized by value function
$$\min\limits_{G} \max\limits_{D} V(D,G) = \mathbb{E}{x \sim p{data}(\mathbf{x})}[logD(\mathbf{x})] + \mathbb{E}{\mathbf{z} \sim p{\mathbf{z}}(\mathbf{z})}[log(1 - D(G(\mathbf{z})))]$$
where:
Authors of Generative Adversial Nets mention that the above equation should not be used in a direct way - on early phase of learning when
The idea of building more complex GAN based models came with Unsupervised representation learning with deep convolutional generative adversarial networks written by Alec Redford et al. Authors proposed a new architecture which takes advantage of transposed convolutions layers in generator and standard convolutional layers in discriminator.
The clue of the publication is to reflect in discriminator's structure layers used in generator but in a reversed way. For example parameters (padding, stride and kernel size) of first transposed convolution layer in generator should match equivalent parameters of last convolutional layer in discriminator. In this way the training process is stabilized to some extend.
Below gifs present generated images (dog/frog/ship/cat/horse/truck) using constant reference random vector across all epochs:
!python /path/to/main.py --n_epochs 300 \
--batch_size 32 \
--ref_images_dir 'path/to/directory/for/images'\
--download_datasets 'true'\
--root_datasets_dir 'CIFAR10_dataset' \
--class_name 'cat' \
--latent_vector_length 100 \
--init_generator_weights 'true' \
--init_discriminator_weights 'true'
Args used in command:
- n_epochs - number of epochs
- batch_size - number of images in batch
- ref_images_dir - path to directory where referance images will be saved
- download_datasets - download dataset from Torchvision repo or use already existing dataset
- root_datasets_dir - path where dataset should be downloaded or where is it already stored
- class_name - name of one of 10 available in CIFAR10 classes
- latent_vector_length - length of random vector which will be transformed into an image by generator
- init_generator_weights - init generator's weights using normal distribiution
- init_discriminator_weights - init discriminator's weights using normal distribiution