-
Notifications
You must be signed in to change notification settings - Fork 3
/
merge_8datasets.py
139 lines (114 loc) · 4.3 KB
/
merge_8datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import torch
import numpy as np
import wandb
import os
from src.args import parse_arguments
from src.merging.task_vectors import TaskVector, merge_max_abs, merge_rnd_mix
from src.merging.ties import merge_methods, state_dict_to_vector, vector_to_state_dict
from src.eval import eval_single_dataset, eval_task_aware
# Config
args = parse_arguments()
args.model = 'ViT-B-16'
pretrained_checkpoint = f'checkpoints/{args.model}/zeroshot.pt'
def search_evaluate_merging(datasets, task_vectors, args, n_coeffs=20):
funcs_and_coeffs = [
# (merge_rnd_mix, np.linspace(0.5, 1.5, num=n_coeffs+1)[1:]),
# (merge_max_abs, np.linspace(0.0, 1.0, num=n_coeffs+1)[1:]),
# (sum, np.linspace(0.0, 2.0/args.n_splits, num=n_coeffs+1)[1:]),
(merge_rnd_mix, [1.0]),
(merge_max_abs, [0.5]),
(sum, [1.0/args.n_splits]),
]
for f, coeffs in funcs_and_coeffs:
print(f"\nMerging with function: {f.__name__}")
merged_tv = f(task_vectors)
# Apply the resulting task vector
results = {}
for scaling_coef in coeffs:
print(f"Scaling coeff: {scaling_coef}")
image_encoder = merged_tv.apply_to(pretrained_checkpoint, scaling_coef=scaling_coef)
# Evaluate
_r = {}
print(datasets)
for ds in datasets:
_r[f"merging/{ds}/{f.__name__}"] = eval_single_dataset(image_encoder, ds, args)['top1'] * 100.0
wandb.log({
**_r,
"helpers/merging/alpha": scaling_coef,
})
results[scaling_coef] = _r
print(f"Results with function {f.__name__}:\n{results}")
# TIES merging
reset_type = 'topk'
reset_thresh = 20
resolve = 'mass'
merge = 'dis-mean'
tv_flat_checks = torch.vstack([state_dict_to_vector(tv.vector) for tv in task_vectors])
print(f"\nMerging with TIES merging: pruning {reset_type}-{reset_thresh}, resolve sign by {resolve}, merge by {merge}")
merged_flat_tv = merge_methods(
reset_type,
tv_flat_checks,
reset_thresh=reset_thresh,
resolve_method=resolve,
merge_func=merge,
)
merged_tv = vector_to_state_dict(
merged_flat_tv, task_vectors[0].vector, remove_keys=[]
)
merged_tv = TaskVector(vector=merged_tv)
# Apply the resulting task vector
results = {}
# for scaling_coef in np.linspace(0.55, 1.5, num=n_coeffs+1):
for scaling_coef in [0.55]:
print(f"Scaling coeff: {scaling_coef}")
image_encoder = merged_tv.apply_to(pretrained_checkpoint, scaling_coef=scaling_coef)
# Evaluate
_r = {}
for ds in datasets:
_r[f"merging/{ds}/TIES"] = eval_single_dataset(image_encoder, ds, args)['top1'] * 100.0
wandb.log({
**_r,
"helpers/merging/alpha": scaling_coef,
})
results[scaling_coef] = _r
print(f"Results with function TIES:\n{results}")
if __name__ == '__main__':
datasets = ['Cars', 'MNIST', 'EuroSAT', 'SVHN', 'RESISC45', 'SUN397', 'DTD', 'GTSRB']
epochs = {
'Cars': 35,
'DTD': 75,
'EuroSAT': 12,
'GTSRB': 11,
'MNIST': 5,
'RESISC45': 15,
'SUN397': 14,
'SVHN': 4,
'ImageNet': 4
}
args.eval_datasets = [args.dataset]
args.lr = 1e-5
args.batch_size = 128
args.n_splits = len(datasets)
method = 'seq-ft' if args.sequential_finetuning else 'ind-ft'
wandb.init(
project="magmax",
group=f"merging-8datasets",
entity=args.wandb_entity_name,
mode='online',
name=f"merging-8datasets-{method}",
config=args,
tags=["merging", "8datasets", {method}],
)
if args.sequential_finetuning:
base_path = f"checkpoints/{args.model}/8datasets/{'->'.join(datasets)}"
else:
base_path = f'checkpoints/{args.model}/8datasets/ind'
task_vectors = []
for ds in datasets:
ft_path = os.path.join(base_path, f'{ds}/checkpoint-epochs:{epochs[ds]}-seed:{args.seed}.pt')
tv = TaskVector(pretrained_checkpoint, ft_path)
task_vectors.append(tv)
print('='*100)
print(f"\nEVAL 8 datasets")
print('='*100)
search_evaluate_merging(datasets, task_vectors, args, n_coeffs=10)