Skip to content

Commit

Permalink
[init]
Browse files Browse the repository at this point in the history
  • Loading branch information
keyu-tian committed Jun 23, 2024
1 parent af0a2a9 commit c08db77
Show file tree
Hide file tree
Showing 24 changed files with 4,402 additions and 0 deletions.
20 changes: 20 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
*.swp
**/__pycache__/**
**/.ipynb_checkpoints/**
.DS_Store
.idea/*
.vscode/*
llava/
_vis_cached/
_auto_*
ckpt/
log/
tb*/
img*/
local_output*
*.pth
*.pth.tar
*.ckpt
*.log
*.txt
*.ipynb
302 changes: 302 additions & 0 deletions dist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
import datetime
import functools
import os
import sys
from typing import List
from typing import Union

import pytz
import torch
import torch.distributed as tdist
import torch.multiprocessing as mp

__rank, __local_rank, __world_size, __device = 0, 0, 1, 'cuda' if torch.cuda.is_available() else 'cpu'
__rank_str_zfill = '0'
__initialized = False


def initialized():
return __initialized


def __initialize(fork=False, backend='nccl', gpu_id_if_not_distibuted=0, timeout_minutes=30):
global __device
if not torch.cuda.is_available():
print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr)
return
elif 'RANK' not in os.environ:
torch.cuda.set_device(gpu_id_if_not_distibuted)
__device = torch.empty(1).cuda().device
print(f'[dist initialize] env variable "RANK" is not set, use {__device} as the device', file=sys.stderr)
return
# then 'RANK' must exist
global_rank, num_gpus = int(os.environ['RANK']), torch.cuda.device_count()
local_rank = global_rank % num_gpus
torch.cuda.set_device(local_rank)

# ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29
if mp.get_start_method(allow_none=True) is None:
method = 'fork' if fork else 'spawn'
print(f'[dist initialize] mp method={method}')
mp.set_start_method(method)
tdist.init_process_group(backend=backend, timeout=datetime.timedelta(seconds=timeout_minutes * 60))

global __rank, __local_rank, __world_size, __initialized, __rank_str_zfill
__local_rank = local_rank
__rank, __world_size = tdist.get_rank(), tdist.get_world_size()
__rank_str_zfill = str(__rank).zfill(len(str(__world_size)))
__device = torch.empty(1).cuda().device
__initialized = True

assert tdist.is_initialized(), 'torch.distributed is not initialized!'
print(f'[lrk={get_local_rank()}, rk={get_rank()}]')


def get_rank():
return __rank


def get_rank_str_zfill():
return __rank_str_zfill


def get_local_rank():
return __local_rank


def get_world_size():
return __world_size


def get_device():
return __device


def set_gpu_id(gpu_id: int):
if gpu_id is None: return
global __device
if isinstance(gpu_id, (str, int)):
torch.cuda.set_device(int(gpu_id))
__device = torch.empty(1).cuda().device
else:
raise NotImplementedError


def is_master():
return __rank == 0


def is_local_master():
return __local_rank == 0


def new_group(ranks: List[int]):
if __initialized:
return tdist.new_group(ranks=ranks)
return None


def new_local_machine_group():
if __initialized:
cur_subgroup, subgroups = tdist.new_subgroups()
return cur_subgroup
return None


def barrier():
if __initialized:
tdist.barrier()


def allreduce(t: torch.Tensor, async_op=False):
if __initialized:
if not t.is_cuda:
cu = t.detach().cuda()
ret = tdist.all_reduce(cu, async_op=async_op)
t.copy_(cu.cpu())
else:
ret = tdist.all_reduce(t, async_op=async_op)
return ret
return None


def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
if __initialized:
if not t.is_cuda:
t = t.cuda()
ls = [torch.empty_like(t) for _ in range(__world_size)]
tdist.all_gather(ls, t)
else:
ls = [t]
if cat:
ls = torch.cat(ls, dim=0)
return ls


