Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "Progressive Growing of GANs" (ProGAN) model #1105

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,13 @@ def set_requires_grad(self, nets, requires_grad=False):
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad

def make_data_parallel(self):
"""Make models data parallel"""
if len(self.gpu_ids) == 0:
return
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, 'net' + name)
net = torch.nn.DataParallel(net, self.gpu_ids) # multi-GPUs
setattr(self, 'net' + name, net)
43 changes: 36 additions & 7 deletions models/cycle_gan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
from torch.utils.checkpoint import checkpoint

try:
from apex import amp
except ImportError:
print("Please install NVIDIA Apex for safe mixed precision if you want to use non default --opt_level")

class CycleGANModel(BaseModel):
"""
Expand Down Expand Up @@ -96,6 +101,13 @@ def __init__(self, opt):
self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D)

if opt.apex:
[self.netG_A, self.netG_B, self.netD_A, self.netD_B], [self.optimizer_G, self.optimizer_D] = amp.initialize(
[self.netG_A, self.netG_B, self.netD_A, self.netD_B], [self.optimizer_G, self.optimizer_D], opt_level=opt.opt_level, num_losses=3)

# need to be wrapped after amp.initialize
self.make_data_parallel()

def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.

Expand All @@ -112,11 +124,17 @@ def set_input(self, input):
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
self.fake_B = self.netG_A(self.real_A) # G_A(A)
self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A))
if not self.isTrain or not self.opt.checkpointing:
self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A))
else:
self.rec_A = checkpoint(self.netG_B, self.fake_B)
self.fake_A = self.netG_B(self.real_B) # G_B(B)
self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B))
if not self.isTrain or not self.opt.checkpointing:
self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B))
else:
self.rec_B = checkpoint(self.netG_A, self.fake_A)

def backward_D_basic(self, netD, real, fake):
def backward_D_basic(self, netD, real, fake, loss_id):
"""Calculate GAN loss for the discriminator

Parameters:
Expand All @@ -135,18 +153,23 @@ def backward_D_basic(self, netD, real, fake):
loss_D_fake = self.criterionGAN(pred_fake, False)
# Combined loss and calculate gradients
loss_D = (loss_D_real + loss_D_fake) * 0.5
loss_D.backward()
if self.opt.apex:
with amp.scale_loss(loss_D, self.optimizer_D, loss_id=loss_id) as loss_D_scaled:
loss_D_scaled.backward()
else:
loss_D.backward()

return loss_D

def backward_D_A(self):
"""Calculate GAN loss for discriminator D_A"""
fake_B = self.fake_B_pool.query(self.fake_B)
self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B, loss_id=0)

def backward_D_B(self):
"""Calculate GAN loss for discriminator D_B"""
fake_A = self.fake_A_pool.query(self.fake_A)
self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A, loss_id=1)

def backward_G(self):
"""Calculate the loss for generators G_A and G_B"""
Expand Down Expand Up @@ -175,7 +198,13 @@ def backward_G(self):
self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
# combined loss and calculate gradients
self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
self.loss_G.backward()

if self.opt.apex:
with amp.scale_loss(self.loss_G, self.optimizer_G, loss_id=2) as loss_G_scaled:
loss_G_scaled.backward()
else:
self.loss_G.backward()


def optimize_parameters(self):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
Expand Down
218 changes: 210 additions & 8 deletions models/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import torch.nn as nn
from torch.nn import init
import functools
from torch.optim import lr_scheduler

from torch.nn.functional import interpolate
from torch.optim import lr_scheduler
import numpy as np

###############################################################################
# Helper Functions
Expand Down Expand Up @@ -98,7 +100,7 @@ def init_func(m): # define the initialization function
net.apply(init_func) # apply the initialization function <init_func>


def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], init_weights_=True):
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
Parameters:
net (network) -- the network to be initialized
Expand All @@ -111,12 +113,13 @@ def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
net.to(gpu_ids[0])
net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
init_weights(net, init_type, init_gain=init_gain)
if init_weights_:
init_weights(net, init_type, init_gain=init_gain)
return net


