-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval_aa.py
141 lines (108 loc) · 4.35 KB
/
eval_aa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""
Evaluation with AutoAttack.
python eval-aa.py --fname_input xxx --eps_eval xxx --batch_size_for_eval xxx
"""
import json
import time
import argparse
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from autoattack import AutoAttack
from core.data import get_data_info
from core.data import load_data
from core.models import create_model
from core.utils import Logger
from core.utils import seed
from core.utils.utils import str2bool, str2float
# Setup
def parser_eval():
"""
Parse input arguments (eval-adv.py, eval-corr.py, eval-aa.py).
"""
parser = argparse.ArgumentParser(description='Robustness evaluation.')
parser.add_argument('--norm_attack', type=str, default='Linf', choices = ['Linf', 'L2'])
parser.add_argument('--eps_eval', type=float, default=8, help='Random seed.') # 8 for Linf, 0.5 for L2
parser.add_argument('--fname_input', type=str, default='...')
parser.add_argument('--batch_size_for_eval', type=int, default=1024)
parser.add_argument('--early_stopping', type=str2bool, default=True)
parser.add_argument('--train', action='store_true', default=False, help='Evaluate on training set.')
parser.add_argument('-v', '--version', type=str, default='standard', choices=['custom', 'plus', 'standard'],
help='Version of AA.')
parser.add_argument('--seed', type=int, default=1, help='Random seed.')
return parser
parse = parser_eval()
args = parse.parse_args()
if args.norm_attack == 'Linf':
eps_eval = args.eps_eval/255. # will use the eps specified by the parser_eval
else:
eps_eval = args.eps_eval
# accessing and appending the args for training the model
with open(args.fname_input+'/args.txt', 'r') as f:
old = json.load(f)
args.__dict__ = dict(vars(args), **old) # new args = args from parser_eval and training args
DATA_DIR = args.data_dir + args.data
if args.early_stopping:
WEIGHTS = args.fname_input + '/val_best.pt'
else:
WEIGHTS = args.fname_input + '/latest_checkpoint.pt'
log_path = args.fname_input + f'/log-aa_{args.early_stopping}.log'
logger = Logger(log_path)
logger.log('\n\n')
info = get_data_info(DATA_DIR)
BATCH_SIZE = args.batch_size
BATCH_SIZE_VALIDATION = args.batch_size_for_eval
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load data
seed(args.seed)
_, _, train_dataloader, test_dataloader = load_data(DATA_DIR, BATCH_SIZE, BATCH_SIZE_VALIDATION, use_augmentation=False,
shuffle_train=False)
if args.train:
logger.log('Evaluating on training set.')
l = [x for (x, y) in train_dataloader]
x_test = torch.cat(l, 0)
l = [y for (x, y) in train_dataloader]
y_test = torch.cat(l, 0)
else:
l = [x for (x, y) in test_dataloader]
x_test = torch.cat(l, 0)
l = [y for (x, y) in test_dataloader]
y_test = torch.cat(l, 0)
logger.log('evaluation data size:{}'.format(y_test.size(0)))
# Model
# +
model = create_model(args.model, args.normalize, info, device,GroupNorm=args.GroupNorm) ## dataParallel
checkpoint = torch.load(WEIGHTS)
try:
model.load_state_dict(checkpoint['unaveraged_model'])
except:
model.module.load_state_dict(checkpoint['unaveraged_model']) # when checkpt is not dataParallel
model.eval()
# -
# AA Evaluation
# +
seed(args.seed)
if args.norm_attack == 'Linf':
assert args.attack in ['fgsm', 'linf-pgd', 'linf-df', 'linf-apgd']
elif args.norm_attack == 'L2':
assert args.attack in ['fgm', 'l2-pgd', 'l2-df', 'l2-apgd']
else:
raise ValueError('Invalid norm_attack for evaluation')
adversary = AutoAttack(model, norm=args.norm_attack, eps=eps_eval, log_path=log_path, version=args.version, seed=args.seed)
# -
logger.log('{} AA evaluation on:\n{}\n'.format(args.norm_attack, WEIGHTS))
try:
logger.log('epoch {} with val_best {}'.format(checkpoint['epoch'],checkpoint['val_best']))
except:
logger.log('epoch {} with test_best {}'.format(checkpoint['epoch'],checkpoint['test_best']))
del checkpoint
logger.log('eps:{:.4f} batch size:{}\n'.format(eps_eval,BATCH_SIZE_VALIDATION))
if args.version == 'custom':
adversary.attacks_to_run = ['apgd-ce', 'apgd-t', 'fab-t']
adversary.apgd.n_restarts = 1
adversary.apgd_targeted.n_restarts = 1
with torch.no_grad():
x_adv = adversary.run_standard_evaluation(x_test, y_test, bs=BATCH_SIZE_VALIDATION)
print ('Script Completed.')