forked from jy0205/LaVIT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
123 lines (89 loc) · 3.54 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import torch
from torch import nn
import torch.distributed as dist
import timm.models.hub as timm_hub
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def download_cached_file(url, check_hash=True, progress=False):
"""
Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
"""
def get_cached_file_path():
# a hack to sync the file path across processes
parts = torch.hub.urlparse(url)
filename = os.path.basename(parts.path)
cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
return cached_file
if is_main_process():
timm_hub.download_cached_file(url, check_hash, progress)
if is_dist_avail_and_initialized():
dist.barrier()
return get_cached_file_path()
def convert_weights_to_fp16(model: nn.Module):
"""Convert applicable model parameters to fp16"""
def _convert_weights_to_fp16(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
l.weight.data = l.weight.data.to(torch.float16)
if l.bias is not None:
l.bias.data = l.bias.data.to(torch.float16)
model.apply(_convert_weights_to_fp16)
def convert_weights_to_bf16(model: nn.Module):
"""Convert applicable model parameters to fp16"""
def _convert_weights_to_bf16(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
l.weight.data = l.weight.data.to(torch.bfloat16)
if l.bias is not None:
l.bias.data = l.bias.data.to(torch.bfloat16)
model.apply(_convert_weights_to_bf16)
def save_result(result, result_dir, filename, remove_duplicate=""):
import json
print("Dump result")
# Make the temp dir for saving results
if not os.path.exists(result_dir):
if is_main_process():
os.makedirs(result_dir)
if is_dist_avail_and_initialized():
torch.distributed.barrier()
result_file = os.path.join(
result_dir, "%s_rank%d.json" % (filename, get_rank())
)
final_result_file = os.path.join(result_dir, "%s.json" % filename)
json.dump(result, open(result_file, "w"))
if is_dist_avail_and_initialized():
torch.distributed.barrier()
if is_main_process():
print("rank %d starts merging results." % get_rank())
# combine results from all processes
result = []
for rank in range(get_world_size()):
result_file = os.path.join(result_dir, "%s_rank%d.json" % (filename, rank))
res = json.load(open(result_file, "r"))
result += res
print("Remove duplicate")
if remove_duplicate:
result_new = []
id_set = set()
for res in result:
if res[remove_duplicate] not in id_set:
id_set.add(res[remove_duplicate])
result_new.append(res)
result = result_new
json.dump(result, open(final_result_file, "w"))
print("result file saved to %s" % final_result_file)
return final_result_file