-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval.py
86 lines (69 loc) · 2.64 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import argparse
from sklearn.model_selection import StratifiedKFold
import torch
from fvcore.nn import FlopCountAnalysis
import torch.nn.functional as F
from config import Dataset, EnumAction
from dataset import get_dataset
def eval_(model, device, data_loader):
model.eval()
correct = 0
loss_test = 0.0
predicted = torch.tensor([]).to(device)
actual = torch.tensor([]).to(device)
for data in data_loader:
data = data.to(device)
x, edge_index, batch = data.x, data.edge_index, data.batch
out = model(x, edge_index, batch=batch)
p = out.argmax(dim=1)
correct += (p == data.y).sum().item()
loss_test += F.nll_loss(out, data.y).item()
predicted = torch.cat((predicted, p))
actual = torch.cat((actual, data.y))
acc = correct / len(data_loader.dataset)
avg_loss = loss_test / len(data_loader.dataset)
return (
acc,
avg_loss,
actual.to("cpu", dtype=torch.int),
predicted.to("cpu", dtype=torch.int),
)
# https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/7
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def get_flops(model, device, x, edge_index, batch=None):
model.eval()
return FlopCountAnalysis(
model,
inputs=(
x.to(device),
edge_index.to(device),
None if batch is None else batch.to(device),
),
).total()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Evaluate a model")
parser.add_argument(
"-d",
"--dataset",
action=EnumAction,
enum_type=Dataset,
required=True,
help="Choose a dataset from: %(choices)s",
)
parser.add_argument("-m", "--model", type=str, required=True, help="Model path")
parser.add_argument("--k-folds", type=int, default=5, help="Number of folds")
parser.add_argument("--fold", type=int, default=0, help="Fold to evaluate")
parser.add_argument("--seed", type=int, default=42, help="Random seed")
parser.add_argument("--device", type=str, default="cuda:0", help="Device")
args = parser.parse_args()
ds = get_dataset(args.dataset)
skf = StratifiedKFold(n_splits=args.k_folds, shuffle=True, random_state=args.seed)
train_index, test_index = list(skf.split(ds, ds.y))[args.fold]
# train_ds = ds[train_index]
test_ds = ds[test_index]
model = torch.load(args.model).to(args.device)
acc, _, _, _ = eval_(args, args.device, test_ds)
print("Acc:", acc)
param_count = count_parameters(model)
print("Parameters:", param_count)