-
Notifications
You must be signed in to change notification settings - Fork 1
/
inference.py
161 lines (122 loc) · 5.25 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from pathlib import Path
import pandas as pd
import torch
from transformers import BertTokenizer, BertConfig
from config import Config
from model.model import BertForClassification
from utils.functions import prepare_device, set_rs, get_token_logits, get_map_location, filter_model_state_dict, \
get_dataset_type
class Inferencer:
"""Class for inference models.
Note:
Source .csv files should be placed in `data/inference` directory.
You can choose what model to inference in `config.json`.
"""
def __init__(self):
self.config = Config(config_path="config.json")
self.directory = self.config["inference_dir"]
self.tokenizer = BertTokenizer.from_pretrained(self.config["pretrained_model_name"])
dataset_type = get_dataset_type(self.config["table_serialization_type"])
self.preprocess_tables()
self.dataset = dataset_type(
tokenizer=self.tokenizer,
num_rows=self.config["dataset"]["num_rows"],
data_dir=self.config["dataset"]["data_dir"] + "inference/preprocessed/",
file_name=None
)
self.model = BertForClassification(
BertConfig.from_pretrained(self.config["pretrained_model_name"], num_labels=self.config["num_labels"])
)
checkpoint = torch.load(
self.config["checkpoint_dir"] + self.config["inference_model_name"],
map_location=get_map_location()
)
self.model.load_state_dict(filter_model_state_dict(checkpoint["model_state_dict"]))
self.device, device_ids = prepare_device(self.config["num_gpu"])
self.model = self.model.to(self.device)
if len(device_ids) > 1:
self.model = torch.nn.DataParallel(self.model, device_ids=device_ids)
def inference(self) -> None:
"""Inference tables and save result.
Returns:
None
"""
inference_result = self._inference()
sem_types = pd.read_csv("data/sem_types.csv")
inference_result["labels"] = inference_result["labels"].apply(lambda x: x[0])
def helper(x: list):
r = []
for i in x:
label = sem_types.iloc[i]["label"]
r.append(label)
return r
inference_result["labels"] = inference_result["labels"].apply(lambda x: helper(x))
inference_result.to_csv(self.directory + "result.csv", index=False, sep="|")
def _inference(self):
"""Inference samples in dataset.
Returns:
pd.DataFrame: Contains `table_id` and labels.
"""
set_rs(self.config["random_seed"])
result_df = []
self.model.eval()
with torch.no_grad():
for sample in self.dataset:
logits = []
data = sample["data"].to(self.device)
seq = data.unsqueeze(0)
attention_mask = torch.clone(seq != 0)
probs = self.model(seq, attention_mask=attention_mask)
if isinstance(probs, tuple):
probs = probs[0]
cls_probs = get_token_logits(self.device, seq, probs, self.tokenizer.cls_token_id)
logits.append(cls_probs.argmax(1).cpu().detach().numpy().tolist())
result_df.append([sample["table_id"], logits])
return pd.DataFrame(
result_df,
columns=["table_id", "labels"]
)
def preprocess_tables(self) -> None:
"""Collect all tables to inference and call preprocess.
Collects csv files from `data/inference` and then calls
preprocess function.
Returns:
None
"""
files = [
f for f in Path(self.directory).iterdir()
if f.is_file() and f.name.endswith(".csv") and f.name != "result.csv"
]
preprocessed_dir = self.directory + "preprocessed/"
Path(preprocessed_dir).mkdir(parents=True, exist_ok=True)
if len(files) == 1:
preprocessed_table = self.preprocess_table(files[0].name)
preprocessed_table.to_csv(f"{preprocessed_dir}/data.csv", index=False, sep="|")
else:
for i, file_name in enumerate(files):
preprocessed_table = self.preprocess_table(file_name.name)
preprocessed_table.to_csv(f"{preprocessed_dir}/data_{i}.csv", index=False, sep="|")
def preprocess_table(self, filename: str):
"""Preprocess table.
Preprocess given table and save in `data/inference/preprocess/` directory.
Args:
filename: Table filename.
Returns:
pd.DataFrame: Table as dataframe.
"""
table = pd.read_csv(f"{self.directory}{filename}", header=None)
data_list = []
for i in table.columns:
column_id = i
label_id = 0
label = "none"
column_data = " ".join(list(map(lambda x: str(x).strip(), table[i])))
data_list.append([filename, column_id, label_id, label, column_data])
preprocessed_table = pd.DataFrame(
data_list,
columns=["table_id", "column_id", "label_id", "label", "column_data"]
)
return preprocessed_table
if __name__ == "__main__":
inference = Inferencer()
inference.inference()