def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]):
def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[],
init_weights=True, **kwargs):
"""Create a generator

Parameters:
Expand Down Expand Up @@ -154,12 +157,16 @@ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, in
net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
elif netG == 'unet_256':
net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
elif netG == 'progan':
net = GeneratorProGanV2(input_code_dim=input_nc, in_channel=ngf, max_steps=kwargs['max_steps']+1, out_channels=output_nc)
# net = GeneratorProGan(input_code_dim=input_nc, in_channel=ngf, max_steps=kwargs['max_steps'], out_channels=output_nc)
else:
raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
return init_net(net, init_type, init_gain, gpu_ids)
return init_net(net, init_type, init_gain, gpu_ids, init_weights_=init_weights)


def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]):
def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[],
init_weights=True, **kwargs):
"""Create a discriminator

Parameters:
Expand Down Expand Up @@ -198,9 +205,12 @@ def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal'
net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer)
elif netD == 'pixel': # classify if each pixel is real or fake
net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
elif netD == 'progan':
net = DiscriminatorProGanV2(feat_dim=ndf, in_dim=input_nc, max_steps=kwargs['max_steps']+1)
# net = DiscriminatorProGan(feat_dim=ndf, in_dim=input_nc, max_steps=kwargs['max_steps'])
else:
raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD)
return init_net(net, init_type, init_gain, gpu_ids)
return init_net(net, init_type, init_gain, gpu_ids, init_weights_=init_weights)


##############################################################################
Expand Down Expand Up @@ -613,3 +623,195 @@ def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
def forward(self, input):
"""Standard forward."""
return self.net(input)


# ========================================================================================
# Generator Module of ProGAN
# can be used with ProGAN or standalone (for inference)
# Thanks to https://raw.githubusercontent.com/akanimax/pro_gan_pytorch/
# ========================================================================================


class GeneratorProGanV2(nn.Module):
""" Generator of the GAN network """

def __init__(self, max_steps=7, input_code_dim=512, in_channel=512, out_channels=3, use_eql=True):
"""
constructor for the Generator class
:param max_steps: required depth of the Network
:param input_code_dim: size of the latent manifold
:param use_eql: whether to use equalized learning rate
"""
from .progan_layers import GenGeneralConvBlock, GenInitialBlock, _equalized_conv2d

super(GeneratorProGanV2, self).__init__()

assert input_code_dim != 0 and ((input_code_dim & (input_code_dim - 1)) == 0), \
"latent size not a power of 2"
if max_steps >= 4:
assert in_channel >= np.power(2, max_steps - 4), "in_channel size will diminish to zero"

# state of the generator:
self.use_eql = use_eql
self.depth = max_steps
self.latent_size = input_code_dim
self.channels_conv = in_channel

# register the modules required for the GAN
self.initial_block = GenInitialBlock(in_channels=self.latent_size, out_channels=in_channel, use_eql=self.use_eql)

# create a module list of the other required general convolution blocks
self.layers = nn.ModuleList([]) # initialize to empty list

# create the ToRGB layers for various outputs:
if self.use_eql:
self.toRGB = lambda in_channels: \
_equalized_conv2d(in_channels, out_channels, (1, 1), bias=True)
else:
from torch.nn import Conv2d
self.toRGB = lambda in_channels: Conv2d(in_channels, out_channels, (1, 1), bias=True)

self.rgb_converters = nn.ModuleList([self.toRGB(self.channels_conv)])