def allgather_diff_shape(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
if __initialized:
if not t.is_cuda:
t = t.cuda()

t_size = torch.tensor(t.size(), device=t.device)
ls_size = [torch.empty_like(t_size) for _ in range(__world_size)]
tdist.all_gather(ls_size, t_size)

max_B = max(size[0].item() for size in ls_size)
pad = max_B - t_size[0].item()
if pad:
pad_size = (pad, *t.size()[1:])
t = torch.cat((t, t.new_empty(pad_size)), dim=0)

ls_padded = [torch.empty_like(t) for _ in range(__world_size)]
tdist.all_gather(ls_padded, t)
ls = []
for t, size in zip(ls_padded, ls_size):
ls.append(t[:size[0].item()])
else:
ls = [t]
if cat:
ls = torch.cat(ls, dim=0)
return ls


def broadcast(t: torch.Tensor, src_rank) -> None:
if __initialized:
if not t.is_cuda:
cu = t.detach().cuda()
tdist.broadcast(cu, src=src_rank)
t.copy_(cu.cpu())
else:
tdist.broadcast(t, src=src_rank)


def dist_fmt_vals(val: float, fmt: Union[str, None] = '%.2f') -> Union[torch.Tensor, List]:
if not initialized():
return torch.tensor([val]) if fmt is None else [fmt % val]

ts = torch.zeros(__world_size)
ts[__rank] = val
allreduce(ts)
if fmt is None:
return ts
return [fmt % v for v in ts.cpu().numpy().tolist()]


def master_only(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
force = kwargs.pop('force', False)
if force or is_master():
ret = func(*args, **kwargs)
else:
ret = None
barrier()
return ret
return wrapper


def local_master_only(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
force = kwargs.pop('force', False)
if force or is_local_master():
ret = func(*args, **kwargs)
else:
ret = None
barrier()
return ret
return wrapper


def for_visualize(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if is_master():
# with torch.no_grad():
ret = func(*args, **kwargs)
else:
ret = None
return ret
return wrapper


def finalize():
if __initialized:
tdist.destroy_process_group()


def init_distributed_mode(local_out_path, only_sync_master=False, timeout_minutes=30):
try:
__initialize(fork=False, timeout_minutes=timeout_minutes)
barrier()
except RuntimeError as e:
print(f'{"!"*80} dist init error (NCCL Error?), stopping training! {"!"*80}', flush=True)
raise e

if local_out_path is not None: os.makedirs(local_out_path, exist_ok=True)
_change_builtin_print(is_local_master())
if (is_master() if only_sync_master else is_local_master()) and local_out_path is not None and len(local_out_path):
sys.stdout, sys.stderr = BackupStreamToFile(local_out_path, for_stdout=True), BackupStreamToFile(local_out_path, for_stdout=False)


def _change_builtin_print(is_master):
import builtins as __builtin__

builtin_print = __builtin__.print
if type(builtin_print) != type(open):
return

def prt(*args, **kwargs):
force = kwargs.pop('force', False)
clean = kwargs.pop('clean', False)
deeper = kwargs.pop('deeper', False)
if is_master or force:
if not clean:
f_back = sys._getframe().f_back
if deeper and f_back.f_back is not None:
f_back = f_back.f_back
file_desc = f'{f_back.f_code.co_filename:24s}'[-24:]
time_str = datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('[%m-%d %H:%M:%S]')
builtin_print(f'{time_str} ({file_desc}, line{f_back.f_lineno:-4d})=>', *args, **kwargs)
else:
builtin_print(*args, **kwargs)

__builtin__.print = prt


class BackupStreamToFile(object):
def __init__(self, local_output_dir, for_stdout=True):
self.for_stdout = for_stdout
self.terminal_stream = sys.stdout if for_stdout else sys.stderr
fname = os.path.join(local_output_dir, 'backup1_stdout.txt' if for_stdout else 'backup2_stderr.txt')
existing = os.path.exists(fname)
self.file_stream = open(fname, 'a')
if existing:
time_str = datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('[%m-%d %H:%M:%S]')
self.file_stream.write('\n'*7 + '='*55 + f' RESTART {time_str} ' + '='*55 + '\n')
self.file_stream.flush()
self.enabled = True

def write(self, message):
self.terminal_stream.write(message)
self.file_stream.write(message)

def flush(self):
self.terminal_stream.flush()
self.file_stream.flush()

def close(self):
if not self.enabled:
return
self.enabled = False
self.file_stream.flush()
self.file_stream.close()
if self.for_stdout:
sys.stdout = self.terminal_stream
sys.stdout.flush()
else:
sys.stderr = self.terminal_stream
sys.stderr.flush()

def __del__(self):
self.close()
70 changes: 70 additions & 0 deletions models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from typing import Tuple

import torch.nn as nn

from utils.arg_util import Args
from .quant import VectorQuantizer
from .vqvae import VQVAE
from .dino import DinoDisc
from .basic_vae import CNNEncoder


def build_vae_disc(args: Args) -> Tuple[VQVAE, DinoDisc]:
# disable built-in initialization for speed
for clz in (
nn.Linear, nn.Embedding,
nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d,
nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm, nn.GroupNorm, nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d,
):
setattr(clz, 'reset_parameters', lambda self: None)

# build models
vae = VQVAE(
grad_ckpt=args.vae_grad_ckpt,
vitamin=args.vae, drop_path_rate=args.drop_path,
ch=args.ch, ch_mult=(1, 1, 2, 2, 4), dropout=args.drop_out,
vocab_size=args.vocab_size, vocab_width=args.vocab_width, vocab_norm=args.vocab_norm, beta=args.vq_beta, quant_conv_k=3, quant_resi=-0.5,
).to(args.device)
disc = DinoDisc(
device=args.device, dino_ckpt_path=args.dino_path, depth=args.dino_depth, key_depths=(2, 5, 8, 11),
ks=args.dino_kernel_size, norm_type=args.disc_norm, using_spec_norm=args.disc_spec_norm, norm_eps=1e-6,
).to(args.device)

# init weights
need_init = [
vae.quant_conv,
vae.quantize,
vae.post_quant_conv,
vae.decoder,
]
if isinstance(vae.encoder, CNNEncoder):
need_init.insert(0, vae.encoder)
for vv in need_init:
init_weights(vv, args.vae_init)
init_weights(disc, args.disc_init)
vae.quantize.init_vocab(args.vocab_init)

return vae, disc


def init_weights(model, conv_std_or_gain):
print(f'[init_weights] {type(model).__name__} with {"std" if conv_std_or_gain > 0 else "gain"}={abs(conv_std_or_gain):g}')
for m in model.modules():
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight.data, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias.data, 0.)
elif isinstance(m, nn.Embedding):
nn.init.trunc_normal_(m.weight.data, std=0.02)
if m.padding_idx is not None:
m.weight.data[m.padding_idx].zero_()
elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
if conv_std_or_gain > 0:
nn.init.trunc_normal_(m.weight.data, std=conv_std_or_gain)
else:
nn.init.xavier_normal_(m.weight.data, gain=-conv_std_or_gain)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias.data, 0.)
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm, nn.GroupNorm, nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
if m.bias is not None: nn.init.constant_(m.bias.data, 0.)
if m.weight is not None: nn.init.constant_(m.weight.data, 1.)
Loading

0 comments on commit c08db77

Please sign in to comment.