-
Notifications
You must be signed in to change notification settings - Fork 0
/
tpt_classification.py
494 lines (396 loc) · 18.5 KB
/
tpt_classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
import argparse
import time
from copy import deepcopy
from PIL import Image
import numpy as np
import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
try:
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
except ImportError:
BICUBIC = Image.BICUBIC
import torchvision.models as models
from clip.custom_clip import get_coop
from clip.cocoop import get_cocoop
from data.imagnet_prompts import imagenet_classes
from data.datautils import AugMixAugmenter, build_dataset
from utils.tools import Summary, AverageMeter, ProgressMeter, accuracy, load_model_weight, set_random_seed
from data.cls_to_names import *
from data.fewshot_datasets import fewshot_datasets
from data.imagenet_variants import thousand_k_to_200, imagenet_a_mask, imagenet_r_mask, imagenet_v_mask
import ipdb
import math
import pickle
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
def ECE_Loss(num_bins, predictions, confidences, correct):
#ipdb.set_trace()
bin_boundaries = torch.linspace(0, 1, num_bins + 1)
bin_lowers = bin_boundaries[:-1]
bin_uppers = bin_boundaries[1:]
bin_accuracy = [0]*num_bins
bin_confidence = [0]*num_bins
bin_num_sample = [0]*num_bins
for idx in range(len(predictions)):
#prediction = predictions[idx]
confidence = confidences[idx]
bin_idx = -1
for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
bin_idx += 1
bin_lower = bin_lower.item()
bin_upper = bin_upper.item()
#if bin_lower <= confidence and confidence < bin_upper:
if bin_lower < confidence and confidence <= bin_upper:
bin_num_sample[bin_idx] += 1
bin_accuracy[bin_idx] += correct[idx]
bin_confidence[bin_idx] += confidences[idx]
for idx in range(num_bins):
if bin_num_sample[idx] != 0:
bin_accuracy[idx] = bin_accuracy[idx]/bin_num_sample[idx]
bin_confidence[idx] = bin_confidence[idx]/bin_num_sample[idx]
ece_loss = 0.0
for idx in range(num_bins):
temp_abs = abs(bin_accuracy[idx]-bin_confidence[idx])
ece_loss += (temp_abs*bin_num_sample[idx])/len(predictions)
return ece_loss, bin_accuracy, bin_confidence, bin_num_sample
def Calculator(result_dict):
list_max_confidence = result_dict['max_confidence']
list_prediction = result_dict['prediction']
list_label = result_dict['label']
torch_list_prediction = torch.tensor(list_prediction).int()
torch_list_label = torch.tensor(list_label).int()
torch_correct = (torch_list_prediction == torch_list_label)
list_correct = torch_correct.tolist()
ece_data = ECE_Loss(20, list_prediction, list_max_confidence, list_correct)
acc = sum(list_correct)/len(list_correct)
print('acc: ', acc*100)
print('ece: ', ece_data[0]*100)
return
def select_confident_samples(logits, top):
batch_entropy = -(logits.softmax(1) * logits.log_softmax(1)).sum(1)
idx = torch.argsort(batch_entropy, descending=False)[:int(batch_entropy.size()[0] * top)]
return logits[idx], idx
def avg_entropy(outputs):
logits = outputs - outputs.logsumexp(dim=-1, keepdim=True) # logits = outputs.log_softmax(dim=1) [N, 1000]
avg_logits = logits.logsumexp(dim=0) - np.log(logits.shape[0]) # avg_logits = logits.mean(0) [1, 1000]
min_real = torch.finfo(avg_logits.dtype).min
avg_logits = torch.clamp(avg_logits, min=min_real)
return -(avg_logits * torch.exp(avg_logits)).sum(dim=-1)
def test_time_tuning(model, inputs, optimizer, scaler, args):
output = None
output2 = None
single_output = None
if args.cocoop:
image_feature, pgen_ctx = inputs
pgen_ctx.requires_grad = True
optimizer = torch.optim.AdamW([pgen_ctx], args.lr)
selected_idx = None
for j in range(args.tta_steps):
if 'tpt' in args.run_type:
with torch.cuda.amp.autocast():
if args.cocoop:
output = model((image_feature, pgen_ctx))
else:
output = model(inputs)
if selected_idx is not None:
output = output[selected_idx]
else:
output, selected_idx = select_confident_samples(output, args.selection_p)
loss = avg_entropy(output)
else:
loss = 0
if args.two_step and 'tpt' in args.run_type:
optimizer.zero_grad()
# compute gradient and do SGD step
scaler.scale(loss).backward(retain_graph=True)
# Unscales the gradients of optimizer's assigned params in-place
scaler.step(optimizer)
scaler.update()
loss = 0
with torch.cuda.amp.autocast():
if args.cocoop:
output2 = model((image_feature, pgen_ctx))
else:
output2 = model(inputs)
if 'ctpt' in args.run_type:
if output == None and output2 == None:
single_output = model(args.image)
lambda_ = args.lambda_term
loss += (-lambda_* model.l2_norm_mean_training)
if args.run_type not in ['baseline', 'baseline_cocoop', 'baseline_coop', 'baseline_ts']:
optimizer.zero_grad()
# compute gradient and do SGD step
scaler.scale(loss).backward()
# Unscales the gradients of optimizer's assigned params in-place
scaler.step(optimizer)
scaler.update()
if args.cocoop:
return pgen_ctx
return
def main(args, result_dict):
set_random_seed(args.seed)
# This codebase has only been tested under the single GPU setting
assert args.gpu is not None
main_worker(args.gpu, args, result_dict)
def main_worker(gpu, args, result_dict):
args.gpu = gpu
set_random_seed(args.seed)
print("Use GPU: {} for training".format(args.gpu))
# create model (zero-shot clip model (ViT-L/14@px336) with promptruning)
if args.test_sets in fewshot_datasets:
classnames = eval("{}_classes".format(args.test_sets.lower()))
else:
classnames = imagenet_classes
if args.cocoop:
model = get_cocoop(args.arch, args.test_sets, 'cpu', args.n_ctx)
assert args.load is not None
load_model_weight(args.load, model, 'cpu', args) # to load to cuda: device="cuda:{}".format(args.gpu)
model_state = deepcopy(model.state_dict())
else:
model = get_coop(args.arch, args.test_sets, args.gpu, args.n_ctx, args.ctx_init)
if args.load is not None:
print("Use pre-trained soft prompt (CoOp) as initialization")
pretrained_ctx = torch.load(args.load)['state_dict']['ctx']
assert pretrained_ctx.size()[0] == args.n_ctx
with torch.no_grad():
#model.prompt_learner[0].ctx.copy_(pretrained_ctx)
#model.prompt_learner[0].ctx_init_state = pretrained_ctx
model.prompt_learner.ctx.copy_(pretrained_ctx)
model.prompt_learner.ctx_init_state = pretrained_ctx
model_state = None
for name, param in model.named_parameters():
if not args.cocoop:
if "prompt_learner" not in name:
param.requires_grad_(False)
else:
if "text_encoder" not in name:
param.requires_grad_(False)
print("=> Model created: visual backbone {}".format(args.arch))
if not torch.cuda.is_available():
print('using CPU, this will be slow')
else:
assert args.gpu is not None
torch.cuda.set_device(args.gpu)
model = model.cuda(args.gpu)
# define optimizer
if args.cocoop:
optimizer = None
optim_state = None
else:
trainable_param = model.prompt_learner.parameters()
optimizer = torch.optim.AdamW(trainable_param, args.lr)
optim_state = deepcopy(optimizer.state_dict())
# setup automatic mixed-precision (Amp) loss scaling
scaler = torch.cuda.amp.GradScaler(init_scale=1000)
print('=> Using native Torch AMP. Training in mixed precision.')
cudnn.benchmark = True
# norm stats from clip.load()
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711])
# iterating through eval datasets
datasets = args.test_sets.split("/")
assert len(datasets) == 1
results = {}
for set_id in datasets:
if args.tpt:
base_transform = transforms.Compose([
transforms.Resize(args.resolution, interpolation=BICUBIC),
transforms.CenterCrop(args.resolution)])
preprocess = transforms.Compose([
transforms.ToTensor(),
normalize])
if args.I_augmix:
data_transform = AugMixAugmenter(base_transform, preprocess, n_views=args.batch_size-1,
augmix=len(set_id)>=1)
else:
data_transform = AugMixAugmenter(base_transform, preprocess, n_views=args.batch_size-1,
augmix=len(set_id)>1)
batchsize = 1
else:
data_transform = transforms.Compose([
transforms.Resize(args.resolution, interpolation=BICUBIC),
transforms.CenterCrop(args.resolution),
transforms.ToTensor(),
normalize,
])
batchsize = args.batch_size
print("evaluating: {}".format(set_id))
# reset the model
# Reset classnames of custom CLIP model
if len(set_id) > 1:
# fine-grained classification datasets
classnames = eval("{}_classes".format(set_id.lower()))
else:
assert set_id in ['A', 'R', 'K', 'V', 'I']
classnames_all = imagenet_classes
classnames = []
if set_id in ['A', 'R', 'V']:
label_mask = eval("imagenet_{}_mask".format(set_id.lower()))
if set_id == 'R':
for i, m in enumerate(label_mask):
if m:
classnames.append(classnames_all[i])
else:
classnames = [classnames_all[i] for i in label_mask]
else:
classnames = classnames_all
if args.cocoop:
model.prompt_generator.reset_classnames(classnames, args.arch)
model = model.cpu()
model_state = model.state_dict()
model = model.cuda(args.gpu)
else:
model.reset_classnames(classnames, args.arch)
val_dataset = build_dataset(set_id, data_transform, args.data, mode=args.dataset_mode)
print("number of test samples: {}".format(len(val_dataset)))
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=batchsize, shuffle=True,
num_workers=args.workers, pin_memory=True)
results[set_id] = test_time_adapt_eval(val_loader, model, model_state, optimizer, optim_state, scaler, args, result_dict)
del val_dataset, val_loader
try:
print("=> Acc. on testset [{}]: @1 {}/ @5 {}".format(set_id, results[set_id][0], results[set_id][1]))
except:
print("=> Acc. on testset [{}]: {}".format(set_id, results[set_id]))
print("======== Result Summary ========")
print("params: nstep lr bs")
print("params: {} {} {}".format(args.tta_steps, args.lr, args.batch_size))
print("\t\t [set_id] \t\t Top-1 acc. \t\t Top-5 acc.")
for id in results.keys():
print("{}".format(id), end=" ")
print("\n")
for id in results.keys():
print("{:.2f}".format(results[id][0]), end=" ")
print("\n")
def test_time_adapt_eval(val_loader, model, model_state, optimizer, optim_state, scaler, args, result_dict):
batch_time = AverageMeter('Time', ':6.3f', Summary.NONE)
top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE)
top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE)
progress = ProgressMeter(
len(val_loader),
[batch_time, top1, top5],
prefix='Test: ')
# reset model and switch to evaluate mode
model.eval()
if not args.cocoop: # no need to reset cocoop because it's fixed
with torch.no_grad():
model.reset()
end = time.time()
#define a softmax layer
softmax = torch.nn.Softmax(dim=1)
if 'ctpt' in args.run_type:
model.l2_norm_cal = True
else:
model.l2_norm_cal = False
for i, (images, target) in enumerate(val_loader):
assert args.gpu is not None
if isinstance(images, list):
for k in range(len(images)):
images[k] = images[k].cuda(args.gpu, non_blocking=True)
image = images[0]
else:
if len(images.size()) > 4:
# when using ImageNet Sampler as the dataset
assert images.size()[0] == 1
images = images.squeeze(0)
images = images.cuda(args.gpu, non_blocking=True)
image = images
target = target.cuda(args.gpu, non_blocking=True)
if args.tpt:
images = torch.cat(images, dim=0)
if 'ctpt' in args.run_type:
args.image = image
# reset the tunable prompt to its initial state
if not args.cocoop: # no need to reset cocoop because it's fixed
if args.tta_steps > 0:
with torch.no_grad():
model.reset()
optimizer.load_state_dict(optim_state)
test_time_tuning(model, images, optimizer, scaler, args)
else:
with torch.no_grad():
with torch.cuda.amp.autocast():
image_feature, pgen_ctx = model.gen_ctx(images, args.tpt)
optimizer = None
pgen_ctx = test_time_tuning(model, (image_feature, pgen_ctx), optimizer, scaler, args)
# The actual inference goes here
if args.tpt:
if args.cocoop:
image_feature = image_feature[0].unsqueeze(0)
with torch.no_grad():
with torch.cuda.amp.autocast():
if args.cocoop:
output = model((image_feature, pgen_ctx))
else:
output = model(image)
if 'ts' not in args.run_type:
softmax_output = softmax(output)
elif 'ts' in args.run_type:
if 'ViT' in args.arch:
softmax_output = softmax(output/temperature_value['ViT'])
elif 'RN' in args.arch:
softmax_output = softmax(output/temperature_value['RN'])
else:
ipdb.set_trace()
#maximum confidence of the softmax_output and its index
max_confidence, max_index = torch.max(softmax_output, 1)
#save the max confidence, prediction, and label to the result_dict
result_dict['max_confidence'].append(max_confidence.item())
result_dict['prediction'].append(max_index.item())
result_dict['label'].append(target.item())
# measure accuracy and record loss
acc1, acc5 = accuracy(output, target, topk=(1, 5))
top1.update(acc1[0], image.size(0))
top5.update(acc5[0], image.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if (i+1) % args.print_freq == 0:
progress.display(i)
progress.display_summary()
return [top1.avg, top5.avg]
temperature_value = {'ViT': 1.16, 'RN': 1.15} #for temperature scaling experiments
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Test-time Prompt Tuning')
parser.add_argument('data', metavar='DIR', help='path to dataset root')
parser.add_argument('--test_sets', type=str, default='A/R/V/K/I', help='test dataset (multiple datasets split by slash)')
parser.add_argument('--dataset_mode', type=str, default='test', help='which split to use: train/val/test')
parser.add_argument('-a', '--arch', metavar='ARCH', default='RN50')
parser.add_argument('--resolution', default=224, type=int, help='CLIP image resolution')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('-b', '--batch-size', default=64, type=int, metavar='N')
parser.add_argument('--lr', '--learning-rate', default=5e-3, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('-p', '--print-freq', default=200, type=int,
metavar='N', help='print frequency (default: 10)')
parser.add_argument('--gpu', default=0, type=int,
help='GPU id to use.')
parser.add_argument('--tpt', action='store_true', default=False, help='run test-time prompt tuning')
parser.add_argument('--selection_p', default=0.1, type=float, help='confidence selection percentile')
parser.add_argument('--tta_steps', default=1, type=int, help='test-time-adapt steps')
parser.add_argument('--n_ctx', default=4, type=int, help='number of tunable tokens')
parser.add_argument('--ctx_init', default=None, type=str, help='init tunable prompts')
parser.add_argument('--cocoop', action='store_true', default=False, help="use cocoop's output as prompt initialization")
parser.add_argument('--load', default=None, type=str, help='path to a pre-trained coop/cocoop')
parser.add_argument('--seed', type=int, default=0)
# added args for c-tpt --------------------------------
parser.add_argument('--lambda_term' , type=float, default=0.0, help='lambda for c-tpt')
parser.add_argument('--run_type' , type=str, default='baseline_tpt', choices=['baseline', 'tpt', 'tpt_ctpt', 'tpt_ts'])
parser.add_argument('--two_step', action='store_true', default=False, help='two step training')
parser.add_argument('--I_augmix', action='store_true', default=False, help='augmix for I')
# ------------------------------------------------
args = parser.parse_args()
if 'ctpt' not in args.run_type:
args.lambda_term = 0.0
result_dict = {'max_confidence': [], 'prediction': [], 'label': []}
main(args, result_dict)
acc, ece = Calculator(result_dict)