-
Notifications
You must be signed in to change notification settings - Fork 2
/
eval_citation_bert.py
104 lines (88 loc) · 4.02 KB
/
eval_citation_bert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import yaml
import torch
import argparse
import logging
from utils.logger import ColoredLogger
from dataset import get_bert_dataset
from torch.utils.data import DataLoader
from models.models import CitationBert
from utils.criterion import CrossEntropyLoss
from utils.metrics import ResultRecorder
logging.setLoggerClass(ColoredLogger)
logger = logging.getLogger(__name__)
# Parse Arguments
parser = argparse.ArgumentParser()
parser.add_argument('--cfg', default = os.path.join('configs', 'citation_bert.yaml'), help = 'Config File', type = str)
FLAGS = parser.parse_args()
CFG_FILE = FLAGS.cfg
with open(CFG_FILE, 'r') as cfg_file:
cfg_dict = yaml.load(cfg_file, Loader = yaml.FullLoader)
MULTIGPU = cfg_dict.get('multigpu', False)
EMBEDDING_DIM = cfg_dict.get('embedding_dim', 768)
COSINE_SOFTMAX_S = cfg_dict.get('cosine_softmax_S', 1)
BERT_CASED = cfg_dict.get('bert_cased', False)
BATCH_SIZE = cfg_dict.get('batch_size', 4)
MAX_LENGTH = cfg_dict.get('max_length', 512)
SEQ_LEN = cfg_dict.get('seq_len', 50)
END_YEAR = cfg_dict.get('end_year', 2015)
FREQUENCY = cfg_dict.get('frequency', 5)
RECALL_K = cfg_dict.get('recall_K', [5, 10, 30, 50, 80])
STATS_DIR = cfg_dict.get('stats_dir', os.path.join('stats', 'citation_bert'))
DATA_PATH = cfg_dict.get('data_path', os.path.join('data', 'citation.csv'))
EMBEDDING_PATH = cfg_dict.get('embedding_path', os.path.join('stats', 'vgae', 'embedding.npy'))
if os.path.exists(STATS_DIR) == False:
os.makedirs(STATS_DIR)
checkpoint_file = os.path.join(STATS_DIR, 'checkpoint.tar')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# Load data & Build dataset
logger.info('Reading bert dataset & citation dataset ...')
_, val_dataset, paper_info = get_bert_dataset(DATA_PATH, seq_len = SEQ_LEN, year = END_YEAR, frequency = FREQUENCY)
paper_num = len(paper_info)
logger.info('Finish reading and dividing into training and testing sets.')
val_dataloader = DataLoader(val_dataset, batch_size = BATCH_SIZE, shuffle = True)
# Build model from configs
model = CitationBert(num_classes = paper_num, embedding_dim = EMBEDDING_DIM, max_length = MAX_LENGTH, S = COSINE_SOFTMAX_S, cased = BERT_CASED)
model.to(device)
model.set_paper_embeddings(filename = EMBEDDING_PATH, device = device)
# Define criterion
criterion = CrossEntropyLoss()
# Read checkpoints
if os.path.isfile(checkpoint_file):
logger.info('Load checkpoint from {} ...'.format(checkpoint_file))
checkpoint = torch.load(checkpoint_file, map_location = device)
model.load_state_dict(checkpoint['model_state_dict'])
start_epoch = checkpoint['epoch']
logger.info('Checkpoint {} (epoch {}) loaded.'.format(checkpoint_file, start_epoch))
else:
raise AttributeError('No checkpoint file!')
if MULTIGPU is True:
model = torch.nn.DataParallel(model)
# Result Recorder
recorder = ResultRecorder(paper_num, include_mAP = True, recall_K = RECALL_K)
def evaluate():
logger.info('Start evaluation process.')
model.eval()
recorder.clear()
total_batches = len(val_dataloader)
for idx, data in enumerate(val_dataloader):
left_context, right_context, label, source_label = data
tokens_bert, tokens_specter = model.convert_tokens(list(left_context), list(right_context))
tokens_bert = tokens_bert.to(device)
tokens_specter = tokens_specter.to(device)
label = torch.LongTensor(label).to(device)
with torch.no_grad():
res, res_softmax = model(tokens_bert, tokens_specter)
loss = criterion(res, label)
logger.info('Val batch {}/{}, loss: {:.6f}'.format(idx + 1, total_batches, loss.item()))
recorder.add_record(res_softmax, label, source_label)
logger.info('Finish evaluation process. Now calculating metrics ...')
mAP = recorder.calc_mAP()
mRR = recorder.calc_mRR()
recall_K = recorder.calc_recall_K()
logger.info('mAP: {:.6f}'.format(mAP))
logger.info('MRR: {:.6f}'.format(mRR))
for i, k in enumerate(RECALL_K):
logger.info('Recall@{}: {:.6f}'.format(k, recall_K[i]))
if __name__ == '__main__':
evaluate()