-
Notifications
You must be signed in to change notification settings - Fork 0
/
AnimeGANInitTrain.py
66 lines (55 loc) · 2.52 KB
/
AnimeGANInitTrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from collections import OrderedDict
from glob import glob
import pytorch_lightning as pl
import torch.nn
from torch.optim import Adam
import wandb
from net.backtone import VGGCaffePreTrained
from net.generator import Generator
from tools.ops import *
from tools.utils import *
##################################################################################
# Model
##################################################################################
class AnimeGANInitTrain(pl.LightningModule):
def __init__(self, img_size=None, dataset_name=None, **kwargs):
super().__init__()
self.save_hyperparameters()
if img_size is None:
img_size = [256, 256]
self.img_size = img_size
self.p_model = VGGCaffePreTrained().eval()
""" Define Generator """
self.generated = Generator()
def on_fit_start(self):
self.p_model.setup(self.device)
def forward(self, img):
return self.generated(img)
def training_step(self, batch, batch_idx):
real, anime, anime_gray, anime_smooth = batch
generator_images = self.generated(real)
# init pharse
init_c_loss = con_loss(self.p_model, real, generator_images)
init_loss = self.hparams.con_weight * init_c_loss
self.log('init_loss', init_loss, on_step=True, prog_bar=True, logger=True)
return init_loss
def on_fit_end(self) -> None:
# log epoch images to wandb
val_files = glob('./dataset/{}/*.*'.format('val'))
val_images = []
for i, sample_file in enumerate(val_files):
print('val: ' + str(i) + sample_file)
self.generated.eval()
if i == 0 or i == 26 or i == 5:
with torch.no_grad():
sample_image = np.asarray(load_test_data(sample_file))
test_real = torch.from_numpy(sample_image).type_as(self.generated.out_layer[0].weight)
test_generated_predict = self.generated(test_real)
test_generated_predict = test_generated_predict.permute(0, 2, 3, 1).cpu().detach().numpy()
test_generated_predict = np.squeeze(test_generated_predict, axis=0)
val_images.append(
wandb.Image(test_generated_predict, caption="Name:{}, epoch:{}".format(i, self.current_epoch)))
wandb.log({"val_images": val_images})
def configure_optimizers(self):
G_optim = Adam(self.generated.parameters(), lr=self.hparams.init_lr, betas=(0.5, 0.999))
return G_optim