# create the remaining layers
for i in range(self.depth - 1):
if i <= 2:
layer = GenGeneralConvBlock(self.channels_conv,
self.channels_conv, use_eql=self.use_eql)
rgb = self.toRGB(self.channels_conv)
else:
layer = GenGeneralConvBlock(
int(self.channels_conv // np.power(2, i - 3)),
int(self.channels_conv // np.power(2, i - 2)),
use_eql=self.use_eql
)
rgb = self.toRGB(int(self.channels_conv // np.power(2, i - 2)))
self.layers.append(layer)
self.rgb_converters.append(rgb)

# register the temporary upsampler
self.temporaryUpsampler = lambda x: interpolate(x, scale_factor=2)

def forward(self, x, step, alpha):
"""
forward pass of the Generator
:param x: input noise
:param step: current depth from where output is required
:param alpha: value of alpha for fade-in effect
:return: y => output
"""
# step = step - 1
assert step < self.depth, "Requested output depth cannot be produced"

y = self.initial_block(x)

if step > 0:
for block in self.layers[:step - 1]:
y = block(y)

residual = self.rgb_converters[step - 1](self.temporaryUpsampler(y))
straight = self.rgb_converters[step](self.layers[step - 1](y))

out = (alpha * straight) + ((1 - alpha) * residual)

else:
out = self.rgb_converters[0](y)

return out

# ========================================================================================
# Discriminator Module of ProGAN
# can be used with ProGAN or standalone (for inference).
# Thanks to https://raw.githubusercontent.com/akanimax/pro_gan_pytorch/
# ========================================================================================


class DiscriminatorProGanV2(nn.Module):
""" Discriminator of the GAN """

def __init__(self, max_steps=7, feat_dim=512, in_dim=3, use_eql=True):
"""
constructor for the class
:param max_steps: total height of the discriminator (Must be equal to the Generator depth)
:param feat_dim: size of the deepest features extracted
(Must be equal to Generator latent_size)
:param use_eql: whether to use equalized learning rate
"""
from torch.nn import ModuleList, AvgPool2d
from .progan_layers import DisGeneralConvBlock, DisFinalBlock, _equalized_conv2d

super(DiscriminatorProGanV2, self).__init__()

assert feat_dim != 0 and ((feat_dim & (feat_dim - 1)) == 0), \
"latent size not a power of 2"
if max_steps >= 4:
assert feat_dim >= np.power(2, max_steps - 4), "feature size cannot be produced"

# create state of the object
self.use_eql = use_eql
self.height = max_steps
self.feature_size = feat_dim

self.final_block = DisFinalBlock(self.feature_size, use_eql=self.use_eql)

# create a module list of the other required general convolution blocks
self.layers = ModuleList([]) # initialize to empty list

# create the fromRGB layers for various inputs:
if self.use_eql:
self.fromRGB = lambda out_channels: \
_equalized_conv2d(3, out_channels, (1, 1), bias=True)
else:
from torch.nn import Conv2d
self.fromRGB = lambda out_channels: Conv2d(in_dim, out_channels, (1, 1), bias=True)

self.rgb_to_features = ModuleList([self.fromRGB(self.feature_size)])

# create the remaining layers
for i in range(self.height - 1):
if i > 2:
layer = DisGeneralConvBlock(
int(self.feature_size // np.power(2, i - 2)),
int(self.feature_size // np.power(2, i - 3)),
use_eql=self.use_eql
)
rgb = self.fromRGB(int(self.feature_size // np.power(2, i - 2)))
else:
layer = DisGeneralConvBlock(self.feature_size,
self.feature_size, use_eql=self.use_eql)
rgb = self.fromRGB(self.feature_size)

self.layers.append(layer)
self.rgb_to_features.append(rgb)

# register the temporary downSampler
self.temporaryDownsampler = AvgPool2d(2)

def forward(self, x, step, alpha):
"""
forward pass of the discriminator
:param x: input to the network
:param step: current height of operation (Progressive GAN)
:param alpha: current value of alpha for fade-in
:return: out => raw prediction values (WGAN-GP)
"""
# step = step - 1
assert step < self.height, "Requested output depth cannot be produced"

if step > 0:
residual = self.rgb_to_features[step - 1](self.temporaryDownsampler(x))

straight = self.layers[step - 1](
self.rgb_to_features[step](x)
)

y = (alpha * straight) + ((1 - alpha) * residual)

for block in reversed(self.layers[:step - 1]):
y = block(y)
else:
y = self.rgb_to_features[0](x)

out = self.final_block(y)

return out
Loading