forked from taoxugit/AttnGAN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_feature_predictor.py
110 lines (90 loc) · 5.65 KB
/
train_feature_predictor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import logging
from typing import List
import torch
from matplotlib.figure import Figure
from torch import optim, Tensor
from torch.nn import Module
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from attnganw.randomutils import set_random_seed
from featurepred.model import FeaturePredictorModelWrapper
from featurepred.data import ResNet50DataLoaderBuilder, RESNET50_MEANS, RESNET50_STD_DEVS
from featurepred.train import FeaturePredictorTrainer, output_to_target_class
from utils.image import plot_images_with_labels
INPUT_RESIZE: int = 224
def start_training(predictor_trainer: FeaturePredictorTrainer, train_data_loader: DataLoader,
validation_data_loader: DataLoader,
num_epochs: int, optimiser_learning_rate: float):
model: Module = predictor_trainer.model_wrapper.model
if predictor_trainer.model_wrapper.feature_extraction:
logging.info("Feature extraction: Only optimizing last layer's parameters")
optimiser = optim.Adam(params=model.fc.parameters(), lr=optimiser_learning_rate)
else:
logging.info("Fine-tuning: All model parameters are optimised")
optimiser = optim.Adam(params=model.parameters(), lr=optimiser_learning_rate)
loss_function = torch.nn.CrossEntropyLoss()
device = torch.device("cpu")
if torch.cuda.is_available():
logging.info("CUDA supported. Running on GPU")
device = torch.device("cuda")
predictor_trainer.train_predictor(epochs=num_epochs, train_loader=train_data_loader,
validation_loader=validation_data_loader,
optimiser=optimiser, loss_function=loss_function, device=device)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
# logging.basicConfig(level=logging.DEBUG)
epochs: int = 3
train_image_folder: str = 'data/feature_data/train'
validation_image_folder: str = 'data/feature_data/val'
batch_size: int = 32
learning_rate: float = 0.001
input_resize: int = INPUT_RESIZE
model_state_file: str = 'feature_predictor.pt'
random_seed: int = 100
data_loader_workers: int = 4
summary_writer: SummaryWriter = SummaryWriter()
set_random_seed(random_seed=random_seed)
model_wrapper: FeaturePredictorModelWrapper = FeaturePredictorModelWrapper(model_state_file=model_state_file,
feature_extraction=True)
trainer: FeaturePredictorTrainer = FeaturePredictorTrainer(model_wrapper=model_wrapper,
summary_writer=summary_writer)
train_dataloader_builder: ResNet50DataLoaderBuilder = ResNet50DataLoaderBuilder(image_folder=train_image_folder,
batch_size=batch_size,
input_resize=input_resize,
is_training=True,
data_loader_workers=data_loader_workers)
train_loader: DataLoader = train_dataloader_builder.build()
train_images, train_classes = next(iter(train_loader))
logging.info("Starting TensorBoard ...")
summary_writer.add_graph(model=model_wrapper.model, input_to_model=train_images)
train_labels: List[str] = [train_dataloader_builder.class_names[class_index] for class_index in train_classes]
num_channels: int = 3
summary_writer.add_embedding(mat=train_images.view(-1, num_channels * input_resize * input_resize),
label_img=train_images, metadata=train_labels)
training_data_plot: Figure = plot_images_with_labels(images=train_images[:10], classes=train_classes[:10],
class_names=train_dataloader_builder.class_names,
means=RESNET50_MEANS, standard_devs=RESNET50_STD_DEVS,
file_name='training_sample.png')
summary_writer.add_figure(tag='training_data_plot', figure=training_data_plot)
summary_writer.close()
valid_data_loader_builder: ResNet50DataLoaderBuilder = ResNet50DataLoaderBuilder(
image_folder=validation_image_folder,
batch_size=batch_size,
input_resize=input_resize,
is_training=False,
data_loader_workers=data_loader_workers)
validation_loader: DataLoader = valid_data_loader_builder.build()
start_training(predictor_trainer=trainer, train_data_loader=train_loader, validation_data_loader=validation_loader,
num_epochs=epochs,
optimiser_learning_rate=learning_rate)
model_wrapper.load_model_from_file(device="cpu")
model_wrapper.model.eval()
valid_images, valid_classes = next(iter(validation_loader))
valid_images = valid_images[:10]
predicted_classes: Tensor = output_to_target_class(model_output=model_wrapper.model(valid_images[:10]))
validation_data_plot: Figure = plot_images_with_labels(images=valid_images, classes=predicted_classes,
class_names=valid_data_loader_builder.class_names,
means=RESNET50_MEANS, standard_devs=RESNET50_STD_DEVS,
file_name='prediction_sample.png')
summary_writer.add_figure(tag='validation_data_plot', figure=validation_data_plot)
summary_writer.close()