-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
24 changed files
with
4,402 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
*.swp | ||
**/__pycache__/** | ||
**/.ipynb_checkpoints/** | ||
.DS_Store | ||
.idea/* | ||
.vscode/* | ||
llava/ | ||
_vis_cached/ | ||
_auto_* | ||
ckpt/ | ||
log/ | ||
tb*/ | ||
img*/ | ||
local_output* | ||
*.pth | ||
*.pth.tar | ||
*.ckpt | ||
*.log | ||
*.txt | ||
*.ipynb |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,302 @@ | ||
import datetime | ||
import functools | ||
import os | ||
import sys | ||
from typing import List | ||
from typing import Union | ||
|
||
import pytz | ||
import torch | ||
import torch.distributed as tdist | ||
import torch.multiprocessing as mp | ||
|
||
__rank, __local_rank, __world_size, __device = 0, 0, 1, 'cuda' if torch.cuda.is_available() else 'cpu' | ||
__rank_str_zfill = '0' | ||
__initialized = False | ||
|
||
|
||
def initialized(): | ||
return __initialized | ||
|
||
|
||
def __initialize(fork=False, backend='nccl', gpu_id_if_not_distibuted=0, timeout_minutes=30): | ||
global __device | ||
if not torch.cuda.is_available(): | ||
print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr) | ||
return | ||
elif 'RANK' not in os.environ: | ||
torch.cuda.set_device(gpu_id_if_not_distibuted) | ||
__device = torch.empty(1).cuda().device | ||
print(f'[dist initialize] env variable "RANK" is not set, use {__device} as the device', file=sys.stderr) | ||
return | ||
# then 'RANK' must exist | ||
global_rank, num_gpus = int(os.environ['RANK']), torch.cuda.device_count() | ||
local_rank = global_rank % num_gpus | ||
torch.cuda.set_device(local_rank) | ||
|
||
# ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29 | ||
if mp.get_start_method(allow_none=True) is None: | ||
method = 'fork' if fork else 'spawn' | ||
print(f'[dist initialize] mp method={method}') | ||
mp.set_start_method(method) | ||
tdist.init_process_group(backend=backend, timeout=datetime.timedelta(seconds=timeout_minutes * 60)) | ||
|
||
global __rank, __local_rank, __world_size, __initialized, __rank_str_zfill | ||
__local_rank = local_rank | ||
__rank, __world_size = tdist.get_rank(), tdist.get_world_size() | ||
__rank_str_zfill = str(__rank).zfill(len(str(__world_size))) | ||
__device = torch.empty(1).cuda().device | ||
__initialized = True | ||
|
||
assert tdist.is_initialized(), 'torch.distributed is not initialized!' | ||
print(f'[lrk={get_local_rank()}, rk={get_rank()}]') | ||
|
||
|
||
def get_rank(): | ||
return __rank | ||
|
||
|
||
def get_rank_str_zfill(): | ||
return __rank_str_zfill | ||
|
||
|
||
def get_local_rank(): | ||
return __local_rank | ||
|
||
|
||
def get_world_size(): | ||
return __world_size | ||
|
||
|
||
def get_device(): | ||
return __device | ||
|
||
|
||
def set_gpu_id(gpu_id: int): | ||
if gpu_id is None: return | ||
global __device | ||
if isinstance(gpu_id, (str, int)): | ||
torch.cuda.set_device(int(gpu_id)) | ||
__device = torch.empty(1).cuda().device | ||
else: | ||
raise NotImplementedError | ||
|
||
|
||
def is_master(): | ||
return __rank == 0 | ||
|
||
|
||
def is_local_master(): | ||
return __local_rank == 0 | ||
|
||
|
||
def new_group(ranks: List[int]): | ||
if __initialized: | ||
return tdist.new_group(ranks=ranks) | ||
return None | ||
|
||
|
||
def new_local_machine_group(): | ||
if __initialized: | ||
cur_subgroup, subgroups = tdist.new_subgroups() | ||
return cur_subgroup | ||
return None | ||
|
||
|
||
def barrier(): | ||
if __initialized: | ||
tdist.barrier() | ||
|
||
|
||
def allreduce(t: torch.Tensor, async_op=False): | ||
if __initialized: | ||
if not t.is_cuda: | ||
cu = t.detach().cuda() | ||
ret = tdist.all_reduce(cu, async_op=async_op) | ||
t.copy_(cu.cpu()) | ||
else: | ||
ret = tdist.all_reduce(t, async_op=async_op) | ||
return ret | ||
return None | ||
|
||
|
||
def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]: | ||
if __initialized: | ||
if not t.is_cuda: | ||
t = t.cuda() | ||
ls = [torch.empty_like(t) for _ in range(__world_size)] | ||
tdist.all_gather(ls, t) | ||
else: | ||
ls = [t] | ||
if cat: | ||
ls = torch.cat(ls, dim=0) | ||
return ls | ||
|
||
|
||
def allgather_diff_shape(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]: | ||
if __initialized: | ||
if not t.is_cuda: | ||
t = t.cuda() | ||
|
||
t_size = torch.tensor(t.size(), device=t.device) | ||
ls_size = [torch.empty_like(t_size) for _ in range(__world_size)] | ||
tdist.all_gather(ls_size, t_size) | ||
|
||
max_B = max(size[0].item() for size in ls_size) | ||
pad = max_B - t_size[0].item() | ||
if pad: | ||
pad_size = (pad, *t.size()[1:]) | ||
t = torch.cat((t, t.new_empty(pad_size)), dim=0) | ||
|
||
ls_padded = [torch.empty_like(t) for _ in range(__world_size)] | ||
tdist.all_gather(ls_padded, t) | ||
ls = [] | ||
for t, size in zip(ls_padded, ls_size): | ||
ls.append(t[:size[0].item()]) | ||
else: | ||
ls = [t] | ||
if cat: | ||
ls = torch.cat(ls, dim=0) | ||
return ls | ||
|
||
|
||
def broadcast(t: torch.Tensor, src_rank) -> None: | ||
if __initialized: | ||
if not t.is_cuda: | ||
cu = t.detach().cuda() | ||
tdist.broadcast(cu, src=src_rank) | ||
t.copy_(cu.cpu()) | ||
else: | ||
tdist.broadcast(t, src=src_rank) | ||
|
||
|
||
def dist_fmt_vals(val: float, fmt: Union[str, None] = '%.2f') -> Union[torch.Tensor, List]: | ||
if not initialized(): | ||
return torch.tensor([val]) if fmt is None else [fmt % val] | ||
|
||
ts = torch.zeros(__world_size) | ||
ts[__rank] = val | ||
allreduce(ts) | ||
if fmt is None: | ||
return ts | ||
return [fmt % v for v in ts.cpu().numpy().tolist()] | ||
|
||
|
||
def master_only(func): | ||
@functools.wraps(func) | ||
def wrapper(*args, **kwargs): | ||
force = kwargs.pop('force', False) | ||
if force or is_master(): | ||
ret = func(*args, **kwargs) | ||
else: | ||
ret = None | ||
barrier() | ||
return ret | ||
return wrapper | ||
|
||
|
||
def local_master_only(func): | ||
@functools.wraps(func) | ||
def wrapper(*args, **kwargs): | ||
force = kwargs.pop('force', False) | ||
if force or is_local_master(): | ||
ret = func(*args, **kwargs) | ||
else: | ||
ret = None | ||
barrier() | ||
return ret | ||
return wrapper | ||
|
||
|
||
def for_visualize(func): | ||
@functools.wraps(func) | ||
def wrapper(*args, **kwargs): | ||
if is_master(): | ||
# with torch.no_grad(): | ||
ret = func(*args, **kwargs) | ||
else: | ||
ret = None | ||
return ret | ||
return wrapper | ||
|
||
|
||
def finalize(): | ||
if __initialized: | ||
tdist.destroy_process_group() | ||
|
||
|
||
def init_distributed_mode(local_out_path, only_sync_master=False, timeout_minutes=30): | ||
try: | ||
__initialize(fork=False, timeout_minutes=timeout_minutes) | ||
barrier() | ||
except RuntimeError as e: | ||
print(f'{"!"*80} dist init error (NCCL Error?), stopping training! {"!"*80}', flush=True) | ||
raise e | ||
|
||
if local_out_path is not None: os.makedirs(local_out_path, exist_ok=True) | ||
_change_builtin_print(is_local_master()) | ||
if (is_master() if only_sync_master else is_local_master()) and local_out_path is not None and len(local_out_path): | ||
sys.stdout, sys.stderr = BackupStreamToFile(local_out_path, for_stdout=True), BackupStreamToFile(local_out_path, for_stdout=False) | ||
|
||
|
||
def _change_builtin_print(is_master): | ||
import builtins as __builtin__ | ||
|
||
builtin_print = __builtin__.print | ||
if type(builtin_print) != type(open): | ||
return | ||
|
||
def prt(*args, **kwargs): | ||
force = kwargs.pop('force', False) | ||
clean = kwargs.pop('clean', False) | ||
deeper = kwargs.pop('deeper', False) | ||
if is_master or force: | ||
if not clean: | ||
f_back = sys._getframe().f_back | ||
if deeper and f_back.f_back is not None: | ||
f_back = f_back.f_back | ||
file_desc = f'{f_back.f_code.co_filename:24s}'[-24:] | ||
time_str = datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('[%m-%d %H:%M:%S]') | ||
builtin_print(f'{time_str} ({file_desc}, line{f_back.f_lineno:-4d})=>', *args, **kwargs) | ||
else: | ||
builtin_print(*args, **kwargs) | ||
|
||
__builtin__.print = prt | ||
|
||
|
||
class BackupStreamToFile(object): | ||
def __init__(self, local_output_dir, for_stdout=True): | ||
self.for_stdout = for_stdout | ||
self.terminal_stream = sys.stdout if for_stdout else sys.stderr | ||
fname = os.path.join(local_output_dir, 'backup1_stdout.txt' if for_stdout else 'backup2_stderr.txt') | ||
existing = os.path.exists(fname) | ||
self.file_stream = open(fname, 'a') | ||
if existing: | ||
time_str = datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('[%m-%d %H:%M:%S]') | ||
self.file_stream.write('\n'*7 + '='*55 + f' RESTART {time_str} ' + '='*55 + '\n') | ||
self.file_stream.flush() | ||
self.enabled = True | ||
|
||
def write(self, message): | ||
self.terminal_stream.write(message) | ||
self.file_stream.write(message) | ||
|
||
def flush(self): | ||
self.terminal_stream.flush() | ||
self.file_stream.flush() | ||
|
||
def close(self): | ||
if not self.enabled: | ||
return | ||
self.enabled = False | ||
self.file_stream.flush() | ||
self.file_stream.close() | ||
if self.for_stdout: | ||
sys.stdout = self.terminal_stream | ||
sys.stdout.flush() | ||
else: | ||
sys.stderr = self.terminal_stream | ||
sys.stderr.flush() | ||
|
||
def __del__(self): | ||
self.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
from typing import Tuple | ||
|
||
import torch.nn as nn | ||
|
||
from utils.arg_util import Args | ||
from .quant import VectorQuantizer | ||
from .vqvae import VQVAE | ||
from .dino import DinoDisc | ||
from .basic_vae import CNNEncoder | ||
|
||
|
||
def build_vae_disc(args: Args) -> Tuple[VQVAE, DinoDisc]: | ||
# disable built-in initialization for speed | ||
for clz in ( | ||
nn.Linear, nn.Embedding, | ||
nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d, | ||
nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm, nn.GroupNorm, nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d, | ||
): | ||
setattr(clz, 'reset_parameters', lambda self: None) | ||
|
||
# build models | ||
vae = VQVAE( | ||
grad_ckpt=args.vae_grad_ckpt, | ||
vitamin=args.vae, drop_path_rate=args.drop_path, | ||
ch=args.ch, ch_mult=(1, 1, 2, 2, 4), dropout=args.drop_out, | ||
vocab_size=args.vocab_size, vocab_width=args.vocab_width, vocab_norm=args.vocab_norm, beta=args.vq_beta, quant_conv_k=3, quant_resi=-0.5, | ||
).to(args.device) | ||
disc = DinoDisc( | ||
device=args.device, dino_ckpt_path=args.dino_path, depth=args.dino_depth, key_depths=(2, 5, 8, 11), | ||
ks=args.dino_kernel_size, norm_type=args.disc_norm, using_spec_norm=args.disc_spec_norm, norm_eps=1e-6, | ||
).to(args.device) | ||
|
||
# init weights | ||
need_init = [ | ||
vae.quant_conv, | ||
vae.quantize, | ||
vae.post_quant_conv, | ||
vae.decoder, | ||
] | ||
if isinstance(vae.encoder, CNNEncoder): | ||
need_init.insert(0, vae.encoder) | ||
for vv in need_init: | ||
init_weights(vv, args.vae_init) | ||
init_weights(disc, args.disc_init) | ||
vae.quantize.init_vocab(args.vocab_init) | ||
|
||
return vae, disc | ||
|
||
|
||
def init_weights(model, conv_std_or_gain): | ||
print(f'[init_weights] {type(model).__name__} with {"std" if conv_std_or_gain > 0 else "gain"}={abs(conv_std_or_gain):g}') | ||
for m in model.modules(): | ||
if isinstance(m, nn.Linear): | ||
nn.init.trunc_normal_(m.weight.data, std=0.02) | ||
if m.bias is not None: | ||
nn.init.constant_(m.bias.data, 0.) | ||
elif isinstance(m, nn.Embedding): | ||
nn.init.trunc_normal_(m.weight.data, std=0.02) | ||
if m.padding_idx is not None: | ||
m.weight.data[m.padding_idx].zero_() | ||
elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)): | ||
if conv_std_or_gain > 0: | ||
nn.init.trunc_normal_(m.weight.data, std=conv_std_or_gain) | ||
else: | ||
nn.init.xavier_normal_(m.weight.data, gain=-conv_std_or_gain) | ||
if hasattr(m, 'bias') and m.bias is not None: | ||
nn.init.constant_(m.bias.data, 0.) | ||
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm, nn.GroupNorm, nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)): | ||
if m.bias is not None: nn.init.constant_(m.bias.data, 0.) | ||
if m.weight is not None: nn.init.constant_(m.weight.data, 1.) |
Oops, something went wrong.