-
Notifications
You must be signed in to change notification settings - Fork 3
/
ewc.py
196 lines (154 loc) · 7.29 KB
/
ewc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import os
import time
import torch
from copy import deepcopy
import itertools
from src.args import parse_arguments
from src.datasets.common import get_dataloader, maybe_dictionarize
from src.datasets.registry import get_dataset
from src.modeling import ImageEncoder, ImageClassifier
from src.utils import cosine_lr, LabelSmoothing
from src.cl_utils import get_dataset_and_classifier_for_split
from src.eval import evaluate
def compute_fisher_matrix_diag(model, trn_loader):
print("Starting computing diagonal of Fisher Information Matrix")
# Store Fisher Information
fisher = {n: torch.zeros(p.shape).to("cuda") for n, p in model.named_parameters()
if p.requires_grad}
# Compute fisher information for specified number of samples -- rounded to the batch size
num_samples = len(trn_loader.dataset)
n_samples_batches = (num_samples // trn_loader.batch_size + 1) if num_samples > 0 \
else (len(trn_loader.dataset) // trn_loader.batch_size)
# Do forward and backward pass to compute the fisher information
model.train()
model = model.cuda()
for images, targets in itertools.islice(trn_loader, n_samples_batches):
images = images.to("cuda")
outputs = model(images)
preds = outputs.argmax(1)
loss = torch.nn.functional.cross_entropy(outputs, preds)
# self.optimizer.zero_grad()
loss.backward()
# Accumulate all gradients from loss with regularization
for n, p in model.named_parameters():
if p.grad is not None:
fisher[n] += p.grad.pow(2) * len(targets)
# Apply mean across all samples
n_samples = n_samples_batches * trn_loader.batch_size
fisher = {n: (p / n_samples) for n, p in fisher.items()}
print("Finished computing diagonal of Fisher Information Matrix")
return fisher
def calc_ewc_loss(backbone, fisher, older_params):
"""Returns the distillation loss value"""
loss_reg = 0
# Eq. 3: elastic weight consolidation quadratic penalty
for n, p in backbone.named_parameters():
if n in fisher.keys():
loss_reg += torch.sum(fisher[n] * (p - older_params[n]).pow(2)) / 2
return loss_reg
def finetune(args):
train_dataset = args.dataset
# finetune for each split separately
for split_idx in range(args.n_splits):
print(f"\n##### SPLIT {split_idx} #####")
ckpdir = os.path.join(args.save, f"{train_dataset}-{args.n_splits}", f"ft-epochs-{args.epochs}-seed:{args.seed}-lamb:{args.ewc_lamb}")
ft_path = os.path.join(ckpdir, f'finetuned_{split_idx}.pt')
if split_idx == 0:
print('Building image encoder.')
image_encoder = ImageEncoder(args, keep_lang=True)
preprocess_fn = image_encoder.train_preprocess
print_every = 10
dataset = get_dataset(
train_dataset,
preprocess_fn,
location=args.data_location,
batch_size=args.batch_size
)
dataset, classification_head = get_dataset_and_classifier_for_split(
dataset, split_idx, image_encoder, args
)
model = ImageClassifier(image_encoder, classification_head)
model.freeze_head()
model.freeze_lang()
devices = list(range(torch.cuda.device_count()))
print('Using devices', devices)
model = torch.nn.DataParallel(model, device_ids=devices)
if args.ls > 0:
loss_fn = LabelSmoothing(args.ls)
else:
loss_fn = torch.nn.CrossEntropyLoss()
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(params, lr=args.lr, weight_decay=args.wd)
num_batches = len(dataset.train_loader)
scheduler = cosine_lr(optimizer, args.lr, args.warmup_length, args.epochs * num_batches)
data_loader = get_dataloader(dataset, is_train=True, args=args, image_encoder=None)
n_batches = len(data_loader)
if split_idx == 0:
fisher = {n: torch.zeros(p.shape, device="cuda") for n, p in model.named_parameters() if p.requires_grad}
if args.save is not None:
os.makedirs(ckpdir, exist_ok=True)
model = model.cuda()
model.train()
for epoch in range(args.epochs):
for i, batch in enumerate(data_loader):
start_time = time.time()
step = i + epoch * num_batches
scheduler(step)
optimizer.zero_grad()
batch = maybe_dictionarize(batch)
inputs = batch['images'].to('cuda:0')
labels = batch['labels'].to('cuda:0')
data_time = time.time() - start_time
if split_idx > 0:
# ewc_lamb
logits = model(inputs)
clsf_loss = loss_fn(logits, labels)
ewc_loss = calc_ewc_loss(model, fisher, old_params)
loss = clsf_loss + args.ewc_lamb * ewc_loss
else:
logits = model(inputs)
loss = loss_fn(logits, labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(params, 1.0)
optimizer.step()
batch_time = time.time() - start_time
if step % print_every == 0 or i + 1 == n_batches:
percent_complete = 100 * i / len(data_loader)
if split_idx == 0:
print(
f"Train Epoch: {epoch} [{percent_complete:.0f}% {i}/{len(dataset.train_loader)}]\t"
f"Loss: {loss.item():.6f}\tData (t) {data_time:.3f}\tBatch (t) {batch_time:.3f}", flush=True
)
else:
print(
f"Train Epoch: {epoch} [{percent_complete:.0f}% {i}/{len(dataset.train_loader)}]\t"
f"Loss: {loss.item():.6f}\t Loss clsf: {clsf_loss.item():.6f}\tLoss EWC: {ewc_loss:.6f}\t"
f"Data (t) {data_time:.3f}\tBatch (t) {batch_time:.3f}", flush=True
)
# Evaluate
# evaluate(image_encoder, args)
image_encoder = model.module.image_encoder
if args.save is not None:
image_encoder.save(ft_path)
# Store current parameters for the next task
old_params = {n: p.clone().detach().to("cuda") for n, p in model.named_parameters() if p.requires_grad}
# calculate Fisher information
curr_fisher = compute_fisher_matrix_diag(model, data_loader)
# merge fisher information, we do not want to keep fisher information for each task in memory
alpha = 0.5
for n in fisher.keys():
fisher[n] = (alpha * fisher[n] + (1 - alpha) * curr_fisher[n])
evaluate(image_encoder, args)
if __name__ == '__main__':
args = parse_arguments()
# args.model = 'ViT-B-16'
args.lr = 1e-5
args.batch_size = 128
args.sequential_finetuning = True
args.split_strategy = 'class'
args.save = f'checkpoints/{args.model}/ewc'
args.eval_datasets = [args.dataset]
print('='*100)
print(f'Finetuning {args.model} on {args.dataset} ({args.n_splits} splits)')
print('='*100)
finetune(args)