Skip to content

KonWski/DCGAN_CIFAR10

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

DCGAN_CIFAR

Following repository presents an example implementation of Deep Convolutional Generative Adversial Net (DCGAN). Architecture used for building model is slightly different that the one used in official paper yet its final form is clearly inspired by Alec Redford et al's idea. publication. Data used for model training comes from CIFAR10 dataset available through Torchvision.

GAN

GAN's conception was primarly published in Generative Adversial Nets paper written by Ian J. Godefellow et al.

Generative adversial network is a type of framework consisting of generative model $G(\mathbf{z}, \theta_G)$, where $\theta_G$ are aparameters of $G$, creating data based on random noise $\mathbf{z}$ and of discriminative network $D(\mathbf{x}, \theta_G)$ which returns probability that observation $\mathbf{x}$ comes from training set.

Both framework's parts learn synchronically during training process during which $D$ tries to guess whether observation $\mathbf{x}$ was taken from training set or generated by $G$. GAN's goal is to maximize the probability of $G$ labeling samples in a correct way. On the other hand $G$ tries to fool the discriminator which results in minimizing $\log(1 - D(G(\mathbf{z})))$.

Following minimax game can be summarized by value function $V(D,G)$:

$$\min\limits_{G} \max\limits_{D} V(D,G) = \mathbb{E}{x \sim p{data}(\mathbf{x})}[logD(\mathbf{x})] + \mathbb{E}{\mathbf{z} \sim p{\mathbf{z}}(\mathbf{z})}[log(1 - D(G(\mathbf{z})))]$$

where:

$p_{data}(\mathbf{x})$ - distribution of original dataset;

$p_{\mathbf{z}}(\mathbf{z})$ - distribution on input noise variables.

Authors of Generative Adversial Nets mention that the above equation should not be used in a direct way - on early phase of learning when $G$ generates observations obviously different from the training set, $D$ can classify them faultlessly as product of $G$. That could end up with saturation of $log(1 - D(G(\mathbf{z}))$ which indicates that a better way to train $G$ is to maximize $D(G(\mathbf{z}))$.

DCGAN

The idea of building more complex GAN based models came with Unsupervised representation learning with deep convolutional generative adversarial networks written by Alec Redford et al. Authors proposed a new architecture which takes advantage of transposed convolutions layers in generator and standard convolutional layers in discriminator.

The clue of the publication is to reflect in discriminator's structure layers used in generator but in a reversed way. For example parameters (padding, stride and kernel size) of first transposed convolution layer in generator should match equivalent parameters of last convolutional layer in discriminator. In this way the training process is stabilized to some extend.

Example training processes

Below gifs present generated images (dog/frog/ship/cat/horse/truck) using constant reference random vector across all epochs:

dogs.gif frogs.gif ships.gif cats.gif horses.gif trucks.gif

How to work with project

Training a model

!python /path/to/main.py --n_epochs 300 \
                         --batch_size 32 \
                         --ref_images_dir 'path/to/directory/for/images'\
                         --download_datasets 'true'\
                         --root_datasets_dir 'CIFAR10_dataset' \
                         --class_name 'cat' \
                         --latent_vector_length 100 \
                         --init_generator_weights 'true' \
                         --init_discriminator_weights 'true'

Args used in command:

  • n_epochs - number of epochs
  • batch_size - number of images in batch
  • ref_images_dir - path to directory where referance images will be saved
  • download_datasets - download dataset from Torchvision repo or use already existing dataset
  • root_datasets_dir - path where dataset should be downloaded or where is it already stored
  • class_name - name of one of 10 available in CIFAR10 classes
  • latent_vector_length - length of random vector which will be transformed into an image by generator
  • init_generator_weights - init generator's weights using normal distribiution
  • init_discriminator_weights - init discriminator's weights using normal distribiution

About

DCGAN | CIFAR10 | PyTorch

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages