-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
132 lines (96 loc) · 4.35 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import io
import numpy as np
from torch import Tensor
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import SubsetRandomSampler
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from transformers import AutoTokenizer
from model import BertClassifier
from constants import *
from dataModule import SequenceDataset
from preprocessor import Preprocessor
from utils import seed_everything
from datetime import datetime
import os
seed_everything(24)
def train(data):
writer = SummaryWriter(log_dir=f"/share/tb/{datetime.now().strftime('%b-%d-%Y-%H-%M-%S')}")
# load dutch tokenizer
tokenizer = AutoTokenizer.from_pretrained("GroNLP/bert-base-dutch-cased")
preprocessor = Preprocessor()
train_dataset = SequenceDataset(data, tokenizer, preprocessor)
# model configuration
config = {'hidden_size': 768,
'num_labels': train_dataset.label_count,
'hidden_dropout_prob': 0.05,
}
# Create our custom BERTClassifier model object
model = BertClassifier(config)
model.id_dict = train_dataset.id_dict
model.to(DEVICE)
validation_split = 0.2
dataset_size = len(train_dataset)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
shuffle_dataset = True
if shuffle_dataset:
np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_indices)
validation_sampler = SubsetRandomSampler(val_indices)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=train_sampler)
val_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=validation_sampler)
print('Training Set Size {}, Validation Set Size {}'.format(len(train_indices), len(val_indices)))
loss_fn = nn.CrossEntropyLoss(weight=Tensor(train_dataset.class_weights).to(DEVICE))
optimizer = Adam([
{'params': model.bert.parameters(), 'lr': 1e-5},
{'params': model.classifier.parameters(), 'lr': 3e-4}
])
model.zero_grad()
training_acc_list, validation_acc_list = [], []
global_step = 0
for epoch in range(NUM_EPOCHS):
epoch_loss = 0.0
train_correct_total = 0
# Training Loop
train_iterator = tqdm(train_loader, desc="Train Iteration")
for step, batch in enumerate(train_iterator):
model.train(True)
inputs = batch[0]
labels = batch[1].to(DEVICE)
logits = model(**inputs)
loss = loss_fn(logits, labels)
writer.add_scalar("Loss/train", loss.item(), global_step)
loss.backward()
epoch_loss += loss.item()
if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
optimizer.step()
optimizer.zero_grad()
_, predicted = torch.max(logits.data, 1)
correct_reviews_in_batch = (predicted == labels).sum().item()
train_correct_total += correct_reviews_in_batch
writer.add_scalar("Accuracy/train", train_correct_total / ((step+1) * BATCH_SIZE),
global_step)
global_step += 1
print('Epoch {} - Loss {:.2f}'.format(epoch + 1, epoch_loss / len(train_indices)))
# Validation Loop
with torch.no_grad():
val_correct_total = 0
model.train(False)
val_iterator = tqdm(val_loader, desc="Validation Iteration")
for step, batch in enumerate(val_iterator):
inputs = batch[0]
labels = batch[1].to(DEVICE)
logits = model(**inputs)
_, predicted = torch.max(logits.data, 1)
correct_reviews_in_batch = (predicted == labels).sum().item()
val_correct_total += correct_reviews_in_batch
training_acc_list.append(train_correct_total * 100 / len(train_indices))
validation_acc_list.append(val_correct_total * 100 / len(val_indices))
print('Training Accuracy {:.4f} - Validation Accurracy {:.4f}'.format(
train_correct_total * 100 / len(train_indices), val_correct_total * 100 / len(val_indices)))
torch.save(model, MODEL_FILE_PATH[:-4] + f"_epoch-{epoch}" + ".pth")
torch.save(model, MODEL_FILE_PATH)
return MODEL_FILE_PATH