-
Notifications
You must be signed in to change notification settings - Fork 37
/
Copy pathhubconf.py
182 lines (151 loc) · 6.9 KB
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
dependencies = ["torch", "timm", "einops"]
import os
from typing import Dict, Any, Optional, Union, List
import warnings
import torch
from torch.hub import load_state_dict_from_url
from timm.models import clean_state_dict
from radio.adaptor_registry import adaptor_registry
from radio.common import DEFAULT_VERSION, RadioResource, RESOURCE_MAP
from radio.enable_spectral_reparam import disable_spectral_reparam, configure_spectral_reparam_from_args
from radio.feature_normalizer import FeatureNormalizer, IntermediateFeatureNormalizer
from radio.radio_model import RADIOModel, create_model_from_args
from radio.input_conditioner import get_default_conditioner
from radio.vitdet import apply_vitdet_arch, VitDetArgs
def radio_model(
version: str = "",
progress: bool = True,
adaptor_names: Union[str, List[str]] = None,
vitdet_window_size: Optional[int] = None,
return_checkpoint: bool = False,
**kwargs,
) -> RADIOModel:
if not version:
version = DEFAULT_VERSION
if os.path.isfile(version):
chk = torch.load(version, map_location="cpu")
resource = RadioResource(version, patch_size=None, max_resolution=None, preferred_resolution=None)
else:
resource = RESOURCE_MAP[version]
chk = load_state_dict_from_url(
resource.url, progress=progress, map_location="cpu"
)
if "state_dict_ema" in chk:
state_dict = chk["state_dict_ema"]
chk['args'].spectral_reparam = False
else:
state_dict = chk["state_dict"]
args = chk["args"]
mod = create_model_from_args(args)
mod_state_dict = get_prefix_state_dict(state_dict, "base_model.")
if args.spectral_reparam:
configure_spectral_reparam_from_args(mod, args, state_dict_guidance=mod_state_dict)
state_dict = clean_state_dict(state_dict)
key_warn = mod.load_state_dict(mod_state_dict, strict=False)
if key_warn.missing_keys:
warnings.warn(f'Missing keys in state dict: {key_warn.missing_keys}')
if key_warn.unexpected_keys:
warnings.warn(f'Unexpected keys in state dict: {key_warn.unexpected_keys}')
if chk['args'].spectral_reparam:
# Spectral reparametrization uses PyTorch's "parametrizations" API. The idea behind
# the method is that instead of there being a `weight` tensor for certain Linear layers
# in the model, we make it a dynamically computed function. During training, this
# helps stabilize the model. However, for downstream use cases, it shouldn't be necessary.
# Disabling it in this context means that instead of having `w' = f(w)`, we just compute `w' = f(w)`
# once, during this function call, and replace the parametrization with the realized weights.
# This makes the model run faster, and also use less memory.
disable_spectral_reparam(mod)
chk['args'].spectral_reparam = False
conditioner = get_default_conditioner()
conditioner.load_state_dict(get_prefix_state_dict(state_dict, "input_conditioner."))
dtype = getattr(chk['args'], 'dtype', torch.float32)
mod.to(dtype=dtype)
conditioner.dtype = dtype
name_to_idx_map = dict()
for i, t in enumerate(chk['args'].teachers):
if t.get('use_summary', True):
name = t['name']
if name not in name_to_idx_map:
name_to_idx_map[name] = i
summary_idxs = torch.tensor(sorted(name_to_idx_map.values()), dtype=torch.int64)
if adaptor_names is None:
adaptor_names = []
elif isinstance(adaptor_names, str):
adaptor_names = [adaptor_names]
teachers = chk["args"].teachers
adaptors = dict()
for adaptor_name in adaptor_names:
for tidx, tconf in enumerate(teachers):
if tconf["name"] == adaptor_name:
break
else:
raise ValueError(f'Unable to find the specified adaptor name. Known names: {list(t["name"] for t in teachers)}')
ttype = tconf["type"]
pf_idx_head = f'_heads.{tidx}'
pf_name_head = f'_heads.{adaptor_name}'
pf_idx_feat = f'_feature_projections.{tidx}'
pf_name_feat = f'_feature_projections.{adaptor_name}'
adaptor_state = dict()
for k, v in state_dict.items():
if k.startswith(pf_idx_head):
adaptor_state['summary' + k[len(pf_idx_head):]] = v
elif k.startswith(pf_name_head):
adaptor_state['summary' + k[len(pf_name_head):]] = v
elif k.startswith(pf_idx_feat):
adaptor_state['feature' + k[len(pf_idx_feat):]] = v
elif k.startswith(pf_name_feat):
adaptor_state['feature' + k[len(pf_name_feat):]] = v
adaptor = adaptor_registry.create_adaptor(ttype, chk["args"], tconf, adaptor_state)
adaptor.head_idx = tidx
adaptors[adaptor_name] = adaptor
feat_norm_sd = get_prefix_state_dict(state_dict, '_feature_normalizer.')
feature_normalizer = None
if feat_norm_sd:
feature_normalizer = FeatureNormalizer(feat_norm_sd['mean'].shape[0], dtype=dtype)
feature_normalizer.load_state_dict(feat_norm_sd)
inter_feat_norm_sd = get_prefix_state_dict(state_dict, '_intermediate_feature_normalizer.')
inter_feature_normalizer = None
if inter_feat_norm_sd:
inter_feature_normalizer = IntermediateFeatureNormalizer(
*inter_feat_norm_sd['means'].shape[:2],
rot_per_layer=inter_feat_norm_sd['rotation'].ndim == 3,
dtype=dtype
)
inter_feature_normalizer.load_state_dict(inter_feat_norm_sd)
radio = RADIOModel(
mod,
conditioner,
summary_idxs=summary_idxs,
patch_size=resource.patch_size,
max_resolution=resource.max_resolution,
window_size=vitdet_window_size,
preferred_resolution=resource.preferred_resolution,
adaptors=adaptors,
feature_normalizer=feature_normalizer,
inter_feature_normalizer=inter_feature_normalizer,
)
if vitdet_window_size is not None:
apply_vitdet_arch(
mod,
VitDetArgs(
vitdet_window_size,
radio.num_summary_tokens,
num_windowed=resource.vitdet_num_windowed,
num_global=resource.vitdet_num_global,
),
)
if return_checkpoint:
return radio, chk
return radio
def get_prefix_state_dict(state_dict: Dict[str, Any], prefix: str):
mod_state_dict = {
k[len(prefix) :]: v for k, v in state_dict.items() if k.startswith(prefix)
}
return mod_state_dict