diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..da02bc4 --- /dev/null +++ b/.gitignore @@ -0,0 +1,20 @@ +*.swp +**/__pycache__/** +**/.ipynb_checkpoints/** +.DS_Store +.idea/* +.vscode/* +llava/ +_vis_cached/ +_auto_* +ckpt/ +log/ +tb*/ +img*/ +local_output* +*.pth +*.pth.tar +*.ckpt +*.log +*.txt +*.ipynb diff --git a/dist.py b/dist.py new file mode 100644 index 0000000..7e7b7b2 --- /dev/null +++ b/dist.py @@ -0,0 +1,302 @@ +import datetime +import functools +import os +import sys +from typing import List +from typing import Union + +import pytz +import torch +import torch.distributed as tdist +import torch.multiprocessing as mp + +__rank, __local_rank, __world_size, __device = 0, 0, 1, 'cuda' if torch.cuda.is_available() else 'cpu' +__rank_str_zfill = '0' +__initialized = False + + +def initialized(): + return __initialized + + +def __initialize(fork=False, backend='nccl', gpu_id_if_not_distibuted=0, timeout_minutes=30): + global __device + if not torch.cuda.is_available(): + print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr) + return + elif 'RANK' not in os.environ: + torch.cuda.set_device(gpu_id_if_not_distibuted) + __device = torch.empty(1).cuda().device + print(f'[dist initialize] env variable "RANK" is not set, use {__device} as the device', file=sys.stderr) + return + # then 'RANK' must exist + global_rank, num_gpus = int(os.environ['RANK']), torch.cuda.device_count() + local_rank = global_rank % num_gpus + torch.cuda.set_device(local_rank) + + # ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29 + if mp.get_start_method(allow_none=True) is None: + method = 'fork' if fork else 'spawn' + print(f'[dist initialize] mp method={method}') + mp.set_start_method(method) + tdist.init_process_group(backend=backend, timeout=datetime.timedelta(seconds=timeout_minutes * 60)) + + global __rank, __local_rank, __world_size, __initialized, __rank_str_zfill + __local_rank = local_rank + __rank, __world_size = tdist.get_rank(), tdist.get_world_size() + __rank_str_zfill = str(__rank).zfill(len(str(__world_size))) + __device = torch.empty(1).cuda().device + __initialized = True + + assert tdist.is_initialized(), 'torch.distributed is not initialized!' + print(f'[lrk={get_local_rank()}, rk={get_rank()}]') + + +def get_rank(): + return __rank + + +def get_rank_str_zfill(): + return __rank_str_zfill + + +def get_local_rank(): + return __local_rank + + +def get_world_size(): + return __world_size + + +def get_device(): + return __device + + +def set_gpu_id(gpu_id: int): + if gpu_id is None: return + global __device + if isinstance(gpu_id, (str, int)): + torch.cuda.set_device(int(gpu_id)) + __device = torch.empty(1).cuda().device + else: + raise NotImplementedError + + +def is_master(): + return __rank == 0 + + +def is_local_master(): + return __local_rank == 0 + + +def new_group(ranks: List[int]): + if __initialized: + return tdist.new_group(ranks=ranks) + return None + + +def new_local_machine_group(): + if __initialized: + cur_subgroup, subgroups = tdist.new_subgroups() + return cur_subgroup + return None + + +def barrier(): + if __initialized: + tdist.barrier() + + +def allreduce(t: torch.Tensor, async_op=False): + if __initialized: + if not t.is_cuda: + cu = t.detach().cuda() + ret = tdist.all_reduce(cu, async_op=async_op) + t.copy_(cu.cpu()) + else: + ret = tdist.all_reduce(t, async_op=async_op) + return ret + return None + + +def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]: + if __initialized: + if not t.is_cuda: + t = t.cuda() + ls = [torch.empty_like(t) for _ in range(__world_size)] + tdist.all_gather(ls, t) + else: + ls = [t] + if cat: + ls = torch.cat(ls, dim=0) + return ls + + +def allgather_diff_shape(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]: + if __initialized: + if not t.is_cuda: + t = t.cuda() + + t_size = torch.tensor(t.size(), device=t.device) + ls_size = [torch.empty_like(t_size) for _ in range(__world_size)] + tdist.all_gather(ls_size, t_size) + + max_B = max(size[0].item() for size in ls_size) + pad = max_B - t_size[0].item() + if pad: + pad_size = (pad, *t.size()[1:]) + t = torch.cat((t, t.new_empty(pad_size)), dim=0) + + ls_padded = [torch.empty_like(t) for _ in range(__world_size)] + tdist.all_gather(ls_padded, t) + ls = [] + for t, size in zip(ls_padded, ls_size): + ls.append(t[:size[0].item()]) + else: + ls = [t] + if cat: + ls = torch.cat(ls, dim=0) + return ls + + +def broadcast(t: torch.Tensor, src_rank) -> None: + if __initialized: + if not t.is_cuda: + cu = t.detach().cuda() + tdist.broadcast(cu, src=src_rank) + t.copy_(cu.cpu()) + else: + tdist.broadcast(t, src=src_rank) + + +def dist_fmt_vals(val: float, fmt: Union[str, None] = '%.2f') -> Union[torch.Tensor, List]: + if not initialized(): + return torch.tensor([val]) if fmt is None else [fmt % val] + + ts = torch.zeros(__world_size) + ts[__rank] = val + allreduce(ts) + if fmt is None: + return ts + return [fmt % v for v in ts.cpu().numpy().tolist()] + + +def master_only(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + force = kwargs.pop('force', False) + if force or is_master(): + ret = func(*args, **kwargs) + else: + ret = None + barrier() + return ret + return wrapper + + +def local_master_only(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + force = kwargs.pop('force', False) + if force or is_local_master(): + ret = func(*args, **kwargs) + else: + ret = None + barrier() + return ret + return wrapper + + +def for_visualize(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if is_master(): + # with torch.no_grad(): + ret = func(*args, **kwargs) + else: + ret = None + return ret + return wrapper + + +def finalize(): + if __initialized: + tdist.destroy_process_group() + + +def init_distributed_mode(local_out_path, only_sync_master=False, timeout_minutes=30): + try: + __initialize(fork=False, timeout_minutes=timeout_minutes) + barrier() + except RuntimeError as e: + print(f'{"!"*80} dist init error (NCCL Error?), stopping training! {"!"*80}', flush=True) + raise e + + if local_out_path is not None: os.makedirs(local_out_path, exist_ok=True) + _change_builtin_print(is_local_master()) + if (is_master() if only_sync_master else is_local_master()) and local_out_path is not None and len(local_out_path): + sys.stdout, sys.stderr = BackupStreamToFile(local_out_path, for_stdout=True), BackupStreamToFile(local_out_path, for_stdout=False) + + +def _change_builtin_print(is_master): + import builtins as __builtin__ + + builtin_print = __builtin__.print + if type(builtin_print) != type(open): + return + + def prt(*args, **kwargs): + force = kwargs.pop('force', False) + clean = kwargs.pop('clean', False) + deeper = kwargs.pop('deeper', False) + if is_master or force: + if not clean: + f_back = sys._getframe().f_back + if deeper and f_back.f_back is not None: + f_back = f_back.f_back + file_desc = f'{f_back.f_code.co_filename:24s}'[-24:] + time_str = datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('[%m-%d %H:%M:%S]') + builtin_print(f'{time_str} ({file_desc}, line{f_back.f_lineno:-4d})=>', *args, **kwargs) + else: + builtin_print(*args, **kwargs) + + __builtin__.print = prt + + +class BackupStreamToFile(object): + def __init__(self, local_output_dir, for_stdout=True): + self.for_stdout = for_stdout + self.terminal_stream = sys.stdout if for_stdout else sys.stderr + fname = os.path.join(local_output_dir, 'backup1_stdout.txt' if for_stdout else 'backup2_stderr.txt') + existing = os.path.exists(fname) + self.file_stream = open(fname, 'a') + if existing: + time_str = datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime('[%m-%d %H:%M:%S]') + self.file_stream.write('\n'*7 + '='*55 + f' RESTART {time_str} ' + '='*55 + '\n') + self.file_stream.flush() + self.enabled = True + + def write(self, message): + self.terminal_stream.write(message) + self.file_stream.write(message) + + def flush(self): + self.terminal_stream.flush() + self.file_stream.flush() + + def close(self): + if not self.enabled: + return + self.enabled = False + self.file_stream.flush() + self.file_stream.close() + if self.for_stdout: + sys.stdout = self.terminal_stream + sys.stdout.flush() + else: + sys.stderr = self.terminal_stream + sys.stderr.flush() + + def __del__(self): + self.close() diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..52fe562 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,70 @@ +from typing import Tuple + +import torch.nn as nn + +from utils.arg_util import Args +from .quant import VectorQuantizer +from .vqvae import VQVAE +from .dino import DinoDisc +from .basic_vae import CNNEncoder + + +def build_vae_disc(args: Args) -> Tuple[VQVAE, DinoDisc]: + # disable built-in initialization for speed + for clz in ( + nn.Linear, nn.Embedding, + nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d, + nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm, nn.GroupNorm, nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d, + ): + setattr(clz, 'reset_parameters', lambda self: None) + + # build models + vae = VQVAE( + grad_ckpt=args.vae_grad_ckpt, + vitamin=args.vae, drop_path_rate=args.drop_path, + ch=args.ch, ch_mult=(1, 1, 2, 2, 4), dropout=args.drop_out, + vocab_size=args.vocab_size, vocab_width=args.vocab_width, vocab_norm=args.vocab_norm, beta=args.vq_beta, quant_conv_k=3, quant_resi=-0.5, + ).to(args.device) + disc = DinoDisc( + device=args.device, dino_ckpt_path=args.dino_path, depth=args.dino_depth, key_depths=(2, 5, 8, 11), + ks=args.dino_kernel_size, norm_type=args.disc_norm, using_spec_norm=args.disc_spec_norm, norm_eps=1e-6, + ).to(args.device) + + # init weights + need_init = [ + vae.quant_conv, + vae.quantize, + vae.post_quant_conv, + vae.decoder, + ] + if isinstance(vae.encoder, CNNEncoder): + need_init.insert(0, vae.encoder) + for vv in need_init: + init_weights(vv, args.vae_init) + init_weights(disc, args.disc_init) + vae.quantize.init_vocab(args.vocab_init) + + return vae, disc + + +def init_weights(model, conv_std_or_gain): + print(f'[init_weights] {type(model).__name__} with {"std" if conv_std_or_gain > 0 else "gain"}={abs(conv_std_or_gain):g}') + for m in model.modules(): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight.data, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias.data, 0.) + elif isinstance(m, nn.Embedding): + nn.init.trunc_normal_(m.weight.data, std=0.02) + if m.padding_idx is not None: + m.weight.data[m.padding_idx].zero_() + elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)): + if conv_std_or_gain > 0: + nn.init.trunc_normal_(m.weight.data, std=conv_std_or_gain) + else: + nn.init.xavier_normal_(m.weight.data, gain=-conv_std_or_gain) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias.data, 0.) + elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm, nn.GroupNorm, nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)): + if m.bias is not None: nn.init.constant_(m.bias.data, 0.) + if m.weight is not None: nn.init.constant_(m.weight.data, 1.) diff --git a/models/basic_vae.py b/models/basic_vae.py new file mode 100644 index 0000000..3d2821c --- /dev/null +++ b/models/basic_vae.py @@ -0,0 +1,254 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint + + +# this file only provides the 2 modules used in VQVAE +__all__ = ['CNNEncoder', 'CNNDecoder', ] + + +""" +References: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/diffusionmodules/model.py +""" +# swish +def nonlinearity(x): + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample2x(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + return self.conv(F.interpolate(x, scale_factor=2, mode='nearest')) + + +class Downsample2x(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x): + return self.conv(F.pad(x, pad=(0, 1, 0, 1), mode='constant', value=0)) + + +class BnActConvBnActConv(nn.Module): + def __init__(self, *, in_channels, out_channels=None, dropout): # conv_shortcut=False, # conv_shortcut: always False in VAE + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout, inplace=True) if dropout > 1e-6 else nn.Identity() + self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + else: + self.nin_shortcut = nn.Identity() + + def forward(self, x): + h = self.conv1(F.silu(self.norm1(x), inplace=True)) + h = self.conv2(self.dropout(F.silu(self.norm2(h), inplace=True))) + return self.nin_shortcut(x) + h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.C = in_channels + + self.norm = Normalize(in_channels) + self.qkv = torch.nn.Conv2d(in_channels, 3*in_channels, kernel_size=1, stride=1, padding=0) + self.w_ratio = int(in_channels) ** (-0.5) + self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + qkv = self.qkv(self.norm(x)) + B, _, H, W = qkv.shape # should be B,3C,H,W + C = self.C + q, k, v = qkv.reshape(B, 3, C, H, W).unbind(1) + + # compute attention + q = q.view(B, C, H * W).contiguous() + q = q.permute(0, 2, 1).contiguous() # B,HW,C + k = k.view(B, C, H * W).contiguous() # B,C,HW + w = torch.bmm(q, k).mul_(self.w_ratio) # B,HW,HW w[B,i,j]=sum_c q[B,i,C]k[B,C,j] + w = F.softmax(w, dim=2) + + # attend to values + v = v.view(B, C, H * W).contiguous() + w = w.permute(0, 2, 1).contiguous() # B,HW,HW (first HW of k, second of q) + h = torch.bmm(v, w) # B, C,HW (HW of q) h[B,C,j] = sum_i v[B,C,i] w[B,i,j] + h = h.view(B, C, H, W).contiguous() + + return x + self.proj_out(h) + + +def make_attn(in_channels, using_sa=True): + return AttnBlock(in_channels) if using_sa else nn.Identity() + + +class CNNEncoder(nn.Module): + def __init__( + self, *, ch=128, ch_mult=(1, 1, 2, 2, 4), num_res_blocks=2, dropout=0.0, + img_channels=3, output_channels=32, using_sa=True, using_mid_sa=True, + grad_ckpt=False, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.downsample_ratio = 2 ** (self.num_resolutions - 1) + self.num_res_blocks = num_res_blocks + self.grad_ckpt = grad_ckpt + + # downsampling + self.conv_in = torch.nn.Conv2d(img_channels, self.ch, kernel_size=3, stride=1, padding=1) + + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(BnActConvBnActConv(in_channels=block_in, out_channels=block_out, dropout=dropout)) + block_in = block_out + if i_level == self.num_resolutions - 1 and using_sa: + attn.append(make_attn(block_in, using_sa=True)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample2x(block_in) + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = BnActConvBnActConv(in_channels=block_in, out_channels=block_in, dropout=dropout) + self.mid.attn_1 = make_attn(block_in, using_sa=using_mid_sa) + self.mid.block_2 = BnActConvBnActConv(in_channels=block_in, out_channels=block_in, dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, output_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + h = self.conv_in(x) + if not self.grad_ckpt or not self.training: + # downsampling + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](h) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + if i_level != self.num_resolutions - 1: + h = self.down[i_level].downsample(h) + # middle + h = self.mid.block_2(self.mid.attn_1(self.mid.block_1(h))) + # end + h = self.conv_out(F.silu(self.norm_out(h), inplace=True)) + else: + # downsampling + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = checkpoint(self.down[i_level].block[i_block], h, use_reentrant=False) + if len(self.down[i_level].attn) > 0: + h = checkpoint(self.down[i_level].attn[i_block], h, use_reentrant=False) + if i_level != self.num_resolutions - 1: + h = checkpoint(self.down[i_level].downsample, h, use_reentrant=False) + # middle + h = checkpoint(self.mid.block_1, h, use_reentrant=False) + h = checkpoint(self.mid.attn_1, h, use_reentrant=False) + h = checkpoint(self.mid.block_2, h, use_reentrant=False) + # end + h = F.silu(self.norm_out(h), inplace=True) + h = checkpoint(self.conv_out, h, use_reentrant=False) + + return h + + +class CNNDecoder(nn.Module): + def __init__( + self, *, ch=128, ch_mult=(1, 1, 2, 2, 4), num_res_blocks=3, dropout=0.0, + img_channels=3, input_channels=32, using_sa=True, using_mid_sa=True, + grad_ckpt=False, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.grad_ckpt = grad_ckpt + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,) + tuple(ch_mult) + block_in = ch * ch_mult[-1] + + # z to block_in + self.conv_in = torch.nn.Conv2d(input_channels, block_in, kernel_size=3, stride=1, padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = BnActConvBnActConv(in_channels=block_in, out_channels=block_in, dropout=dropout) + self.mid.attn_1 = make_attn(block_in, using_sa=using_mid_sa) + self.mid.block_2 = BnActConvBnActConv(in_channels=block_in, out_channels=block_in, dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(BnActConvBnActConv(in_channels=block_in, out_channels=block_out, dropout=dropout)) + block_in = block_out + if i_level == self.num_resolutions-1 and using_sa: + attn.append(make_attn(block_in, using_sa=True)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample2x(block_in) + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, img_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, z): + if not self.grad_ckpt or not self.training: + # z to block_in and middle + h = self.mid.block_2(self.mid.attn_1(self.mid.block_1(self.conv_in(z)))) + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + else: + # z to block_in and middle + h = checkpoint(self.conv_in, z, use_reentrant=False) + h = checkpoint(self.mid.block_1, h, use_reentrant=False) + h = checkpoint(self.mid.attn_1, h, use_reentrant=False) + h = checkpoint(self.mid.block_2, h, use_reentrant=False) + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks): + h = checkpoint(self.up[i_level].block[i_block], h, use_reentrant=False) + if len(self.up[i_level].attn) > 0: + h = checkpoint(self.up[i_level].attn[i_block], h, use_reentrant=False) + if i_level != 0: + h = checkpoint(self.up[i_level].upsample, h, use_reentrant=False) + + return self.conv_out(F.silu(self.norm_out(h), inplace=True)) diff --git a/models/basic_var.py b/models/basic_var.py new file mode 100644 index 0000000..84afa3e --- /dev/null +++ b/models/basic_var.py @@ -0,0 +1,174 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from models.helpers import DropPath, drop_path + + +# this file only provides the 3 blocks used in VAR transformer +__all__ = ['FFN', 'AdaLNSelfAttn', 'AdaLNBeforeHead'] + + +# automatically import fused operators +dropout_add_layer_norm = fused_mlp_func = memory_efficient_attention = flash_attn_func = None +try: + from flash_attn.ops.layer_norm import dropout_add_layer_norm + from flash_attn.ops.fused_dense import fused_mlp_func +except ImportError: pass +# automatically import faster attention implementations +try: from xformers.ops import memory_efficient_attention +except ImportError: pass +try: from flash_attn import flash_attn_func # qkv: BLHc, ret: BLHcq +except ImportError: pass +try: from torch.nn.functional import scaled_dot_product_attention as slow_attn # q, k, v: BHLc +except ImportError: + def slow_attn(query, key, value, scale: float, attn_mask=None, dropout_p=0.0): + attn = query.mul(scale) @ key.transpose(-2, -1) # BHLc @ BHcL => BHLL + if attn_mask is not None: attn.add_(attn_mask) + return (F.dropout(attn.softmax(dim=-1), p=dropout_p, inplace=True) if dropout_p > 0 else attn.softmax(dim=-1)) @ value + + +class FFN(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, drop=0., fused_if_available=True): + super().__init__() + self.fused_mlp_func = fused_mlp_func if fused_if_available else None + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = nn.GELU(approximate='tanh') + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop, inplace=True) if drop > 0 else nn.Identity() + + def forward(self, x): + if self.fused_mlp_func is not None: + return self.drop(self.fused_mlp_func( + x=x, weight1=self.fc1.weight, weight2=self.fc2.weight, bias1=self.fc1.bias, bias2=self.fc2.bias, + activation='gelu_approx', save_pre_act=self.training, return_residual=False, checkpoint_lvl=0, + heuristic=0, process_group=None, + )) + else: + return self.drop(self.fc2( self.act(self.fc1(x)) )) + + def extra_repr(self) -> str: + return f'fused_mlp_func={self.fused_mlp_func is not None}' + + +class SelfAttention(nn.Module): + def __init__( + self, block_idx, embed_dim=768, num_heads=12, + attn_drop=0., proj_drop=0., attn_l2_norm=False, flash_if_available=True, + ): + super().__init__() + assert embed_dim % num_heads == 0 + self.block_idx, self.num_heads, self.head_dim = block_idx, num_heads, embed_dim // num_heads # =64 + self.attn_l2_norm = attn_l2_norm + if self.attn_l2_norm: + self.scale = 1 + self.scale_mul_1H11 = nn.Parameter(torch.full(size=(1, self.num_heads, 1, 1), fill_value=4.0).log(), requires_grad=True) + self.max_scale_mul = torch.log(torch.tensor(100)).item() + else: + self.scale = 0.25 / math.sqrt(self.head_dim) + + self.mat_qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False) + self.q_bias, self.v_bias = nn.Parameter(torch.zeros(embed_dim)), nn.Parameter(torch.zeros(embed_dim)) + self.register_buffer('zero_k_bias', torch.zeros(embed_dim)) + + self.proj = nn.Linear(embed_dim, embed_dim) + self.proj_drop = nn.Dropout(proj_drop, inplace=True) if proj_drop > 0 else nn.Identity() + self.attn_drop: float = attn_drop + self.using_flash = flash_if_available and flash_attn_func is not None + self.using_xform = flash_if_available and memory_efficient_attention is not None + + # only used during inference + self.caching, self.cached_k, self.cached_v = False, None, None + + def kv_caching(self, enable: bool): self.caching, self.cached_k, self.cached_v = enable, None, None + + # NOTE: attn_bias is None during inference because kv cache is enabled + def forward(self, x, attn_bias): + B, L, C = x.shape + + qkv = F.linear(input=x, weight=self.mat_qkv.weight, bias=torch.cat((self.q_bias, self.zero_k_bias, self.v_bias))).view(B, L, 3, self.num_heads, self.head_dim) + main_type = qkv.dtype + # qkv: BL3Hc + + using_flash = self.using_flash and attn_bias is None and qkv.dtype != torch.float32 + if using_flash or self.using_xform: q, k, v = qkv.unbind(dim=2); dim_cat = 1 # q or k or v: BLHc + else: q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0); dim_cat = 2 # q or k or v: BHLc + + if self.attn_l2_norm: + scale_mul = self.scale_mul_1H11.clamp_max(self.max_scale_mul).exp() + if using_flash or self.using_xform: scale_mul = scale_mul.transpose(1, 2) # 1H11 to 11H1 + q = F.normalize(q, dim=-1).mul(scale_mul) + k = F.normalize(k, dim=-1) + + if self.caching: + if self.cached_k is None: self.cached_k = k; self.cached_v = v + else: k = self.cached_k = torch.cat((self.cached_k, k), dim=dim_cat); v = self.cached_v = torch.cat((self.cached_v, v), dim=dim_cat) + + dropout_p = self.attn_drop if self.training else 0.0 + if using_flash: + oup = flash_attn_func(q.to(dtype=main_type), k.to(dtype=main_type), v.to(dtype=main_type), dropout_p=dropout_p, softmax_scale=self.scale).view(B, L, C) + elif self.using_xform: + oup = memory_efficient_attention(q.to(dtype=main_type), k.to(dtype=main_type), v.to(dtype=main_type), attn_bias=None if attn_bias is None else attn_bias.to(dtype=main_type).expand(B, self.num_heads, -1, -1), p=dropout_p, scale=self.scale).view(B, L, C) + else: + oup = slow_attn(query=q, key=k, value=v, scale=self.scale, attn_mask=attn_bias, dropout_p=dropout_p).transpose(1, 2).reshape(B, L, C) + + return self.proj_drop(self.proj(oup)) + # attn = (q @ k.transpose(-2, -1)).add_(attn_bias + self.local_rpb()) # BHLc @ BHcL => BHLL + # attn = self.attn_drop(attn.softmax(dim=-1)) + # oup = (attn @ v).transpose_(1, 2).reshape(B, L, -1) # BHLL @ BHLc = BHLc => BLHc => BLC + + def extra_repr(self) -> str: + return f'using_flash={self.using_flash}, using_xform={self.using_xform}, attn_l2_norm={self.attn_l2_norm}' + + +class AdaLNSelfAttn(nn.Module): + def __init__( + self, block_idx, last_drop_p, embed_dim, cond_dim, shared_aln: bool, norm_layer, + num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0., attn_l2_norm=False, + flash_if_available=False, fused_if_available=True, + ): + super(AdaLNSelfAttn, self).__init__() + self.block_idx, self.last_drop_p, self.C = block_idx, last_drop_p, embed_dim + self.C, self.D = embed_dim, cond_dim + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.attn = SelfAttention(block_idx=block_idx, embed_dim=embed_dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=drop, attn_l2_norm=attn_l2_norm, flash_if_available=flash_if_available) + self.ffn = FFN(in_features=embed_dim, hidden_features=round(embed_dim * mlp_ratio), drop=drop, fused_if_available=fused_if_available) + + self.ln_wo_grad = norm_layer(embed_dim, elementwise_affine=False) + self.shared_aln = shared_aln + if self.shared_aln: + self.ada_gss = nn.Parameter(torch.randn(1, 1, 6, embed_dim) / embed_dim**0.5) + else: + lin = nn.Linear(cond_dim, 6*embed_dim) + self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), lin) + + self.fused_add_norm_fn = None + + # NOTE: attn_bias is None during inference because kv cache is enabled + def forward(self, x, cond_BD, attn_bias): # C: embed_dim, D: cond_dim + if self.shared_aln: + gamma1, gamma2, scale1, scale2, shift1, shift2 = (self.ada_gss + cond_BD).unbind(2) # 116C + B16C =unbind(2)=> 6 B1C + else: + gamma1, gamma2, scale1, scale2, shift1, shift2 = self.ada_lin(cond_BD).view(-1, 1, 6, self.C).unbind(2) + x = x + self.drop_path(self.attn( self.ln_wo_grad(x).mul(scale1.add(1)).add_(shift1), attn_bias=attn_bias ).mul_(gamma1)) + x = x + self.drop_path(self.ffn( self.ln_wo_grad(x).mul(scale2.add(1)).add_(shift2) ).mul(gamma2)) # this mul(gamma2) cannot be in-placed when FusedMLP is used + return x + + def extra_repr(self) -> str: + return f'shared_aln={self.shared_aln}' + + +class AdaLNBeforeHead(nn.Module): + def __init__(self, C, D, norm_layer): # C: embed_dim, D: cond_dim + super().__init__() + self.C, self.D = C, D + self.ln_wo_grad = norm_layer(C, elementwise_affine=False) + self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), nn.Linear(D, 2*C)) + + def forward(self, x_BLC: torch.Tensor, cond_BD: torch.Tensor): + scale, shift = self.ada_lin(cond_BD).view(-1, 1, 2, self.C).unbind(2) + return self.ln_wo_grad(x_BLC).mul(scale.add(1)).add_(shift) diff --git a/models/dino.py b/models/dino.py new file mode 100644 index 0000000..db83560 --- /dev/null +++ b/models/dino.py @@ -0,0 +1,387 @@ +import math +import os.path +import random +from typing import List, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils.spectral_norm import SpectralNorm +from torchvision.transforms import RandomCrop + +import dist + +try: + from flash_attn.ops.layer_norm import dropout_add_layer_norm + from flash_attn.ops.fused_dense import fused_mlp_func +except: + dropout_add_layer_norm = fused_mlp_func = None + +try: + from flash_attn import flash_attn_qkvpacked_func # qkv: BL3Hc, ret: BLHcq +except: + flash_attn_qkvpacked_func = None + +try: + assert torch.cuda.is_available() + from torch.nn.functional import scaled_dot_product_attention as slow_attn # q, k, v: BHLc +except: + def slow_attn(query, key, value, scale: float, attn_mask=None, dropout_p=0.0): + attn = query.mul(scale) @ key.transpose(-2, -1) # BHLc @ BHcL => BHLL + if attn_mask is not None: attn.add_(attn_mask) + return (F.dropout(attn.softmax(dim=-1), p=dropout_p, inplace=True) if dropout_p > 0 else attn.softmax(dim=-1)) @ value + + +class MLPNoDrop(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, fused_if_available=True): + super().__init__() + self.fused_mlp_func = fused_mlp_func if (torch.cuda.is_available() and fused_if_available) else None + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = nn.GELU(approximate='tanh') + self.fc2 = nn.Linear(hidden_features, out_features) + + def forward(self, x): + if self.fused_mlp_func is not None: + return self.fused_mlp_func( + x=x, + weight1=self.fc1.weight, + weight2=self.fc2.weight, + bias1=self.fc1.bias, + bias2=self.fc2.bias, + activation='gelu_approx', + save_pre_act=self.training, + return_residual=False, + checkpoint_lvl=0, + heuristic=0, + process_group=None, + ) + else: + return self.fc2(self.act(self.fc1(x))) + + def extra_repr(self) -> str: + return f'fused_mlp_func={self.fused_mlp_func is not None}' + + +class SelfAttentionNoDrop(nn.Module): + def __init__( + self, block_idx, embed_dim=768, num_heads=12, flash_if_available=True, + ): + super().__init__() + assert embed_dim % num_heads == 0 + self.block_idx, self.num_heads, self.head_dim = block_idx, num_heads, embed_dim // num_heads # =64 + self.scale = 1 / math.sqrt(self.head_dim) + self.qkv, self.proj = nn.Linear(embed_dim, embed_dim * 3, bias=True), nn.Linear(embed_dim, embed_dim, bias=True) + self.using_flash_attn = torch.cuda.is_available() and flash_if_available and flash_attn_qkvpacked_func is not None + + def forward(self, x): + B, L, C = x.shape + qkv = self.qkv(x).view(B, L, 3, self.num_heads, self.head_dim) + if self.using_flash_attn and qkv.dtype != torch.float32: + oup = flash_attn_qkvpacked_func(qkv, softmax_scale=self.scale).view(B, L, C) + else: + q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0) # BHLc + oup = slow_attn(query=q, key=k, value=v, scale=self.scale).transpose(1, 2).reshape(B, L, C) + return self.proj(oup) + + def extra_repr(self) -> str: + return f'using_flash_attn={self.using_flash_attn}' + +class SABlockNoDrop(nn.Module): + def __init__(self, block_idx, embed_dim, num_heads, mlp_ratio, norm_eps): + super(SABlockNoDrop, self).__init__() + self.norm1 = nn.LayerNorm(embed_dim, eps=norm_eps) + self.attn = SelfAttentionNoDrop(block_idx=block_idx, embed_dim=embed_dim, num_heads=num_heads, flash_if_available=True) + self.norm2 = nn.LayerNorm(embed_dim, eps=norm_eps) + self.mlp = MLPNoDrop(in_features=embed_dim, hidden_features=round(embed_dim * mlp_ratio), fused_if_available=True) + + def forward(self, x): + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + + +class ResidualBlock(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + self.ratio = 1 / np.sqrt(2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x = x.float() + return (self.fn(x).add(x)).mul_(self.ratio) + + +class SpectralConv1d(nn.Conv1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + SpectralNorm.apply(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12) + + +class BatchNormLocal(nn.Module): + def __init__(self, num_features: int, affine: bool = True, virtual_bs: int = 8, eps: float = 1e-6): + super().__init__() + self.virtual_bs = virtual_bs + self.eps = eps + self.affine = affine + + if self.affine: + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shape = x.size() + x = x.float() + + # Reshape batch into groups. + G = np.ceil(x.size(0) / self.virtual_bs).astype(int) + x = x.view(G, -1, x.size(-2), x.size(-1)) + + # Calculate stats. + mean = x.mean([1, 3], keepdim=True) + var = x.var([1, 3], keepdim=True, unbiased=False) + x = (x - mean) / (torch.sqrt(var + self.eps)) + + if self.affine: + x = x * self.weight[None, :, None] + self.bias[None, :, None] + + return x.view(shape) + + +def make_block(channels: int, kernel_size: int, norm_type: str, norm_eps: float, using_spec_norm: bool) -> nn.Module: + if norm_type == 'bn': norm = BatchNormLocal(channels, eps=norm_eps) + elif norm_type == 'sbn': norm = nn.SyncBatchNorm(channels, eps=norm_eps, process_group=None) + elif norm_type in {'lbn', 'hbn'}: norm = nn.SyncBatchNorm(channels, eps=norm_eps, process_group=dist.new_local_machine_group()) + elif norm_type == 'gn': norm = nn.GroupNorm(num_groups=32, num_channels=channels, eps=norm_eps, affine=True) + else: raise NotImplementedError + + return nn.Sequential( + (SpectralConv1d if using_spec_norm else nn.Conv1d)(channels, channels, kernel_size=kernel_size, padding=kernel_size // 2, padding_mode='circular'), + norm, + nn.LeakyReLU(negative_slope=0.2, inplace=True), + ) + + +class DinoDisc(nn.Module): + def __init__(self, device, dino_ckpt_path, ks, depth=12, key_depths=(2, 5, 8, 11), norm_type='bn', using_spec_norm=True, norm_eps=1e-6): + super().__init__() + # load state + state = torch.load(dino_ckpt_path, 'cpu') + for k in sorted(state.keys()): + if '.attn.qkv.bias' in k: + bias = state[k] + C = bias.numel() // 3 + bias[C:2*C].zero_() # zero out k_bias + # build DINO + key_depths = tuple(d for d in key_depths if d < depth) + d = FrozenDINOSmallNoDrop(depth=depth, key_depths=key_depths, norm_eps=norm_eps) + missing, unexpected = d.load_state_dict(state, strict=False) + missing = [m for m in missing if all(x not in m for x in { + 'x_scale', 'x_shift', + })] + if torch.cuda.is_available(): + assert len(missing) == 0, f'missing keys: {missing}' + assert len(unexpected) == 0, f'unexpected keys: {unexpected}' + + # todo: don't compile! reduce-overhead would raise CudaERR + self.dino_proxy: Tuple[FrozenDINOSmallNoDrop] = (d.to(device=device),) + dino_C = self.dino_proxy[0].embed_dim + # if 'KEVIN_LOCAL' in os.environ: + # torch.manual_seed(0) + # np.random.seed(0) + # random.seed(0) + self.heads = nn.ModuleList([ + nn.Sequential( + make_block(dino_C, kernel_size=1, norm_type=norm_type, norm_eps=norm_eps, using_spec_norm=using_spec_norm), + ResidualBlock(make_block(dino_C, kernel_size=ks, norm_type=norm_type, norm_eps=norm_eps, using_spec_norm=using_spec_norm)), + (SpectralConv1d if using_spec_norm else nn.Conv1d)(dino_C, 1, kernel_size=1, padding=0) + ) + for _ in range(len(key_depths) + 1) # +1: before all attention blocks + ]) + + def forward(self, x_in_pm1, grad_ckpt=False): # x_in_pm1: image tensor normalized to [-1, 1] + dino_grad_ckpt = grad_ckpt and x_in_pm1.requires_grad + FrozenDINOSmallNoDrop.forward + activations: List[torch.Tensor] = self.dino_proxy[0](x_in_pm1.float(), grad_ckpt=dino_grad_ckpt) + B = x_in_pm1.shape[0] + return torch.cat([ + ( + h(act) if not grad_ckpt + else torch.utils.checkpoint.checkpoint(h, act, use_reentrant=False) + ).view(B, -1) + for h, act in zip(self.heads, activations) + ], dim=1) # cat 5 BL => B, 5L + + +class PatchEmbed(nn.Module): + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): + super().__init__() + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = (img_size // patch_size) ** 2 + self.flatten = flatten + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + x = self.proj(x).flatten(2).transpose(1, 2) # BCHW => BCL => BLC + return self.norm(x) + + +class FrozenDINOSmallNoDrop(nn.Module): + """ + Frozen DINO ViT without any dropout or droppath layers (eval node only), based on timm.create_model('vit_small_patch16_224', pretrained=False, num_classes=0) + + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + + Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` + - https://arxiv.org/abs/2012.12877 + """ + + def __init__( + self, depth=12, key_depths=(2, 5, 8, 11), norm_eps=1e-6, # 4 stages: 012, 345, 678, 9 10 11 + patch_size=16, in_chans=3, num_classes=0, + embed_dim=384, num_heads=6, mlp_ratio=4., + # drop_rate=0., attn_drop_rate=0., drop_path_rate=0. # no drop for frozen model + ): + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + self.img_size = 224 + self.patch_embed = PatchEmbed(img_size=self.img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + self.patch_size = patch_size + self.patch_nums = self.img_size // patch_size + + # x \in [-1, 1] + # x = ((x+1)/2 - m) / s = 0.5x/s + 0.5/s - m/s = (0.5/s) x + (0.5-m)/s + m, s = torch.tensor((0.485, 0.456, 0.406)), torch.tensor((0.229, 0.224, 0.225)) + self.register_buffer('x_scale', (0.5/s).reshape(1, 3, 1, 1)) + self.register_buffer('x_shift', ((0.5-m)/s).reshape(1, 3, 1, 1)) + self.crop = RandomCrop(self.img_size) + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.dist_token = None + self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_nums*self.patch_nums + 1, embed_dim)) # +1: for cls + # self.pos_drop = nn.Dropout(p=drop_rate) + # self.pos_pool = dict() + + self.key_depths = set(d for d in key_depths if d < depth) + # dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # no drop for frozen model + self.blocks = nn.Sequential(*[ + SABlockNoDrop(block_idx=i, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, norm_eps=norm_eps) + for i in range(max(depth, 1+max(self.key_depths))) + ]) + self.norm = nn.LayerNorm(embed_dim, eps=norm_eps) + + # eval mode only + self.eval() + [p.requires_grad_(False) for p in self.parameters()] + + def inter_pos_embed(self, patch_nums=(14, 14)): + if patch_nums[0] == self.patch_nums and patch_nums[1] == self.patch_nums: + return self.pos_embed + pe_cls, pe_grid = self.pos_embed[:, :1], self.pos_embed[0, 1:] + pe_grid = pe_grid.reshape(1, self.patch_nums, self.patch_nums, -1).permute(0, 3, 1, 2) + pe_grid = F.interpolate(pe_grid, size=(patch_nums[0], patch_nums[1]), mode='bilinear', align_corners=False) + pe_grid = pe_grid.permute(0, 2, 3, 1).reshape(1, patch_nums[0] * patch_nums[1], -1) + return torch.cat([pe_cls, pe_grid], dim=1) + + def forward(self, x, grad_ckpt=False): + with torch.cuda.amp.autocast(enabled=False): + x = (self.x_scale * x.float()).add_(self.x_shift) + H, W = x.shape[-2], x.shape[-1] + if H > self.img_size and W > self.img_size and random.random() <= 0.5: + x = self.crop(x) + else: + x = F.interpolate(x, size=(self.img_size, self.img_size), mode='area' if H > self.img_size else 'bicubic') + # x now must be self.img_size x self.img_size + + # patch_nums = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size + # x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), self.patch_embed(x)), dim=1) + # if patch_nums in self.pos_pool: + # x += self.pos_pool[patch_nums] + # else: + # self.pos_pool[patch_nums] = pe = self.inter_pos_embed(patch_nums) + # x += pe + # x = self.pos_drop(x) + + x = self.patch_embed(x) + + with torch.cuda.amp.autocast(enabled=False): + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x.float()), dim=1) + x = x + self.pos_embed + activations = [(x[:, 1:] + x[:, :1]).transpose_(1, 2)] # readout + for i, b in enumerate(self.blocks): + if not grad_ckpt: + x = b(x) + else: + x = torch.utils.checkpoint.checkpoint(b, x, use_reentrant=False) + if i in self.key_depths: + activations.append((x[:, 1:].float() + x[:, :1].float()).transpose_(1, 2)) # readout + # x = self.norm(x) + return activations + + +if __name__ == '__main__': + torch.manual_seed(0) + np.random.seed(0) + random.seed(0) + ks = 9 + norm_type = 'sbn' + norm_eps = 1e-6 + dino_C = 384 + key_layers = (2, 5, 8, 11) + using_spec_norm = True + + heads = nn.ModuleList([ + nn.Sequential( + make_block(dino_C, kernel_size=1, norm_type=norm_type, norm_eps=norm_eps, using_spec_norm=using_spec_norm), + ResidualBlock(make_block(dino_C, kernel_size=ks, norm_type=norm_type, norm_eps=norm_eps, using_spec_norm=using_spec_norm)), + (SpectralConv1d if using_spec_norm else nn.Conv1d)(dino_C, 1, kernel_size=1, padding=0) + ) + for _ in range(len(key_layers) + 1) + ]) + + ckpt = os.path.join(os.path.dirname(__file__), '/mnt/bn/foundation-lq/tiankeyu/ckpt_vae/vit_small_patch16_224.pth') + + DinoDisc.forward + dd = DinoDisc('cpu', dino_ckpt_path=ckpt, ks=ks, norm_type=norm_type, norm_eps=norm_eps, key_depths=key_layers) + dd.eval() + dd.heads.load_state_dict(heads.state_dict()) + print(f'{sum(p.numel() for p in dd.parameters() if p.requires_grad)/1e6:.2f}M') + inp = torch.linspace(-2, 2, 2*3*224*224).reshape(2, 3, 224, 224) + inp.requires_grad = True + cond = torch.rand(2, 64) + mid_ls = dd.dino_proxy[0](inp) + means = [round(m.mean().item(), 3) for m in mid_ls] + stds = [round(m.std().item(), 3) for m in mid_ls] + print(f'mean: {means}') + print(f'std: {stds}') + + o = dd(inp, grad_ckpt=True) + print(f'o: {o.abs().mean().item():.9f}, {o.abs().std().item():.9f}') + o.abs().mean().backward() + + # for n, p in dd.named_parameters(): + # tag = n.split('heads.')[-1][0] + # if p.ndim == 3: tag += '.conv1d' + # print(f'[{tag}] {n}: {p.shape}') + +""" +对于使用qkv的版本,输出是 +7.39M +mean: [0.019, -0.028, 0.054, 0.058, 0.074] +std: [0.427, 0.142, 0.169, 0.194, 0.153] +o: 50.266475677, 91.698143005 + +对于使用zero_k_bias的版本,输出是 +7.39M +mean: [0.019, -0.028, 0.054, 0.058, 0.074] +std: [0.427, 0.142, 0.169, 0.194, 0.153] +o: 50.266475677, 91.698143005 +""" diff --git a/models/helpers.py b/models/helpers.py new file mode 100644 index 0000000..91a4b5c --- /dev/null +++ b/models/helpers.py @@ -0,0 +1,59 @@ +import torch +from torch import nn as nn +from torch.nn import functional as F + + +def sample_with_top_k_top_p_(logits_BlV: torch.Tensor, top_k: int = 0, top_p: float = 0.0, rng=None, num_samples=1) -> torch.Tensor: # return idx, shaped (B, l) + B, l, V = logits_BlV.shape + if top_k > 0: + idx_to_remove = logits_BlV < logits_BlV.topk(top_k, largest=True, sorted=False, dim=-1)[0].amin(dim=-1, keepdim=True) + logits_BlV.masked_fill_(idx_to_remove, -torch.inf) + if top_p > 0: + sorted_logits, sorted_idx = logits_BlV.sort(dim=-1, descending=False) + sorted_idx_to_remove = sorted_logits.softmax(dim=-1).cumsum_(dim=-1) <= (1 - top_p) + sorted_idx_to_remove[..., -1:] = False + logits_BlV.masked_fill_(sorted_idx_to_remove.scatter(sorted_idx.ndim - 1, sorted_idx, sorted_idx_to_remove), -torch.inf) + # sample (have to squeeze cuz torch.multinomial can only be used for 2D tensor) + replacement = num_samples >= 0 + num_samples = abs(num_samples) + return torch.multinomial(logits_BlV.softmax(dim=-1).view(-1, V), num_samples=num_samples, replacement=replacement, generator=rng).view(B, l, num_samples) + + +def gumbel_softmax_with_rng(logits: torch.Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1, rng: torch.Generator = None) -> torch.Tensor: + if rng is None: + return F.gumbel_softmax(logits=logits, tau=tau, hard=hard, eps=eps, dim=dim) + + gumbels = (-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_(generator=rng).log()) + gumbels = (logits + gumbels) / tau + y_soft = gumbels.softmax(dim) + + if hard: + index = y_soft.max(dim, keepdim=True)[1] + y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0) + ret = y_hard - y_soft.detach() + y_soft + else: + ret = y_soft + return ret + + +def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): # taken from timm + if drop_prob == 0. or not training: return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class DropPath(nn.Module): # taken from timm + def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + def extra_repr(self): + return f'(drop_prob=...)' diff --git a/models/quant.py b/models/quant.py new file mode 100644 index 0000000..e97d5df --- /dev/null +++ b/models/quant.py @@ -0,0 +1,118 @@ +from typing import List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +from torch import distributed as tdist, nn as nn +from torch.nn import functional as F + +import dist + + +# this file only provides the VectorQuantizer2 used in VQVAE +__all__ = ['VectorQuantizer', ] + + +class NormalizedEmbedding(nn.Embedding): + def __init__(self, num_embeddings: int, embedding_dim: int): + super().__init__(num_embeddings=num_embeddings, embedding_dim=embedding_dim) + self.norm_scale = nn.Parameter(torch.tensor(0.0, dtype=torch.float32)) + + def forward(self, idx): + return F.embedding( + idx, F.normalize(self.weight, dim=1).mul_(self.norm_scale.sigmoid()), self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse + ) + + +class ResConv(nn.Conv2d): + def __init__(self, embed_dim, quant_resi): + ks = 3 if quant_resi < 0 else 1 + super().__init__(in_channels=embed_dim, out_channels=embed_dim, kernel_size=ks, stride=1, padding=ks//2) + self.resi_ratio = abs(quant_resi) + + def forward(self, h_BChw): + return h_BChw.mul(1-self.resi_ratio) + super().forward(h_BChw).mul_(self.resi_ratio) + + +class VectorQuantizer(nn.Module): + def __init__( + self, vocab_size: int, vocab_width: int, vocab_norm: bool, beta: float = 0.25, quant_resi=-0.5, + ): + super().__init__() + self.vocab_size: int = vocab_size + self.vocab_width: int = vocab_width + self.register_buffer('vocab_usage', torch.zeros(self.vocab_size)) + self.vocab_usage_record_times: int = 0 + + self.vocab_norm: bool = vocab_norm + self.quant_resi = ResConv(self.vocab_width, quant_resi=quant_resi) + # self.embedding = (NormalizedEmbedding if vocab_norm else nn.Embedding)(self.vocab_size, self.vocab_width) + self.embedding = nn.Embedding(self.vocab_size, self.vocab_width) + self.beta: float = beta + + def init_vocab(self, eini: float): + if eini > 0: + nn.init.trunc_normal_(self.embedding.weight.data, std=eini) + elif eini < 0: + base = self.vocab_width ** -0.5 + base /= 36 + self.embedding.weight.data.uniform_(-abs(eini) * base, abs(eini) * base) + + def extra_repr(self) -> str: + return f'beta={self.beta:g}' + + # ===================== `forward` is only used in VAE training ===================== + def forward(self, f_BChw: torch.Tensor, ret_usages=False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[float]]: + f_BChw = f_BChw.float() + B, C, h, w = f_BChw.shape + # find the nearest embedding + query_NxC = f_BChw.detach().permute(0, 2, 3, 1).reshape(-1, C) + if self.vocab_norm: + query_NxC = F.normalize(query_NxC, dim=-1) + idx_N = torch.argmax(query_NxC @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1) + else: + E_dist = torch.sum(query_NxC.square(), dim=1, keepdim=True) + torch.sum(self.embedding.weight.data.square(), dim=1, keepdim=False) + E_dist.addmm_(query_NxC, self.embedding.weight.data.T, alpha=-2, beta=1) # (B*h*w, vocab_size) + idx_N = torch.argmin(E_dist, dim=1) + + prob_per_class_is_chosen = idx_N.bincount(minlength=self.vocab_size).float() + handler = tdist.all_reduce(prob_per_class_is_chosen, async_op=True) if (self.training and dist.initialized()) else None + + # look up + idx_Bhw = idx_N.view(B, h, w) + fhat_BChw = self.quant_resi(self.embedding(idx_Bhw).permute(0, 3, 1, 2).contiguous()) + + # calc loss + vq_loss = F.mse_loss(fhat_BChw.detach(), f_BChw).mul_(self.beta) + F.mse_loss(fhat_BChw, f_BChw.detach()) + + # VQVAE: straight through gradient estimation, copy the gradient on fhat_BChw to f_BChw + fhat_BChw = (fhat_BChw.detach() - f_BChw.detach()).add_(f_BChw) + + # update vocab_usage + if handler is not None: handler.wait() + prob_per_class_is_chosen /= prob_per_class_is_chosen.sum() + vocab_usage = (prob_per_class_is_chosen > 0.01 / self.vocab_size).float().mean().mul_(100) + + if self.vocab_usage_record_times == 0: self.vocab_usage.copy_(prob_per_class_is_chosen) + elif self.vocab_usage_record_times < 100: self.vocab_usage.mul_(0.9).add_(prob_per_class_is_chosen, alpha=0.1) + else: self.vocab_usage.mul_(0.99).add_(prob_per_class_is_chosen, alpha=0.01) + self.vocab_usage_record_times += 1 + + entropy_loss = 0.0 # todo: not implemented yet + return fhat_BChw, vq_loss, entropy_loss, (vocab_usage if ret_usages else None) + # ===================== `forward` is only used in VAE training ===================== + + def f_to_idx(self, f_BChw: torch.Tensor) -> torch.LongTensor: + f_BChw = f_BChw.float() + B, C, h, w = f_BChw.shape + with torch.cuda.amp.autocast(enabled=False): + # find the nearest embedding + query_NxC = f_BChw.detach().permute(0, 2, 3, 1).reshape(-1, C) + if self.vocab_norm: + query_NxC = F.normalize(query_NxC, dim=-1) + idx_N = torch.argmax(query_NxC @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1) + else: + E_dist = torch.sum(query_NxC.square(), dim=1, keepdim=True) + torch.sum(self.embedding.weight.data.square(), dim=1, keepdim=False) + E_dist.addmm_(query_NxC, self.embedding.weight.data.T, alpha=-2, beta=1) # (B*h*w, vocab_size) + idx_N = torch.argmin(E_dist, dim=1) + return idx_N.view(B, h, w) diff --git a/models/vqvae.py b/models/vqvae.py new file mode 100644 index 0000000..5dd7adf --- /dev/null +++ b/models/vqvae.py @@ -0,0 +1,127 @@ +""" +References: +- VectorQuantizer: VectorQuantizer2 from https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/vqvae/quantize.py#L110 +- VQVAE: VQModel from https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/models/autoencoder.py#L14 +""" +from contextlib import nullcontext +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +import timm +import torch +import torch.nn as nn +import torch.nn.functional as F + +from models.basic_vae import CNNDecoder, CNNEncoder +from models.quant import VectorQuantizer + + +def identity(x, inplace=False): return x + + +class VQVAE(nn.Module): + def __init__( + self, + # for all: + grad_ckpt=False, # whether to use gradient checkpointing + + # vitamin encoder: + vitamin='', # 's', 'b', 'l' for using vitamin; 'cnn' or '' for using CNN + drop_path_rate=0.1, + + # CNN encoder & CNN decoder: + ch=128, # basic width of CNN encoder and CNN decoder + ch_mult=(1, 1, 2, 2, 4), # downsample_ratio would be 2 ** (len(ch_mult) - 1) + dropout=0.0, # dropout in CNN encoder and CNN decoder + + # quantizer: + vocab_size=4096, + vocab_width=32, + vocab_norm=False, # whether to limit the codebook vectors to have unit norm + beta=0.25, # commitment loss weight + quant_conv_k=3, # quant conv kernel size + quant_resi=-0.5, # + ): + super().__init__() + self.downsample_ratio = 2 ** (len(ch_mult) - 1) + + # 1. build encoder + print(f'[VQVAE] create CNN Encoder with {ch=}, {ch_mult=} {dropout=:g} ...', flush=True) + self.encoder: CNNEncoder = CNNEncoder( + ch=ch, ch_mult=ch_mult, num_res_blocks=2, dropout=dropout, + img_channels=3, output_channels=vocab_width, using_sa=True, using_mid_sa=True, + grad_ckpt=grad_ckpt, + ) + # 2. build conv before quant + self.quant_conv = nn.Conv2d(vocab_width, vocab_width, quant_conv_k, stride=1, padding=quant_conv_k // 2) + + # 3. build quant + print(f'[VQVAE] create VectorQuantizer with {vocab_size=}, {vocab_width=} {vocab_norm=}, {beta=:g} ...', flush=True) + self.quantize: VectorQuantizer = VectorQuantizer(vocab_size=vocab_size, vocab_width=vocab_width, vocab_norm=vocab_norm, beta=beta, quant_resi=quant_resi) + + # 4. build conv after quant + self.post_quant_conv = nn.Conv2d(vocab_width, vocab_width, quant_conv_k, stride=1, padding=quant_conv_k // 2) + print(f'[VQVAE] create CNN Decoder with {ch=}, {ch_mult=} {dropout=:g} ...', flush=True) + + # 5. build decoder + self.decoder = CNNDecoder( + ch=ch, ch_mult=ch_mult, num_res_blocks=3, dropout=dropout, + input_channels=vocab_width, using_sa=True, using_mid_sa=True, + grad_ckpt=grad_ckpt, + ) + self.maybe_record_function = nullcontext + + def forward(self, img_B3HW, ret_usages=False): + f_BChw = self.encoder(img_B3HW).float() + with torch.cuda.amp.autocast(enabled=False): + VectorQuantizer.forward + f_BChw, vq_loss, entropy_loss, usages = self.quantize(self.quant_conv(f_BChw), ret_usages=ret_usages) + f_BChw = self.post_quant_conv(f_BChw) + return self.decoder(f_BChw).float(), vq_loss, entropy_loss, usages + + def img_to_idx(self, img_B3HW: torch.Tensor) -> torch.LongTensor: + f_BChw = self.encoder(img_B3HW) + f_BChw = self.quant_conv(f_BChw) + return self.quantize.f_to_idx(f_BChw) + + def idx_to_img(self, idx_Bhw: torch.Tensor) -> torch.Tensor: + f_hat_BChw = self.quantize.quant_resi(self.quantize.embedding(idx_Bhw).permute(0, 3, 1, 2)) + f_hat_BChw = self.post_quant_conv(f_hat_BChw) + return self.decoder(f_hat_BChw).clamp_(-1, 1) + + def img_to_reconstructed_img(self, img_B3HW) -> torch.Tensor: + return self.idx_to_img(self.img_to_idx(img_B3HW)) + + def state_dict(self, *args, **kwargs): + d = super().state_dict(*args, **kwargs) + d['vocab_usage_record_times'] = self.quantize.vocab_usage_record_times + return d + + def load_state_dict(self, state_dict: Dict[str, Any], strict=True, assign=False): + if 'quantize.vocab_usage' not in state_dict or state_dict['quantize.vocab_usage'].shape[0] != self.quantize.vocab_usage.shape[0]: + state_dict['quantize.vocab_usage'] = self.quantize.vocab_usage + if 'vocab_usage_record_times' in state_dict: + self.quantize.vocab_usage_record_times = state_dict.pop('vocab_usage_record_times') + return super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign) + + +if __name__ == '__main__': + for clz in (nn.Linear, nn.LayerNorm, nn.BatchNorm2d, nn.SyncBatchNorm, nn.Conv1d, nn.Conv2d, nn.ConvTranspose1d, nn.ConvTranspose2d): + setattr(clz, 'reset_parameters', lambda self: None) + # cnn = VQVAE(ch=160, vocab_norm=False) + # print(cnn) + # numel = [p.numel() for p in cnn.parameters()] + # para = sum(numel) + # print(len(numel), para, para/1e6) + # exit(0) + + # cnn = VQVAE(ch=32, vocab_norm=True) + # vit = VQVAE(vitamin='S', vocab_norm=True) + # cnn(torch.rand(2, 3, 192, 288))[0].mean().backward() + # vit(torch.rand(2, 3, 256, 256))[0].mean().backward() + # print(cnn.state_dict()['vocab_usage_record_times']) + torch.manual_seed(0) + cnn = VQVAE(ch=32, vocab_width=16, vocab_norm=False) + print(str(cnn).replace('BnActConvBnActConv', 'ResnetBlock').replace('2x(', '(')) + from models import init_weights + init_weights(cnn, -0.5) + torch.save(cnn.state_dict(), r'C:\Users\16333\Desktop\PyCharm\vlip\local_output\cnn.pth') diff --git a/train.py b/train.py new file mode 100644 index 0000000..a158653 --- /dev/null +++ b/train.py @@ -0,0 +1,599 @@ +import gc +import glob +import math +import os +import shutil +import subprocess +import sys +import time +import warnings +from collections import deque +from contextlib import nullcontext +from functools import partial +from typing import List, Optional, Tuple + +import GPUtil +import colorama +import numpy as np +import torch +from torch.autograd.profiler import record_function +from torch.utils.data import DataLoader + +import dist +from utils import arg_util, misc +from utils.data import build_dataset, pil_load +from utils.data_sampler import DistInfiniteBatchSampler + + +def create_tb_lg(args: arg_util.Args): + tb_lg: misc.TensorboardLogger + with_tb_lg = dist.is_master() + if with_tb_lg: + os.makedirs(args.tb_log_dir_path, exist_ok=True) + # noinspection PyTypeChecker + tb_lg = misc.DistLogger(misc.TensorboardLogger(log_dir=args.tb_log_dir_online, filename_suffix=f'_{misc.time_str("%m%d_%H%M")}')) + tb_lg.flush() + else: + # noinspection PyTypeChecker + tb_lg = misc.DistLogger(None) + dist.barrier() + return tb_lg + + +def maybe_auto_resume(args: arg_util.Args, pattern='ckpt*.pth') -> Tuple[List[str], int, int, str, List[Tuple[float, float]], dict, dict]: + info = [] + resume = None + if len(args.resume): + resume = args.resume + info.append(f'[auto_resume] load from args.resume @ {resume} ...') + elif not args.local_debug: + all_ckpt = lyoko.glob_with_latest_modified_first(os.path.join(args.bed, pattern)) + if len(all_ckpt) == 0: + resume = resume + info.append(f'[auto_resume] no ckpt found @ {pattern}') + info.append(f'[auto_resume quit]') + else: + resume = all_ckpt[0] + info.append(f'[auto_resume] auto load from @ {resume} ...') + info.append(f'[auto_resume quit]') + else: + info.append(f'[auto_resume] disabled') + info.append(f'[auto_resume quit]') + + if resume is None: + return info, 0, 0, '[no acc str]', [], {}, {} + + try: + ckpt = torch.load(resume, map_location='cpu') + except Exception as e: + info.append(f'[auto_resume] failed, {e} @ {resume}') + return info, 0, 0, '[no acc str]', [], {}, {} + + dist.barrier() + ep, it = (ckpt['epoch'], ckpt['iter']) if 'iter' in ckpt else (ckpt['epoch'] + 1, 0) + eval_milestone = ckpt.get('milestones', []) + info.append(f'[auto_resume success] resume from ep{ep}, it{it}, eval_milestone: {eval_milestone}') + return info, ep, it, ckpt.get('acc_str', '[no acc str]'), eval_milestone, ckpt['trainer'], ckpt['args'] + + +def build_things_from_args(args: arg_util.Args): + # set seed + auto_resume_info, start_ep, start_it, acc_str, eval_milestone, trainer_state, args_state = maybe_auto_resume(args, 'ckpt*.pth') + args.load_state_dict_vae_only(args_state) + args.diffs = ' '.join([f'{d:.3f}'[2:] for d in eval_milestone]) # args updated + tb_lg = create_tb_lg(args) + print(f'global bs={args.bs}, local bs={args.lbs}') + print(f'initial args:\n{str(args)}') + + if start_ep == args.ep: + print(f'[vlip] Training finished ({acc_str}), skipping ...\n\n') + return args, tb_lg + + # build data + # swin: -1~1, resize to (reso, reso) by LANCZOS + # xl: -1~1,t + if not args.local_debug: + print(f'[build PT data] ...\n') + dataset_train, val_transform = build_dataset(datasets_str=args.data, subset_ratio=args.subset, final_reso=args.img_size, mid_reso=args.mid_reso, hflip=args.hflip) + ld_train = DataLoader( + dataset=dataset_train, num_workers=args.workers, pin_memory=True, + generator=args.get_different_generator_for_each_rank(), # worker_init_fn=worker_init_fn, + batch_sampler=DistInfiniteBatchSampler( + dataset_len=len(dataset_train), glb_batch_size=args.bs, same_seed_for_all_ranks=args.same_seed_for_all_ranks, + shuffle=True, fill_last=True, rank=dist.get_rank(), world_size=dist.get_world_size(), start_ep=start_ep, start_it=start_it, + ), + ) + del dataset_train + [print(l) for l in auto_resume_info] + print(f'[dataloader multi processing] ...', end='', flush=True) + stt = time.time() + iters_train = len(ld_train) # 479 # len(ld_train) + ld_train = iter(ld_train) # iter(range(20000000)) + # noinspection PyArgumentList + print(f' [dataloader multi processing](*) finished! ({time.time()-stt:.2f}s)', flush=True, clean=True) + print(f'[dataloader] gbs={args.bs}, lbs={args.lbs}, iters_train={iters_train}') + else: + # dataset_mean, dataset_std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225) + iters_train = ld_train = None + from torchvision.transforms import transforms, InterpolationMode + from utils.data import normalize_01_into_pm1 + val_transform = transforms.Compose([ + transforms.Resize(round(args.img_size*1.3), interpolation=InterpolationMode.LANCZOS), # shorter edge would be the size + transforms.CenterCrop((args.img_size, args.img_size)), + transforms.ToTensor(), + # transforms.Normalize(mean, std, inplace=True), + normalize_01_into_pm1 + ]) + [print(l) for l in auto_resume_info] + + # import heavy packages after Dataloader object creation + from torch.nn.parallel import DistributedDataParallel as DDP + from models import build_vae_disc, VQVAE, DinoDisc + from trainer import VAETrainer + from utils.amp_opt import AmpOptimizer + from utils.lpips import LPIPS + from utils.lr_control import filter_params + from utils import optimizer + + # build models + vae_wo_ddp, disc_wo_ddp = build_vae_disc(args) + vae_wo_ddp: VQVAE + disc_wo_ddp: DinoDisc + + print(f'[PT] VAE model ({args.vae}) = {vae_wo_ddp}\n') + if isinstance(disc_wo_ddp, DinoDisc): + print(f'[PT] Disc model (frozen part) = {disc_wo_ddp.dino_proxy[0]}\n') + print(f'[PT] Disc model (trainable part) = {disc_wo_ddp}\n\n') + + assert all(p.requires_grad for p in vae_wo_ddp.parameters()) + assert all(p.requires_grad for p in disc_wo_ddp.parameters()) + count_p = lambda m: f'{sum(p.numel() for p in m.parameters())/1e6:.2f}' + print(f'[PT][#para] ' + ', '.join([f'{k}={count_p(m)}' for k, m in ( + ('VAE', vae_wo_ddp), ('VAE.enc', vae_wo_ddp.encoder), ('VAE.dec', vae_wo_ddp.decoder), ('VAE.quant', vae_wo_ddp.quantize) + )])) + print(f'[PT][#para] ' + ', '.join([f'{k}={count_p(m)}' for k, m in ( + ('Disc', disc_wo_ddp), + # ('from_wave', disc_wo_ddp.ls_from_wavelet12c), ('resi', disc_wo_ddp.ls_resi), + # ('fpn_conv', disc_wo_ddp.ls_fpn_conv), ('head', disc_wo_ddp.ls_head), ('down', disc_wo_ddp.ls_down), + # ('glb_cls', disc_wo_ddp.glb_cls), + )]) + '\n\n') + + # build optimizers + optimizers: List[AmpOptimizer] = [] + for model_name, model_wo_ddp, opt_beta, lr, wd, clip in (('vae', vae_wo_ddp, args.vae_opt_beta, args.vae_lr, args.vae_wd, args.grad_clip), ('dis', disc_wo_ddp, args.disc_opt_beta, args.disc_lr, args.disc_wd, args.grad_clip)): + if args.local_debug: + lr, wd, clip = 5e-5, 5e-4, 20 + + # sync model parameters + for p in model_wo_ddp.parameters(): + if p.requires_grad: + dist.broadcast(p.data, src_rank=0) + ndim_dict = {name: para.ndim for name, para in model_wo_ddp.named_parameters() if para.requires_grad} + + # build optimizer + nowd_keys = { + 'cls_token', 'start_token', 'task_token', 'cfg_uncond', + 'pos_embed', 'pos_1LC', 'pos_start', 'start_pos', 'lvl_embed', + 'gamma', 'beta', + 'ada_gss', 'moe_bias', + 'class_emb', 'embedding', + 'norm_scale', + } + names, paras, para_groups = filter_params(model_wo_ddp, ndim_dict, nowd_keys=nowd_keys) + + beta1, beta2 = map(float, opt_beta.split('_')) + opt_clz = { + 'adam': partial(torch.optim.AdamW, betas=(beta1, beta2), fused=args.fuse_opt), + 'adamw': partial(torch.optim.AdamW, betas=(beta1, beta2), fused=args.fuse_opt), + 'lamb': partial(optimizer.LAMBtimm, betas=(beta1, beta2), max_grad_norm=clip), # eps=1e-7 + 'lion': partial(optimizer.Lion, betas=(beta1, beta2), max_grad_norm=clip), # eps=1e-7 + }[args.opt] + opt_kw = dict(lr=lr, weight_decay=0) + if args.oeps: opt_kw['eps'] = args.oeps + + print(f'[vlip] optim={opt_clz}, opt_kw={opt_kw}\n') + optimizers.append(AmpOptimizer(model_name, model_maybe_fsdp=None, fp16=args.fp16, bf16=args.bf16, zero=args.zero, optimizer=opt_clz(params=para_groups, **opt_kw), grad_clip=clip, n_gradient_accumulation=args.grad_accu)) + del names, paras, para_groups + vae_optim, disc_optim = optimizers[0], optimizers[1] + + vae_wo_ddp, disc_wo_ddp = args.compile_model(vae_wo_ddp, args.compile_vae), args.compile_model(disc_wo_ddp, args.compile_disc) + lpips_loss: LPIPS = args.compile_model(LPIPS(args.lpips_path).to(args.device), fast=args.compile_lpips) + + # distributed wrapper + ddp_class = DDP if dist.initialized() else NullDDP + vae: DDP = ddp_class(vae_wo_ddp, device_ids=[dist.get_local_rank()], find_unused_parameters=False, static_graph=args.ddp_static, broadcast_buffers=False) + disc: DDP = ddp_class(disc_wo_ddp, device_ids=[dist.get_local_rank()], find_unused_parameters=False, static_graph=args.ddp_static, broadcast_buffers=False) + + vae_optim.model_maybe_fsdp = vae if args.zero else vae_wo_ddp + disc_optim.model_maybe_fsdp = disc if args.zero else disc_wo_ddp + + trainer = VAETrainer( + is_visualizer=dist.is_master(), + vae=vae, vae_wo_ddp=vae_wo_ddp, disc=disc, disc_wo_ddp=disc_wo_ddp, ema_ratio=args.ema, + dcrit=args.dcrit, vae_opt=vae_optim, disc_opt=disc_optim, + daug=args.disc_aug_prob, lpips_loss=lpips_loss, lp_reso=args.lpr, wei_l1=args.l1, wei_l2=args.l2, wei_entropy=args.le, wei_lpips=args.lp, wei_disc=args.ld, adapt_type=args.gada, bcr=args.bcr, bcr_cut=args.bcr_cut, reg=args.reg, reg_every=args.reg_every, + disc_grad_ckpt=args.disc_grad_ckpt, + dbg_unused=args.dbg_unused, dbg_nan=args.dbg_nan + ) + if trainer_state is not None and len(trainer_state): + trainer.load_state_dict(trainer_state, strict=False) + del vae, vae_wo_ddp, disc, disc_wo_ddp, vae_optim, disc_optim + + func = lambda x: os.path.basename(x) not in {'v3_008d0681123bcdf1.jpg', 'v4_00938fc5a0223cf4.jpg', 'v6_013afe5493a1a41c.jpg'} + val_imgs = list(filter(func, sorted(glob.glob(args.val_img_pattern)))) + + if args.local_debug: + inp = [] + for im in val_imgs: + im = pil_load(im, args.img_size * 2) + inp.append(val_transform(im)) + inp = torch.stack(inp, dim=0).to(args.device, non_blocking=True) + print(f'[{inp.shape=}]') + + me = misc.MetricLogger(delimiter=' ') + dbg_it = 599 + me.log_iters = {0, dbg_it} + print(f'{trainer.vae_wo_ddp.encoder.conv_in.weight.data.view(-1)[:4]=}') + args.seed_everything() + trainer.train_step( + ep=0, it=0, g_it=0, stepping=True, regularizing=False, metric_lg=me, logging_params=True, tb_lg=tb_lg, + inp=inp, + warmup_disc_schedule=0.0, fade_blur_schedule=0.8, + maybe_record_function=nullcontext, + args=args + ) + trainer.train_step( + ep=1, it=dbg_it, g_it=dbg_it, stepping=True, regularizing=True, metric_lg=me, logging_params=True, tb_lg=tb_lg, + inp=inp, + warmup_disc_schedule=0.8, fade_blur_schedule=0.0, + maybe_record_function=nullcontext, + args=args + ) + print({k: meter.global_avg for k, meter in me.meters.items()}) + + if isinstance(sys.stdout, dist.BackupStreamToFile) and isinstance(sys.stderr, dist.BackupStreamToFile): + sys.stdout.close(), sys.stderr.close() + exit(0) + + vis_dir, vis_file = '_vis_cached', f'{"vae_oi1in" if is_old_exp else "vae_mine"}_8x{args.img_size}.pth' + vis_path = os.path.join(vis_dir, vis_file) + + print(f'[dld {vis_file}] before dld') + if not os.path.exists(vis_path): + if dist.is_local_master(): + misc.os_system(f'mkdir -p {vis_dir}; cp {lyoko.BNAS_DATA}/ckpt_vgpt/{vis_file} {vis_dir}/ >/dev/null 2>&1') + dist.barrier() + + print(f'[dld {vis_file}] before load') + if os.path.exists(vis_path): + inp, label = torch.load(vis_path, map_location='cpu') + inp, label = inp.to(args.device, non_blocking=True), label.to(args.device, non_blocking=True) + print(f'[dld {vis_file}] {vis_path} successfully loaded.', flush=True) + else: + print(f'[dld {vis_file}] {vis_path} not found, now create and upload.', flush=True) + inp, label = [], [] + for im in val_imgs: + im = pil_load(im, args.img_size * 2) + inp.append(val_transform(im)) + label.append(0) + inp, label = torch.stack(inp, dim=0).to(args.device, non_blocking=True), torch.tensor(label, dtype=torch.long).to(args.device, non_blocking=True) + if dist.is_master(): + torch.save([inp, label], vis_path) + misc.os_system(f'mkdir -p {lyoko.BNAS_DATA}/ckpt_vgpt; cp {vis_path} {lyoko.BNAS_DATA}/ckpt_vgpt/ >/dev/null 2>&1') + dist.barrier() + + del inp, label, val_transform + return ( + tb_lg, trainer, + start_ep, start_it, acc_str, eval_milestone, iters_train, ld_train, + ) + + +g_speed_ls = deque(maxlen=128) +def train_one_ep(ep: int, is_first_ep: bool, start_it: int, saver: CKPTSaver, args: arg_util.Args, tb_lg: misc.TensorboardLogger, ld_or_itrt, iters_train: int, trainer, logging_params_milestone): + # import heavy packages after Dataloader object creation + from trainer import VAETrainer + from utils.lr_control import lr_wd_annealing + trainer: VAETrainer + + step_cnt = 0 + me = misc.MetricLogger(delimiter=' ') + [me.add_meter(x, misc.SmoothedValue(window_size=1, fmt='{value:.2g}')) for x in ['glr', 'dlr']] + [me.add_meter(x, misc.SmoothedValue(window_size=1, fmt='{median:.2f} ({global_avg:.2f})')) for x in ['gnm', 'dnm']] + for l in ['L1', 'NLL', 'Ld', 'Wg']: + me.add_meter(l, misc.SmoothedValue(fmt='{median:.3f} ({global_avg:.3f})')) + header = f'[Ep]: [{ep:4d}/{args.ep}]' + + touching_secs = 120 + if is_first_ep: + warnings.filterwarnings('ignore', category=DeprecationWarning) + warnings.filterwarnings('ignore', category=UserWarning) + g_it, wp_it, max_it = ep * iters_train, args.warmup_ep * iters_train, args.ep * iters_train + disc_start = args.disc_start_ep * iters_train + disc_wp_it, disc_max_it = args.disc_warmup_ep * iters_train, max_it - disc_start + + doing_profiling = args.prof and is_first_ep and (args.profall or dist.is_master()) + maybe_record_function = record_function if doing_profiling else nullcontext + trainer.vae_wo_ddp.maybe_record_function = maybe_record_function + if args.zero: + pref = 'hybrid' if args.hsdp else 'fsdp' + if args.buck in {'0', '0.0', '0e0', '0.0e0'}: + parallel = f'ep{ep}_{pref}{args.zero}_module_orig{args.fsdp_orig:d}' + else: + parallel = f'ep{ep}_{pref}{args.zero}_buk{args.buck}_orig{args.fsdp_orig:d}' + else: + parallel = 'ddp' + if os.getenv('NCCL_CROSS_NIC', '0') == '1': + parallel += f'_NIC1' + profiling_name = f'{args.vae}_bs{args.bs}_{parallel}_gradckpt{args.vae_grad_ckpt:d}__GPU{dist.get_rank_str_zfill()}of{dist.get_world_size()}' + + profiler = None + if doing_profiling: + profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=torch.profiler.schedule( + wait=40, + warmup=3, + active=2, + repeat=1, + ), + record_shapes=True, + profile_memory=True, + with_stack=True, + on_trace_ready=TraceHandler('./', f'{profiling_name}.pt.trace.json', args.tos_profiler_file_prefix, args.bed) + ) + profiler.start() + + last_t_perf = time.perf_counter() + speed_ls: deque = g_speed_ls + FREQ = min(50, iters_train//2-1) + NVIDIA_IT_PLUS_1 = set(FREQ*i for i in (1, 2, 3, 4, 6, 8)) + PRINTABLE_IT_PLUS_1 = set(FREQ*i for i in (1, 2, 3, 4, 6, 8, 12, 16, 24, 32)) + for it, inp in me.log_every(start_it, iters_train, ld_or_itrt, max(10, iters_train // 1000), header): + if (it+1) % FREQ == 0: + speed_ls.append((time.perf_counter()-last_t_perf)/FREQ) + iter_speed = float(np.median(speed_ls)) + img_per_sec = args.bs / iter_speed + img_per_day = img_per_sec * 3600 * 24 / 1e6 + args.iter_speed, args.img_per_day = iter_speed, img_per_day + + if (it+1) in NVIDIA_IT_PLUS_1: args.max_nvidia_smi = max(args.max_nvidia_smi, max(gpu.memoryUsed for gpu in GPUtil.getGPUs()) / 1024) + mem_infos_dict = torch.cuda.memory_stats() + memory_allocated = round(mem_infos_dict['allocated_bytes.all.current']/1024**3, 2) + memory_reserved = round(mem_infos_dict['reserved_bytes.all.current']/1024**3, 2) + args.max_memory_allocated = round(mem_infos_dict['allocated_bytes.all.peak']/1024**3, 2) + args.max_memory_reserved = round(mem_infos_dict['reserved_bytes.all.peak']/1024**3, 2) + args.num_alloc_retries = mem_infos_dict['num_alloc_retries'] + if (ep <= 1 or ep == math.floor(args.disc_start_ep + 1e-4)) and (it+1) in PRINTABLE_IT_PLUS_1: + tails = list(speed_ls)[-10:] + print( + colorama.Fore.LIGHTCYAN_EX + + f"[profiling] " + f"speed: {iter_speed:.3f} ({min(tails):.3f}~{max(tails):.2f}) sec/iter | " + f"{img_per_sec:.1f} imgs/sec | " + f"{img_per_day:.2f}M imgs/day | " + f"{img_per_day*(args.img_size//trainer.vae_wo_ddp.downsample_ratio)**2/1e3:.2f}B token/day || " + f"Peak nvidia-smi: {args.max_nvidia_smi:.2f} GB || " + f"PyTorch mem - " + f"alloc: {memory_allocated:.2f} | " + f"max_alloc: {args.max_memory_allocated:.2f} | " + f"reserved: {memory_reserved:.2f} | " + f"max_reserved: {args.max_memory_reserved:.2f} | " + f"num_alloc_retries: {args.num_alloc_retries}" + colorama.Fore.RESET + colorama.Back.RESET + colorama.Style.RESET_ALL, + flush=True + ) + last_t_perf = time.perf_counter() + + if it < start_it: continue + if is_first_ep and it == start_it: warnings.resetwarnings() + + if doing_profiling: profiler.step() + + with maybe_record_function('before_train'): + inp = inp.to(args.device, non_blocking=True) + + g_it = ep * iters_train + it + disc_g_it = g_it - disc_start + args.cur_it = f'{it+1}/{iters_train}' + min_glr, max_glr, min_gwd, max_gwd = lr_wd_annealing(args.sche, trainer.vae_opt.optimizer, args.vae_lr, args.vae_wd, g_it, wp_it, max_it, wp0=args.wp0, wpe=args.sche_end) + if disc_g_it >= 0: + min_dlr, max_dlr, min_dwd, max_dwd = lr_wd_annealing(args.sche, trainer.disc_opt.optimizer, args.disc_lr, args.disc_wd, disc_g_it, disc_wp_it, disc_max_it, wp0=args.wp0, wpe=args.sche_end) + else: + min_dlr = max_dlr = min_dwd = max_dwd = 0 + + stepping = (g_it + 1) % args.grad_accu == 0 + step_cnt += int(stepping) + warmup_disc_schedule = 0 if disc_g_it < 0 else min(1.0, disc_g_it / disc_wp_it) + fade_blur_schedule = 0 if disc_g_it < 0 else min(1.0, disc_g_it / (disc_wp_it * 2)) + fade_blur_schedule = 1 - fade_blur_schedule + + grad_norm_g, scale_log2_g, grad_norm_d, scale_log2_d = trainer.train_step( + ep=ep, it=it, g_it=g_it, stepping=stepping, regularizing=args.reg > 0 and (g_it % args.reg_every == 0), + metric_lg=me, logging_params=stepping and step_cnt == 1 and (ep < 4 or ep in logging_params_milestone), tb_lg=tb_lg, + inp=inp, + warmup_disc_schedule=warmup_disc_schedule, + fade_blur_schedule=fade_blur_schedule, + maybe_record_function=maybe_record_function, + args=args, + ) + + with maybe_record_function('after_train'): + me.update(glr=max_glr, dlr=max_dlr) + tb_lg.set_step(step=g_it) + if tb_lg.loggable(): + if args.max_nvidia_smi > 0: + tb_lg.update(head='Profiling/speed', iter_cost=args.iter_speed, img_per_day=args.img_per_day) + tb_lg.update(head='Profiling/cuda_mem', max_nvi_smi=args.max_nvidia_smi, max_alloc=args.max_memory_allocated, max_reserve=args.max_memory_reserved, alloc_retries=args.num_alloc_retries) + + tb_lg.update(head='PT_opt_lr/lr_max', sche_glr=max_glr, sche_dlr=max_dlr) + tb_lg.update(head='PT_opt_lr/lr_min', sche_glr=min_glr, sche_dlr=min_dlr) + tb_lg.update(head='PT_opt_wd/wd_max', sche_gwd=max_gwd, sche_dwd=max_dwd) + tb_lg.update(head='PT_opt_wd/wd_min', sche_gwd=min_gwd, sche_dwd=min_dwd) + if scale_log2_g is not None: + tb_lg.update(head='PT_opt_grad/fp16', scale_log2_g=scale_log2_g, scale_log2_d=scale_log2_d) + + tb_lg.update(head='PT_opt_grad/grad', grad_norm_g=grad_norm_g, grad_norm_d=grad_norm_d) + g_ratio = 1 if grad_norm_g is None else min(1.0, args.grad_clip / (grad_norm_g + 1e-7)) + d_ratio = 1 if grad_norm_d is None else min(1.0, args.grad_clip / (grad_norm_d + 1e-7)) + tb_lg.update(head='PT_opt_lr/lr_max', actu_glr=g_ratio*max_glr, actu_dlr=d_ratio*max_dlr) + tb_lg.update(head='PT_opt_lr/lr_min', actu_glr=g_ratio*min_glr, actu_dlr=d_ratio*min_dlr) + + me.synchronize_between_processes() + return {k: meter.global_avg for k, meter in me.meters.items()}, me.iter_time.time_preds(max_it - (g_it + 1) + (args.ep - ep) * 15) # +15: other cost + + +def main_training(): + args: arg_util.Args = arg_util.init_dist_and_get_args() + if args.dbg_unused: + torch.autograd.set_detect_anomaly(True) + + ret = build_things_from_args(args) + if len(ret) < 8: + return ret + ( + tb_lg, trainer, + start_ep, start_it, acc_str, eval_milestone, iters_train, ld_train, + ) = ret + + # import heavy packages after Dataloader object creation + from trainer import VAETrainer + ret: Tuple[ + misc.TensorboardLogger, VAETrainer, + int, int, str, List[float], Optional[int], Optional[DataLoader], + ] + saver = CKPTSaver(dist.is_master(), eval_milestone) + + # train + start_time, min_Lnll, min_Ld, disc_start = time.time(), 999., 999., False + # seg8 = np.linspace(1, args.ep, 8+1, dtype=int).tolist() + seg5 = np.linspace(1, args.ep, 5+1, dtype=int).tolist() + # noinspection PyTypeChecker + logging_params_milestone: List[int] = np.linspace(1, args.ep, 10+1, dtype=int).tolist() + eval_milestone_ep = set(seg5[:]) # seg4 + vis_milestone_ep = set(seg5[:]) | set(x for x in (2, 4, 8, 16) if x <= args.ep) + for x in [6, 12, 3, 24, 18, 48, 72, 96]: + if len(vis_milestone_ep) < 10 and x <= args.ep: + vis_milestone_ep.add(x) + + # save_milestone = list(range(5, args.ep, 2)) + [args.ep - 1] + # for i, m in enumerate(save_milestone): + # if m != args.ep - 1 and m % 100 in {99, 0}: + # save_milestone[i] -= 1 + # save_milestone = set(save_milestone) + # if 0 in save_milestone: save_milestone.remove(0) + print(f'[PT milestones] eval={sorted(eval_milestone_ep)} vis={sorted(vis_milestone_ep)}') + + diff_t = torch.tensor([0.0, 0.0], dtype=torch.float32, device=args.device) + trainer.vae_opt.log_param(ep=-1, tb_lg=tb_lg) + trainer.disc_opt.log_param(ep=-1, tb_lg=tb_lg) + time.sleep(3), gc.collect(), torch.cuda.empty_cache(), time.sleep(3) + ep_lg = max(1, args.ep // 10) if args.ep <= 100 else max(1, args.ep // 20) + for ep in range(start_ep, args.ep): + if ep % ep_lg == 0 or ep == start_ep: + print(f'[PT info] this exp is from ep{start_ep} it{start_it}, acc_str: {acc_str}, diffs: {args.diffs}, ==========> bed: {args.bed} h2: {args.tb_log_dir_online} < ==========\n') + if hasattr(ld_train, 'sampler') and hasattr(ld_train.sampler, 'set_epoch'): + ld_train.sampler.set_epoch(ep) + if 0 <= ep <= 3: + print(f'[ld_train.sampler.set_epoch({ep})]') + tb_lg.set_step(ep * iters_train) + + if args.flash_attn: + sdp_kernel_select_ctx = torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False) + else: + sdp_kernel_select_ctx = torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False) + if args.local_debug: + sdp_kernel_select_ctx = nullcontext() + with sdp_kernel_select_ctx: + stats, (sec, remain_time, finish_time) = train_one_ep( + ep, ep == start_ep, start_it if ep == start_ep else 0, saver, args, tb_lg, ld_train, iters_train, trainer, logging_params_milestone + ) + + Lnll, L1, Ld, wei_g = stats['NLL'], stats['L1'], stats['Ld'], stats['Wg'] + min_Lnll, min_Ld = min(min_Lnll, Lnll), min(min_Ld, min_Ld if Ld < 1e-7 else Ld) + acc_real, acc_fake = stats.get('acc_real', -1), stats.get('acc_fake', -1) + acc_all = (acc_real + acc_fake) * 0.5 + args.last_Lnll, args.last_L1, args.last_Ld, args.last_wei_g, args.acc_all, args.acc_real, args.acc_fake = Lnll, L1, Ld, wei_g, acc_all, acc_real, acc_fake + if not math.isfinite(Lnll + Ld + L1 + wei_g): + for n, v in zip( + ('Lnll', 'Ld', 'L1', 'wei_g'), + (Lnll, Ld, L1, wei_g), + ): + if not math.isfinite(v): + # noinspection PyArgumentList + print(f'[rk{dist.get_rank():02d}] {n} is {v}, stopping training!', force=True, flush=True) + sys.exit(666) + + args.cur_phase = 'PT' + args.cur_ep = f'{ep+1}/{args.ep}' + args.remain_time, args.finish_time = remain_time, finish_time + + from torch.nn.parallel import DistributedDataParallel as DDP + if isinstance(trainer.vae, DDP): + vae_ddp_static = trainer.vae._get_ddp_logging_data().get('can_set_static_graph') + disc_ddp_static = trainer.disc._get_ddp_logging_data().get('can_set_static_graph') + tail = colorama.Fore.LIGHTGREEN_EX + f' | static_graph: vae={vae_ddp_static}, disc={disc_ddp_static}' + colorama.Fore.RESET + colorama.Back.RESET + colorama.Style.RESET_ALL + else: + tail = '' + if ep > args.ep // 20: + print(f' [*] [ep{ep}] Min Lnll: {min_Lnll:.3f}, Ld: {min_Ld:.3f}, Remain: {remain_time}, Finish: {finish_time}' + tail) + tb_lg.update(head='PT_y_result', step=ep+1, min_Lnll=min_Lnll, min_Ld=None if min_Ld > 200 else min_Ld) + else: + print(f' [*] [ep{ep}] Remain: {remain_time}, Finish: {finish_time}' + tail) + + disc_start = acc_all >= 0 + if disc_start: + kw = dict(L1rec=L1, Lnll=Lnll, Ld=Ld, wei_g=wei_g, acc_all=acc_all, acc_fake=acc_fake, acc_real=acc_real) + else: + kw = dict(L1rec=L1, Lnll=Lnll) + tb_lg.update(head='PT_ep_loss', step=ep+1, **kw) + tb_lg.update(head='PT_z_burnout', step=ep+1, rest_hours=round(sec / 60 / 60, 2)) + + is_val_and_also_saving = (ep + 1) % 10 == 0 or (ep + 1) == args.ep + if is_val_and_also_saving: + print(f' [*] [ep{ep}] (val {tot}) Lm: {L_mean:.4f}, Lt: {L_tail:.4f}, Acc m&t: {acc_mean:.2f} {acc_tail:.2f}, Val cost: {cost:.2f}s') + + if dist.is_local_master(): + local_out_ckpt = os.path.join(args.local_out_dir_path, 'ckpt-last.pth') + local_out_ckpt_best = os.path.join(args.local_out_dir_path, 'ckpt-best.pth') + print(f'[saving ckpt] ...', end='', flush=True) + torch.save({ + 'epoch': ep+1, + 'iter': 0, + 'trainer': trainer.state_dict(), + 'args': args.state_dict(), + }, local_out_ckpt) + if best_updated: + shutil.copy(local_out_ckpt, local_out_ckpt_best) + print(f' [saving ckpt](*) finished! @ {local_out_ckpt}', flush=True, clean=True) + dist.barrier() + + total_time = f'{(time.time() - start_time) / 60 / 60:.1f}h' + print('\n\n') + print(f' [*] [finished] Total Time: {total_time}, Lg: {min_Lnll:.3f}, Ld: {min_Ld:.3f}') + print('\n\n') + + del iters_train, ld_train + tb_lg.flush(); tb_lg.close() + dist.barrier() + + +class NullDDP(torch.nn.Module): + def __init__(self, module, *args, **kwargs): + super(NullDDP, self).__init__() + self.module = module + self.require_backward_grad_sync = False + + def forward(self, *args, **kwargs): + return self.module(*args, **kwargs) + + +if __name__ == '__main__': + try: main_training() + finally: + dist.finalize() + if isinstance(sys.stdout, dist.BackupStreamToFile) and isinstance(sys.stderr, dist.BackupStreamToFile): + sys.stdout.close(), sys.stderr.close() + diff --git a/trainer.py b/trainer.py new file mode 100644 index 0000000..e283e7e --- /dev/null +++ b/trainer.py @@ -0,0 +1,339 @@ +import sys +from copy import deepcopy +from pprint import pformat +from typing import Callable, Optional, Tuple + +import seaborn as sns +import torch +import torch.nn as nn +import torch.nn.functional as F +from matplotlib.colors import ListedColormap +from torch.nn.parallel import DistributedDataParallel as DDP + +from models import VectorQuantizer, VQVAE, DinoDisc +from utils import arg_util, misc, nan +from utils.amp_opt import AmpOptimizer +from utils.diffaug import DiffAug +from utils.loss import hinge_loss, linear_loss, softplus_loss +from utils.lpips import LPIPS + +# from memory_profiler import profile + +FTen = torch.Tensor +ITen = torch.LongTensor +BTen = torch.BoolTensor + +class VAETrainer(object): + def __init__( + self, is_visualizer: bool, + vae: DDP, vae_wo_ddp: VQVAE, disc: DDP, disc_wo_ddp: DinoDisc, ema_ratio: float, # decoder, en_de_lin=True, seg_embed=False, + dcrit: str, vae_opt: AmpOptimizer, disc_opt: AmpOptimizer, + daug=1.0, lpips_loss: LPIPS = None, lp_reso=64, wei_l1=1.0, wei_l2=0.0, wei_entropy=0.0, wei_lpips=0.5, wei_disc=0.6, adapt_type=1, bcr=5.0, bcr_cut=0.5, reg=0.0, reg_every=16, + disc_grad_ckpt=False, + dbg_unused=False, dbg_nan=False, + ): + super(VAETrainer, self).__init__() + self.dbg_unused, self.dbg_nan = dbg_unused, dbg_nan + if self.dbg_nan: + print('[dbg_nan mode on]') + nan.debug_nan_hook(vae) + nan.debug_nan_hook(disc) + + self.vae, self.disc = vae, disc + self.vae_opt, self.disc_opt = vae_opt, disc_opt + self.vae_wo_ddp: VQVAE = vae_wo_ddp # after torch.compile + self.disc_wo_ddp: DinoDisc = disc_wo_ddp # after torch.compile + self.vae_params: Tuple[nn.Parameter] = tuple(self.vae_wo_ddp.parameters()) + self.disc_params: Tuple[nn.Parameter] = tuple(self.disc_wo_ddp.parameters()) + + self.ema_ratio = ema_ratio + self.is_visualizer = is_visualizer + self.using_ema = is_visualizer + if self.using_ema: + self.vae_ema: VQVAE = deepcopy(vae_wo_ddp).eval() + else: + self.vae_ema: VQVAE = None + + self.cmap_sim: ListedColormap = sns.color_palette('viridis', as_cmap=True) + + self.dcrit = dcrit + self.d_criterion: Callable = { # 'hg' by default + 'hg': hinge_loss, 'hinge': hinge_loss, + 'sp': softplus_loss, 'softplus': softplus_loss, + 'ln': linear_loss, 'lin': linear_loss, 'linear': linear_loss + }[dcrit] + + self.daug = DiffAug(prob=daug, cutout=0.2) + self.wei_l1, self.wei_l2, self.wei_entropy = wei_l1, wei_l2, wei_entropy + self.lpips_loss: LPIPS = lpips_loss + self.lp_reso = lp_reso + self.adapt_wei_disc = wei_disc > 0 + self.adapt_type = adapt_type + self.ema_gada: torch.Tensor = None + self.wei_lpips, self.wei_disc = wei_lpips*2, abs(wei_disc) + self.reg = 0.5 * reg * reg_every + # balanced_consistency_regularization, 10.0 is used by StyleSwin + self.bcr = bcr * 2 # LEGACY *2: in the old version, bcr MSE losses on real/fake images are calculated separately and added up; so *2 in the new version + if self.bcr > 0: + self.bcr_strong_aug = DiffAug(prob=1, cutout=bcr_cut) + self.disc_grad_ckpt = disc_grad_ckpt + + # @profile(precision=4, stream=open('trainstep.log', 'w+')) + def train_step( + self, ep: int, it: int, g_it: int, stepping: bool, regularizing: bool, metric_lg: misc.MetricLogger, logging_params: bool, tb_lg: misc.TensorboardLogger, + inp: FTen, warmup_disc_schedule: float, fade_blur_schedule: float, + maybe_record_function: Callable, + args: arg_util.Args, + ) -> Tuple[torch.Tensor, Optional[float], Optional[torch.Tensor], Optional[float]]: + if warmup_disc_schedule < 1e-6: warmup_disc_schedule = 0 + if fade_blur_schedule < 1e-6: fade_blur_schedule = 0 + loggable = (g_it == 0 or (g_it + 1) % 600 == 0) and self.is_visualizer + + # [vae loss] + with maybe_record_function('VAE_rec'): + with self.vae_opt.amp_ctx: + self.vae_wo_ddp.forward + rec_B3HW, Lq, Le, usage = self.vae(inp, ret_usages=loggable) + B = rec_B3HW.shape[0] + inp_rec_no_grad = torch.cat((inp, rec_B3HW.data), dim=0) + + Lrec = F.l1_loss(rec_B3HW, inp) + Lrec_for_log = Lrec.data.clone() + Lrec *= self.wei_l1 + if self.wei_l2 > 0: + Lrec += F.mse_loss(rec_B3HW, inp).mul_(self.wei_l2) + # if self.wei_llaplace > 0: + # inp_01_09 = inp.mul(0.4).add_(0.5) + # dist = (rec_B3HW.sigmoid() - inp_01_09.sigmoid()).abs() + # # dist /= lnb.exp().square().mul_(inp_01_09.add(inp_01_09).mul_(1-inp_01_09)).add_(1).mul_(0.5) + # dist /= inp_01_09.add(inp_01_09).mul_(1-inp_01_09).add_(1).mul_(0.5) + # Lrec += dist.mean().mul_(self.wei_llaplace) + + using_lpips = inp.shape[-2] >= self.lp_reso and self.wei_lpips > 0 + if using_lpips: + self.lpips_loss.forward + Lpip = self.lpips_loss(inp, rec_B3HW) + Lnll = Lrec + self.wei_lpips * Lpip + else: + Lpip = torch.tensor(0.) + Lnll = Lrec + + if warmup_disc_schedule > 0: + with maybe_record_function('VAE_disc'): + for d in self.disc_params: d.requires_grad = False + self.disc_wo_ddp.eval() + with self.disc_opt.amp_ctx: + self.disc_wo_ddp.forward + Lg = -self.disc_wo_ddp(self.daug.aug(rec_B3HW, fade_blur_schedule), grad_ckpt=False).mean() # todo: aug or not? + self.disc_wo_ddp.train() + + wei_g = warmup_disc_schedule * self.wei_disc + if self.adapt_wei_disc: + last_layer = self.vae_wo_ddp.decoder.conv_out.weight + w = ( + torch.autograd.grad(Lnll, last_layer, retain_graph=True)[0].data.norm() + / (torch.autograd.grad(Lg, last_layer, retain_graph=True)[0].data.norm().add_(1e-6)) + ) + if self.adapt_type % 10 == 0: + w.clamp_(0.0, 1e4) + elif self.adapt_type % 10 == 1: + w.clamp_(0.015, 1e4) + elif self.adapt_type % 10 == 2: + w.clamp_(0.1, 10) + w = min(max(w, 0.1), 10) + elif self.adapt_type % 10 == 3: + w.clamp_(0.0, 1e4).sqrt_() + + if self.adapt_type >= 10: + if self.ema_gada is None: + self.ema_gada = w + else: + self.ema_gada.mul_(0.9).add_(w, alpha=0.1) + w = self.ema_gada + wei_g = wei_g * w + + Lv = Lnll + Lq + self.wei_entropy * Le + wei_g * Lg + else: + Lv = Lnll + Lq + self.wei_entropy * Le + Lg = torch.tensor(0.) + wei_g = None + + # todo: G D backward together; less calling .item() + # todo: G D backward together; less calling .item() + with maybe_record_function('VAE_backward'): + grad_norm_g, scale_log2_g = self.vae_opt.backward_clip_step(stepping=stepping, loss=Lv) + + # [discriminator loss] + if warmup_disc_schedule > 0: + with maybe_record_function('Disc_forward'): + for d in self.disc_params: d.requires_grad = True + with self.disc_opt.amp_ctx: + self.disc_wo_ddp.forward + logits = self.disc(self.daug.aug(inp_rec_no_grad, fade_blur_schedule), grad_ckpt=self.disc_grad_ckpt).float() + + logits_real, logits_fake = logits[:B], logits[B:] + acc_real, acc_fake = (logits_real.data > 0).float().mean().mul_(100), (logits_fake.data < 0).float().mean().mul_(100) + + Ld = self.d_criterion(logits_real) + self.d_criterion(-logits_fake) + + if self.bcr: + with maybe_record_function('Disc_bCR'): + with self.disc_opt.amp_ctx: + self.disc_wo_ddp.forward + logits2 = self.disc(self.bcr_strong_aug.aug(inp_rec_no_grad, 0.0), grad_ckpt=self.disc_grad_ckpt).float() + Lbcr = F.mse_loss(logits2, logits).mul_(self.bcr) + Ld += Lbcr + else: + Lbcr = torch.tensor(0.) + + if regularizing: + with maybe_record_function('Disc_reg'): + self.disc_wo_ddp.eval() + with torch.cuda.amp.autocast(enabled=False): # todo: why AMP is disabled in this disc forward? + inp.requires_grad_(True) + self.disc_wo_ddp.forward + grad_real = torch.autograd.grad(outputs=self.disc(self.daug.aug(inp, fade_blur_schedule), grad_ckpt=False).sum(), inputs=inp, create_graph=True)[0] + Lreg = grad_real.square().flatten(1).sum(dim=1).mean() + Ld += self.reg * Lreg + Lreg = Lreg.item() + inp.requires_grad_(False) + self.disc_wo_ddp.train() + else: + Lreg = 0. + + with maybe_record_function('Disc_backward'): + grad_norm_d, scale_log2_d = self.disc_opt.backward_clip_step(stepping=stepping, loss=Ld) + Ld = Ld.data.clone() + else: + Ld = acc_real = acc_fake = grad_norm_d = scale_log2_d = None + Lbcr = torch.tensor(0.) + + # [zero_grad] + if stepping: + if self.using_ema: + with maybe_record_function('EMA_upd'): + self.ema_update(g_it) + + if self.dbg_nan: + nan.debug_nan_grad(self.vae_wo_ddp), nan.debug_nan_grad(self.disc_wo_ddp) + nan.debug_nan_param(self.vae_wo_ddp), nan.debug_nan_param(self.disc_wo_ddp) + if self.dbg_unused: + ls = [] + for n, p in self.vae_wo_ddp.named_parameters(): + if p.grad is None and n not in {'quantize.embedding.weight'}: # or tuple(p.grad.shape) == (512, 512, 1, 1): + ls.append(n) + for n, p in self.disc_wo_ddp.named_parameters(): + if p.grad is None: # or tuple(p.grad.shape) == (512, 512, 1, 1): + ls.append(n) + if len(ls): + print(f'unused param: {ls}', flush=True, file=sys.stderr) + + with maybe_record_function('opt_step'): + self.vae_opt.optimizer.zero_grad(set_to_none=True) + self.disc_opt.optimizer.zero_grad(set_to_none=True) + + with maybe_record_function('trainer_log'): + # [metric logging] + if it == 0 or it in metric_lg.log_iters: + Lpip = Lpip.item() + Lnll = Lrec_for_log + Lpip + metric_lg.update(L1=Lrec_for_log, NLL=Lnll, Ld=Ld, Wg=wei_g, acc_real=acc_real, acc_fake=acc_fake, gnm=grad_norm_g, dnm=grad_norm_d) + + # [tensorboard logging] + if loggable: + Lbcr, Lq, Le, Lg = Lbcr.item(), Lq.item(), Le if isinstance(Le, (int, float)) else Le.item(), Lg.item() + + # vae_vocab_size = self.vae_wo_ddp.vocab_size + # prob_per_class_is_chosen = idx_N.bincount() + # prob_per_class_is_chosen = F.pad(prob_per_class_is_chosen, pad=(0, vae_vocab_size-prob_per_class_is_chosen.shape[0]), mode='constant', value=0).float() / prob_per_class_is_chosen.sum() + # log_perplexity = (-(prob_per_class_is_chosen * torch.log(prob_per_class_is_chosen + 1e-10)).sum()) + # cluster_usage = (prob_per_class_is_chosen > 0.05 / vae_vocab_size).float().mean() * 100 + kw = dict( + # total=Lnll + Lq + self.wei_disc * Lg, + Nll=Lnll, RecL1=Lrec_for_log, quant=Lq, + # z_log_perplex=log_perplexity, z_voc_usage=cluster_usage + ) + kw[f'z_voc_usage'] = usage + if Le > 1e-6: kw['entropy'] = Le + if Lpip > 1e-6: kw['Lpip'] = Lpip + tb_lg.update(head='PT_iter_V_loss', step=g_it, **kw) + + if warmup_disc_schedule > 0: + kw = dict(Disc=Ld-Lbcr-Lreg, bcr=Lbcr, give_vae=Lg) + if Lreg > 1e-6: kw['regR1'] = Lreg + tb_lg.update(head='PT_iter_D_loss', step=g_it, **kw) + tb_lg.update( + head='PT_iter_pred', + logits_real=logits_real.data.mean(), logits_fake=logits_fake.data.mean(), + logits_L1dis_normed=F.l1_loss(logits_real.data, logits_fake.data).mul_(3.0178) / (logits_real.data.abs().mean() + logits_fake.data.abs().mean()), + acc_real=acc_real, acc_fake=acc_fake, step=g_it + ) + + tb_lg.update(head='PT_iter_schedule', warm_disc=warmup_disc_schedule, fade_blur=fade_blur_schedule, step=g_it) + + return grad_norm_g, scale_log2_g, grad_norm_d, scale_log2_d + + def __repr__(self): + return ( + f'\n' + f'[{type(self).__name__}.config]: {pformat(self.get_config(), indent=2, width=250)}\n' + f'[{type(self).__name__}.structure]: {super(VAETrainer, self).__repr__().replace(VAETrainer.__name__, "")}' + ) + + # p_ema = p_ema*0.9 + p*0.1 <==> p_ema.lerp_(p, 0.1) + # p_ema.mul_(self.ema_ratio).add_(p.mul(self.ema_ratio_1)) + # @profile(precision=4, stream=open('ema_update.log', 'w+')) + def ema_update(self, g_it): + ema_ratio = min(self.ema_ratio, (g_it//2 + 1) / (g_it//2 + 10)) + for p_ema, p in zip(self.vae_ema.parameters(), self.vae_wo_ddp.parameters()): + if p.requires_grad: + p_ema.data.mul_(ema_ratio).add_(p.data, alpha=1-ema_ratio) + for p_ema, p in zip(self.vae_ema.buffers(), self.vae_wo_ddp.buffers()): + p_ema.data.copy_(p.data) + quant, quant_ema = self.vae_wo_ddp.quantize, self.vae_ema.quantize + quant: VectorQuantizer + if hasattr(quant, 'using_ema') and quant.using_ema: # then embedding.weight requires no grad, thus is not in self.vae_ema_params; so need to update it manually + if hasattr(quant, 'using_restart') and quant.using_restart: + # cannot use ema, cuz quantize.embedding uses replacement (rand restart) + quant_ema.embedding.weight.data.copy_(quant.embedding.weight.data) + else: + quant_ema.embedding.weight.data.mul_(ema_ratio).add_(quant.embedding.weight.data, alpha=1-ema_ratio) + + def get_config(self): + return { + 'ema_ratio': self.ema_ratio, + 'dcrit': self.dcrit, + 'wei_l1': self.wei_l1, 'wei_l2': self.wei_l2, 'wei_lpips': self.wei_lpips, 'wei_disc': self.wei_disc, + 'bcr': self.bcr, 'reg': self.reg, + } + + def state_dict(self): + state = {'config': self.get_config()} + for k in ('vae_wo_ddp', 'vae_ema', 'disc_wo_ddp', 'vae_opt', 'disc_opt'): + m = getattr(self, k) + if m is not None: + if hasattr(m, '_orig_mod'): + m = m._orig_mod + state[k] = m.state_dict() + return state + + def load_state_dict(self, state, strict=True): + for k in ('vae_wo_ddp', 'vae_ema', 'disc_wo_ddp', 'vae_opt', 'disc_opt'): + m = getattr(self, k) + if m is not None: + if hasattr(m, '_orig_mod'): + m = m._orig_mod + ret = m.load_state_dict(state[k], strict=strict) + if ret is not None: + missing, unexpected = ret + print(f'[VAETr.load_state_dict] {k} missing: {missing}') + print(f'[VAETr.load_state_dict] {k} unexpected: {unexpected}') + config: dict = state.pop('config', None) + if config is not None: + for k, v in self.get_config().items(): + if config.get(k, None) != v: + err = f'[VAETr.load_state_dict] config mismatch: this.{k}={v} (ckpt.{k}={config.get(k, None)})' + if strict: + raise AttributeError(err) + else: + print(err) diff --git a/utils/amp_opt.py b/utils/amp_opt.py new file mode 100644 index 0000000..da82701 --- /dev/null +++ b/utils/amp_opt.py @@ -0,0 +1,115 @@ +import math +from typing import List, Optional, Tuple, Union + +import torch +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + +from utils import misc +from utils.log_param import get_param_for_log + + +class NullCtx: + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + +class AmpOptimizer: + def __init__( + self, + model_name_3letters: str, model_maybe_fsdp: Union[torch.nn.Module, FSDP], fp16: bool, bf16: bool, zero: int, + optimizer: torch.optim.Optimizer, grad_clip: float, n_gradient_accumulation: int = 1, + ): + self.model_name_3letters = model_name_3letters + self.model_maybe_fsdp = model_maybe_fsdp + self.zero = zero + self.enable_amp = fp16 or bf16 + self.using_fp16_rather_bf16 = fp16 + + if self.enable_amp: + self.amp_ctx = torch.autocast('cuda', enabled=True, dtype=torch.float16 if self.using_fp16_rather_bf16 else torch.bfloat16, cache_enabled=self.zero == 0) + self.scaler = torch.cuda.amp.GradScaler(init_scale=2. ** 11, growth_interval=1000) if self.using_fp16_rather_bf16 else None # only fp16 needs a scaler + else: + self.amp_ctx = NullCtx() + self.scaler = None + + self.optimizer = optimizer + self.grad_clip = grad_clip + self.early_clipping = self.grad_clip > 0 and not hasattr(optimizer, 'global_grad_norm') + self.late_clipping = self.grad_clip > 0 and hasattr(optimizer, 'global_grad_norm') + + self.r_accu = 1 / n_gradient_accumulation # r_accu == 1.0 / n_gradient_accumulation + + def backward_clip_step( + self, stepping: bool, loss: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], Optional[float]]: + # backward + loss = loss.mul(self.r_accu) # r_accu == 1.0 / n_gradient_accumulation + orig_norm = scaler_sc = None + if self.scaler is not None: + self.scaler.scale(loss).backward(retain_graph=False, create_graph=False) + else: + loss.backward(retain_graph=False, create_graph=False) + # print('===' * 100) + # for n, p in self.model_maybe_fsdp.named_parameters(): + # if p.stride() != p.grad.stride(): + # print(n) + # print(p.stride(), p.grad.stride()) + # print(p.shape, p.grad.shape) + # print(p.is_contiguous(), p.grad.is_contiguous()) + # print('*' * 50) + # print('===' * 100) + if stepping: + if self.scaler is not None: self.scaler.unscale_(self.optimizer) + if self.early_clipping: + if self.zero: + orig_norm: Optional[torch.Tensor] = self.model_maybe_fsdp.clip_grad_norm_(self.grad_clip) + else: + orig_norm: Optional[torch.Tensor] = torch.nn.utils.clip_grad_norm_(self.model_maybe_fsdp.parameters(), self.grad_clip) + + if self.scaler is not None: + self.scaler.step(self.optimizer) + scaler_sc: Optional[float] = self.scaler.get_scale() + if scaler_sc > 65536.: # fp16 will overflow when >65536, so multiply 65536 could be dangerous + self.scaler.update(new_scale=65536.) + else: + self.scaler.update() + try: + scaler_sc = float(math.log2(scaler_sc)) + except Exception as e: + print(f'[scaler_sc = {scaler_sc}]\n' * 15, flush=True) + raise e + else: + self.optimizer.step() + + if self.late_clipping: + orig_norm: Optional[torch.Tensor] = self.optimizer.global_grad_norm + + self.optimizer.zero_grad(set_to_none=True) + + return orig_norm, scaler_sc + + @torch.no_grad() + def log_param(self, ep: int, tb_lg: misc.TensorboardLogger): + if self.zero == 0: + for name, values in get_param_for_log(self.model_name_3letters, self.model_maybe_fsdp.named_parameters()).items(): + values: List[float] + if len(values) == 1: # e.g., cls token will only have one value + values.append(values[0]) + tb_lg.log_tensor_as_distri(name, torch.tensor(values, dtype=torch.float32), step=ep+1) + + def state_dict(self): + return { + 'optimizer': self.optimizer.state_dict() + } if self.scaler is None else { + 'scaler': self.scaler.state_dict(), + 'optimizer': self.optimizer.state_dict() + } + + def load_state_dict(self, state, strict=True): + if self.scaler is not None: + try: self.scaler.load_state_dict(state['scaler']) + except Exception as e: print(f'[fp16 load_state_dict err] {e}') + self.optimizer.load_state_dict(state['optimizer']) diff --git a/utils/arg_util.py b/utils/arg_util.py new file mode 100644 index 0000000..7109021 --- /dev/null +++ b/utils/arg_util.py @@ -0,0 +1,355 @@ +import json +import math +import os +import os.path as osp +import random +import re +import subprocess +import sys +import time +from collections import OrderedDict +from typing import Optional, Union + +import numpy as np +import torch + +try: + from tap import Tap +except ImportError as e: + print(f'`>>>>>>>> from tap import Tap` failed, please run: pip3 install typed-argument-parser <<<<<<<<', file=sys.stderr, flush=True) + print(f'`>>>>>>>> from tap import Tap` failed, please run: pip3 install typed-argument-parser <<<<<<<<', file=sys.stderr, flush=True) + time.sleep(5) + raise e + +import dist + +class Args(Tap): + exp_name: str # MUST BE specified as `-`, e.g., vlip-exp1_cnn_lr1e-4 + bed: str # MUST BE specified, Bytenas Experiment Directory + resume: str = '' # if specified, load this checkpoint; if not, load the latest checkpoint in bed (if existing) + lpips_path: str = '' # lpips VGG model weights + dino_path: str = '' # vit_small_patch16_224.pth model weights + val_img_pattern: str = '' + data: str = 'o_cc' # datasets, split by - or _, o: openimages, cc: cc12m, co: coco, fa: face data(ffhq+HumanArt+afhq+Internal), mj: midjourney, p: pinterest, px: (pexels+pixabay+unsplash) + + # speed-up: torch.compile + zero: int = 0 # todo: FSDP zero 2/3 + compile_vae: int = 0 # torch.compile VAE; =0: not compile; 1: compile with 'reduce-overhead'; 2: compile with 'max-autotune' + compile_disc: int = 0 # torch.compile discriminator; =0: not compile; 1: compile with 'reduce-overhead'; 2: compile with 'max-autotune' + compile_lpips: int = 0 # torch.compile LPIPS; =0: not compile; 1: compile with 'reduce-overhead'; 2: compile with 'max-autotune' + # speed-up: ddp + ddp_static: bool = False # whether to use static graph in DDP + # speed-up: large batch + vae_grad_ckpt: bool = False # gradient checkpointing + disc_grad_ckpt: bool = False # gradient checkpointing + grad_accu: int = 1 # gradient accumulation + prof: bool = False # whether to do profile + profall: bool = False # whether to do profile on all ranks + tos_profiler_file_prefix: str = '' + + # VAE: vitamin or cnn + vae: str = 'cnn' # 's', 'b', 'l' for using vitamin; 'cnn', 'conv', or '' for using CNN + drop_path: float = 0.1 # following https://github.com/Beckschen/ViTamin/blob/76f1b1524ce03fcaa3449c7db678711f0961ebc2/ViTamin/open_clip/model_configs/ViTamin-L.json#L9 + # VAE: CNN encoder and CNN decoder + ch: int = 160 + drop_out: float = 0.05 + # VAE: quantization layer + vocab_size: int = 4096 + vocab_width: int = 32 + vocab_norm: bool = False + vq_beta: float = 0.25 # commitment loss weight + + # DINO discriminator + dino_depth: int = 12 # 12: use all layers + dino_kernel_size: int = 9 # 9 is stylegan-T's setting + disc_norm: str = 'sbn' # gn: group norm, bn: batch norm, sbn: sync batch norm, hbn: hybrid sync batch norm + disc_spec_norm: bool = True # whether to use SpectralNorm on Conv1ds in discriminator + disc_aug_prob: float = 1.0 # discriminator augmentation probability (see models/vae/diffaug.py) + disc_start_ep: float = 0 # start using disc loss for VAE after dep epochs; =0: will be automatically set to 0.22 * args.ep + disc_warmup_ep: float = 0 # disc loss warm up epochs; =0: will be automatically set to 0.02 * args.ep + reg: float = 0.0 # [NOT IMPLEMENTED YET] float('KEVIN_LOCAL' in os.environ) # discriminator r1 regularization (grad penalty), =10 + reg_every: int = 4 # [NOT IMPLEMENTED YET] + + # initialization + vae_init: float = -0.5 # <0: xavier_normal_(gain=abs(init)); >0: trunc_normal_(std=init) + vocab_init: float = -1 # <0: uniform(-abs(init)*base, abs(init)*base), where base = 20/vocab_size; >0: trunc_normal_(std=init) + disc_init: float = 0.02 # <0: xavier_normal_(gain=abs(init)); >0: trunc_normal_(std=init) + + # optimization + fp16: bool = False + bf16: bool = False + vae_lr: float = 3e-4 # learning rate + disc_lr: float = 3e-4 # learning rate + vae_wd: float = 0.005 # weight decay + disc_wd: float = 0.0005 # weight decay + grad_clip: float = 10 # <=0 for not using grad clip + ema: float = 0.9999 # ema ratio + + warmup_ep: float = 0 # lr warmup: epochs + wp0: float = 0.005 # lr warmup: initial lr ratio + sche: str = 'cos' # lr schedule type + sche_end: float = 0.3 # lr schedule: final lr ratio + + ep: int = 250 # epochs + lbs: int = 0 # local batch size (exclusive to --bs) if this is specified, --bs will be ignored + bs: int = 768 # global batch size (exclusive to --lbs) + + opt: str = 'adamw' # adamw, lamb, or lion: https://cloud.tencent.com/developer/article/2336657?areaId=106001 lr=5e-5 (0.25x) wd=0.8 (8x); Lion needs a large bs to work + oeps: float = 0 + fuse_opt: bool = torch.cuda.is_available() # whether to use fused optimizer + vae_opt_beta: str = '0.5_0.9' # beta1, beta2 of optimizer + disc_opt_beta: str = '0.5_0.9' # beta1, beta2 of optimizer + + # gan optimization + l1: float = 0.2 # L1 rec loss weight + l2: float = 1.0 # L2 rec loss weight + lp: float = 0.5 # lpips loss weight (WOULD BE *2 TO ADAPT LEGACY) + lpr: int = 48 # only calculate lpips >= this image resolution + ld: float = 0.4 # discriminator loss weight; if <0: NO ADAPTIVE WEIGHT + le: float = 0.0 # VQ entropy loss weight + gada: int = 1 # 0: local, 1: local+clamp(0.015, 1e4), 2: local+clamp(0.1, 10), 3: local+sqrt; 10, 11, 12, 13: with ema + bcr: float = 4. # balanced Consistency Regularization, used on small dataset with low reso, StyleSwin: 10.0 + bcr_cut: float = 0.2# cutout ratio (0.5: 50% width) + dcrit: str = 'hg' # hg hinge, sp softplus, ln linear + # T: g=ln, d=hg; XL: g=ln, d=hg; Swin: g=sp, d=sp + + # other hps + flash_attn: bool = True # whether to use torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False) + + # data + subset: float = 1.0 # < 1.0 for use subset + img_size: int = 256 + mid_reso: float = 1.125 # aug: first resize to mid_reso = 1.125 * data_load_reso, then crop to data_load_reso + hflip: bool = False # augmentation: horizontal flip + workers: int = 8 # num workers; 0: auto, -1: don't use multiprocessing in DataLoader + + # debug + local_debug: bool = 'KEVIN_LOCAL' in os.environ + dbg_unused: bool = False + dbg_nan: bool = False # 'KEVIN_LOCAL' in os.environ + + # would be automatically set in runtime + cmd: str = ' '.join(sys.argv[1:]) # [automatically set; don't specify this] + branch: str = subprocess.check_output(f'git symbolic-ref --short HEAD 2>/dev/null || git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this] + commit_id: str = subprocess.check_output(f'git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this] + commit_msg: str = (subprocess.check_output(f'git log -1', shell=True).decode('utf-8').strip().splitlines() or ['[unknown]'])[-1].strip() # [automatically set; don't specify this] + + acc_all: float = None # [automatically set; don't specify this] + acc_real: float = None # [automatically set; don't specify this] + acc_fake: float = None # [automatically set; don't specify this] + last_Lnll: float = None # [automatically set; don't specify this] + last_L1: float = None # [automatically set; don't specify this] + last_Ld: float = None # [automatically set; don't specify this] + last_wei_g: float = None# [automatically set; don't specify this] + grad_boom: str = None # [automatically set; don't specify this] + diff: float = None # [automatically set; don't specify this] + diffs: str = '' # [automatically set; don't specify this] + diffs_ema: str = None # [automatically set; don't specify this] + cur_phase: str = '' # [automatically set; don't specify this] + cur_ep: str = '' # [automatically set; don't specify this] + cur_it: str = '' # [automatically set; don't specify this] + remain_time: str = '' # [automatically set; don't specify this] + finish_time: str = '' # [automatically set; don't specify this] + + iter_speed: float = None # [automatically set; don't specify this] + img_per_day: float = None # [automatically set; don't specify this] + max_nvidia_smi: float = 0 # [automatically set; don't specify this] + max_memory_allocated: float = None # [automatically set; don't specify this] + max_memory_reserved: float = None # [automatically set; don't specify this] + num_alloc_retries: int = None # [automatically set; don't specify this] + + # environment + local_out_dir_path: str = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'local_output') # [automatically set; don't specify this] + tb_log_dir_path: str = '...tb-...' # [automatically set; don't specify this] + tb_log_dir_online: str = '...tb-...'# [automatically set; don't specify this] + log_txt_path: str = '...' # [automatically set; don't specify this] + last_ckpt_pth_bnas: str = '...' # [automatically set; don't specify this] + + tf32: bool = True # whether to use TensorFloat32 + device: str = 'cpu' # [automatically set; don't specify this] + seed: int = None # seed + deterministic: bool = False + same_seed_for_all_ranks: int = 0 # this is only for distributed sampler + def seed_everything(self): + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + if self.seed is not None: + print(f'[in seed_everything] {self.deterministic=}', flush=True) + if self.deterministic: + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + torch.use_deterministic_algorithms(True) + os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':16:8' + seed = self.seed + dist.get_rank()*16384 + os.environ['PYTHONHASHSEED'] = str(seed) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + def get_different_generator_for_each_rank(self) -> Optional[torch.Generator]: # for random augmentation + if self.seed is None: return None + g = torch.Generator() + g.manual_seed(self.seed * dist.get_world_size() + dist.get_rank()) + return g + + def compile_model(self, m, fast): + if fast == 0 or self.local_debug or not hasattr(torch, 'compile'): + return m + mode = { + 1: 'reduce-overhead', + 2: 'max-autotune', + 3: 'default', + }[fast] + print(f'[TORCH.COMPILE: {mode=}] compile {type(m)} ...', end='', flush=True) + stt = time.perf_counter() + m = torch.compile(m, mode=mode) + print(f' finished! ({time.perf_counter()-stt:.2f}s)', flush=True, clean=True) + return m + + def state_dict(self, key_ordered=True) -> Union[OrderedDict, dict]: + d = (OrderedDict if key_ordered else dict)() + # self.as_dict() would contain methods, but we only need variables + for k in self.class_variables.keys(): + if k not in {'device'}: # these are not serializable + d[k] = getattr(self, k) + return d + + def load_state_dict(self, d: Union[OrderedDict, dict, str]): + if isinstance(d, str): # for compatibility with old version + d: dict = eval('\n'.join([l for l in d.splitlines() if ' 0: + print(f'======================================================================================') + print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================\n{args.extra_args}') + print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================') + print(f'======================================================================================\n\n') + + # init torch distributed + from utils import misc + os.makedirs(args.local_out_dir_path, exist_ok=True) + dist.init_distributed_mode(local_out_path=args.local_out_dir_path, timeout_minutes=30) + + # set env + args.set_tf32(args.tf32) + args.seed_everything() + args.device = dist.get_device() + + if not torch.cuda.is_available() or (not args.bf16 and not args.fp16): + args.flash_attn = False + + # update args: paths + assert args.bed + if args.exp_name not in args.bed: + args.bed = osp.join(args.bed, f'{args.exp_name}') + args.bed = args.bed.rstrip(osp.sep) + os.makedirs(args.bed, exist_ok=True) + if not args.lpips_path: + args.lpips_path = f'{lyoko.BNAS_DATA}/ckpt_vae/lpips_with_vgg.pth' + if not args.dino_path: + args.dino_path = f'{lyoko.BNAS_DATA}/ckpt_vae/vit_small_patch16_224.pth' + if not args.val_img_pattern: + args.val_img_pattern = f'{lyoko.BNAS_DATA}/ckpt_vae/val_imgs/v*' + if not args.tos_profiler_file_prefix.endswith('/'): + args.tos_profiler_file_prefix += '/' + + # update args: bs, lr, wd + if args.lbs == 0: + args.lbs = max(1, round(args.bs / args.grad_accu / dist.get_world_size())) + args.bs = args.lbs * dist.get_world_size() + args.workers = min(args.workers, args.lbs) + + # args.lr = args.grad_accu * args.base_lr * args.glb_batch_size / 256 + + # update args: warmup + if args.warmup_ep == 0: + args.warmup_ep = args.ep * 0.01 + if args.disc_start_ep == 0: + args.disc_start_ep = args.ep * 0.2 + if args.disc_warmup_ep == 0: + args.disc_warmup_ep = args.ep * 0.02 + + # update args: paths + args.log_txt_path = os.path.join(args.local_out_dir_path, 'log.txt') + args.last_ckpt_pth_bnas = os.path.join(args.bed, f'ckpt-last.pth') + + _reg_valid_name = re.compile(r'[^\w\-+,.]') + tb_name = _reg_valid_name.sub( + '_', + f'tb-{args.exp_name}' + f'__{args.vae}' + f'__b{args.bs}ep{args.ep}{args.opt[:4]}vlr{args.vae_lr:g}wd{args.vae_wd:g}dlr{args.disc_lr:g}wd{args.disc_wd:g}' + ) + + if dist.is_master(): + os.system(f'rm -rf {os.path.join(args.bed, "ready-node*")} {os.path.join(args.local_out_dir_path, "ready-node*")}') + + args.tb_log_dir_path = os.path.join(args.local_out_dir_path, tb_name) + + return args diff --git a/utils/data.py b/utils/data.py new file mode 100644 index 0000000..2ba1d81 --- /dev/null +++ b/utils/data.py @@ -0,0 +1,70 @@ +import PIL.Image as PImage +from PIL import ImageFile +from torchvision.transforms import InterpolationMode, transforms + +PImage.MAX_IMAGE_PIXELS = (1024 * 1024 * 1024 // 4 // 3) * 5 +ImageFile.LOAD_TRUNCATED_IMAGES = False + + +def normalize_01_into_pm1(x): # normalize x from [0, 1] to [-1, 1] by (x*2) - 1 + return x.add(x).add_(-1) + + +def pil_load(path: str, proposal_size): + with open(path, 'rb') as f: + img: PImage.Image = PImage.open(f) + w: int = img.width + h: int = img.height + sh: int = min(h, w) + if sh > proposal_size: + ratio: float = proposal_size / sh + w = round(ratio * w) + h = round(ratio * h) + img.draft('RGB', (w, h)) + img = img.convert('RGB') + return img + + +def build_dataset( + datasets_str: str, subset_ratio: float, final_reso: int, mid_reso=1.125, hflip=False, +): + # build augmentations + mid_reso = round(min(mid_reso, 2) * final_reso) # first resize to mid_reso, then crop to final_reso + train_aug, val_aug = [ + transforms.Resize(mid_reso, interpolation=InterpolationMode.LANCZOS), # transforms.Resize: resize the shorter edge to mid_reso + transforms.RandomCrop((final_reso, final_reso)), + transforms.ToTensor(), normalize_01_into_pm1, + ], [ + transforms.Resize(mid_reso, interpolation=InterpolationMode.LANCZOS), # transforms.Resize: resize the shorter edge to mid_reso + transforms.CenterCrop((final_reso, final_reso)), + transforms.ToTensor(), normalize_01_into_pm1, + ] + if hflip: train_aug.insert(0, transforms.RandomHorizontalFlip()) + train_aug, val_aug = transforms.Compose(train_aug), transforms.Compose(val_aug) + + train_set = UnlabeledImageFolders(datasets_str=datasets_str, subset_ratio=subset_ratio, transform=train_aug) # todo: junfeng; only `train_set` required, no need to create a 'validation_set' + + # log dataset + print(f'[Dataset] {len(train_set)=}') + print_aug(train_aug, '[train]') + print_aug(val_aug, '[val]') + return train_set, val_aug + + +def pil_loader(path): + with open(path, 'rb') as f: + img: PImage.Image = PImage.open(f).convert('RGB') + return img + + +def no_transform(x): return x + + +def print_aug(transform, label): + print(f'Transform {label} = ') + if hasattr(transform, 'transforms'): + for t in transform.transforms: + print(t) + else: + print(transform) + print('---------------------------\n') diff --git a/utils/data_sampler.py b/utils/data_sampler.py new file mode 100644 index 0000000..a4c6abf --- /dev/null +++ b/utils/data_sampler.py @@ -0,0 +1,103 @@ +import numpy as np +import torch +from torch.utils.data.sampler import Sampler + + +class EvalDistributedSampler(Sampler): + def __init__(self, dataset, num_replicas, rank): + seps = np.linspace(0, len(dataset), num_replicas+1, dtype=int) + beg, end = seps[:-1], seps[1:] + beg, end = beg[rank], end[rank] + self.indices = tuple(range(beg, end)) + + def __iter__(self): + return iter(self.indices) + + def __len__(self) -> int: + return len(self.indices) + + +class InfiniteBatchSampler(Sampler): + def __init__(self, dataset_len, batch_size, seed_for_all_rank=0, fill_last=False, shuffle=True, drop_last=False, start_ep=0, start_it=0): + self.dataset_len = dataset_len + self.batch_size = batch_size + self.iters_per_ep = dataset_len // batch_size if drop_last else (dataset_len + batch_size - 1) // batch_size + self.max_p = self.iters_per_ep * batch_size + self.fill_last = fill_last + self.shuffle = shuffle + self.epoch = start_ep + self.same_seed_for_all_ranks = seed_for_all_rank + self.indices = self.gener_indices() + self.start_ep, self.start_it = start_ep, start_it + + def gener_indices(self): + if self.shuffle: + g = torch.Generator() + g.manual_seed(self.epoch + self.same_seed_for_all_ranks) + indices = torch.randperm(self.dataset_len, generator=g).numpy() + else: + indices = torch.arange(self.dataset_len).numpy() + + tails = self.batch_size - (self.dataset_len % self.batch_size) + if tails != self.batch_size and self.fill_last: + tails = indices[:tails] + np.random.shuffle(indices) + indices = np.concatenate((indices, tails)) + + # built-in list/tuple is faster than np.ndarray (when collating the data via a for-loop) + # noinspection PyTypeChecker + return tuple(indices.tolist()) + + def __iter__(self): + self.epoch = self.start_ep + while True: + self.epoch += 1 + p = (self.start_it * self.batch_size) if self.epoch == self.start_ep else 0 + while p < self.max_p: + q = p + self.batch_size + yield self.indices[p:q] + p = q + if self.shuffle: + self.indices = self.gener_indices() + + def __len__(self): + return self.iters_per_ep + + +class DistInfiniteBatchSampler(InfiniteBatchSampler): + def __init__(self, world_size, rank, dataset_len, glb_batch_size, same_seed_for_all_ranks=0, repeated_aug=0, fill_last=False, shuffle=True, start_ep=0, start_it=0): + assert glb_batch_size % world_size == 0 + self.world_size, self.rank = world_size, rank + self.dataset_len = dataset_len + self.glb_batch_size = glb_batch_size + self.batch_size = glb_batch_size // world_size + + self.iters_per_ep = (dataset_len + glb_batch_size - 1) // glb_batch_size + self.fill_last = fill_last + self.shuffle = shuffle + self.repeated_aug = repeated_aug + self.epoch = start_ep + self.same_seed_for_all_ranks = same_seed_for_all_ranks + self.indices = self.gener_indices() + self.start_ep, self.start_it = start_ep, start_it + + def gener_indices(self): + global_max_p = self.iters_per_ep * self.glb_batch_size # global_max_p % world_size must be 0 cuz glb_batch_size % world_size == 0 + # print(f'global_max_p = iters_per_ep({self.iters_per_ep}) * glb_batch_size({self.glb_batch_size}) = {global_max_p}') + if self.shuffle: + g = torch.Generator() + g.manual_seed(self.epoch + self.same_seed_for_all_ranks) + global_indices = torch.randperm(self.dataset_len, generator=g) + if self.repeated_aug > 1: + global_indices = global_indices[:(self.dataset_len + self.repeated_aug - 1) // self.repeated_aug].repeat_interleave(self.repeated_aug, dim=0)[:global_max_p] + else: + global_indices = torch.arange(self.dataset_len) + filling = global_max_p - global_indices.shape[0] + if filling > 0 and self.fill_last: + global_indices = torch.cat((global_indices, global_indices[:filling])) + # global_indices = tuple(global_indices.numpy().tolist()) + + seps = torch.linspace(0, global_indices.shape[0], self.world_size + 1, dtype=torch.int) + local_indices = global_indices[seps[self.rank].item():seps[self.rank + 1].item()].tolist() + self.max_p = len(local_indices) + return local_indices diff --git a/utils/diffaug.py b/utils/diffaug.py new file mode 100644 index 0000000..eb0ef96 --- /dev/null +++ b/utils/diffaug.py @@ -0,0 +1,115 @@ +# this file is taken from https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/training/diffaug.py +import math + +import torch +import torch.nn.functional as F + + +def load_png(file_name: str): + from torchvision.io import read_image + return read_image(file_name).float().div_(255).mul_(2).sub_(1).unsqueeze(0) # to [-1, 1] +def show(tensor): # from [-1, 1] + from torchvision.utils import make_grid + from torchvision.transforms.functional import to_pil_image + if tensor.shape[0] == 1: tensor = tensor[0] + if tensor.ndim == 3: + to_pil_image(tensor.add(1).div_(2).clamp_(0, 1).detach().cpu()).convert('RGB').show() + else: + to_pil_image(make_grid(tensor.add(1).div_(2).clamp_(0, 1).detach().cpu())).convert('RGB').show() + + +class DiffAug(object): + def __init__(self, prob=1.0, cutout=0.2): # todo: swin ratio = 0.5, T&XL = 0.2 + self.grids = {} + self.prob = abs(prob) + self.using_cutout = prob > 0 + self.cutout = cutout + self.img_channels = -1 + self.last_blur_radius = -1 + self.last_blur_kernel_h = self.last_blur_kernel_w = None + + def get_grids(self, B, x, y, dev): + if (B, x, y) in self.grids: + return self.grids[(B, x, y)] + + self.grids[(B, x, y)] = ret = torch.meshgrid( + torch.arange(B, dtype=torch.long, device=dev), + torch.arange(x, dtype=torch.long, device=dev), + torch.arange(y, dtype=torch.long, device=dev), + indexing='ij' + ) + return ret + + def aug(self, BCHW: torch.Tensor, warmup_blur_schedule: float = 0) -> torch.Tensor: + # warmup blurring + if BCHW.dtype != torch.float32: + BCHW = BCHW.float() + if warmup_blur_schedule > 0: + self.img_channels = BCHW.shape[1] + sigma0 = (BCHW.shape[-2] * 0.5) ** 0.5 + sigma = sigma0 * warmup_blur_schedule + blur_radius = math.floor(sigma * 3) # 3-sigma is enough for Gaussian + if blur_radius >= 1: + if self.last_blur_radius != blur_radius: + self.last_blur_radius = blur_radius + gaussian = torch.arange(-blur_radius, blur_radius + 1, dtype=torch.float32, device=BCHW.device) + gaussian = gaussian.mul_(1/sigma).square_().neg_().exp2_() + gaussian.div_(gaussian.sum()) # normalize + self.last_blur_kernel_h = gaussian.view(1, 1, 2*blur_radius+1, 1).repeat(self.img_channels, 1, 1, 1).contiguous() + self.last_blur_kernel_w = gaussian.view(1, 1, 1, 2*blur_radius+1).repeat(self.img_channels, 1, 1, 1).contiguous() + + BCHW = F.pad(BCHW, [blur_radius, blur_radius, blur_radius, blur_radius], mode='reflect') + BCHW = F.conv2d(input=BCHW, weight=self.last_blur_kernel_h, bias=None, groups=self.img_channels) + BCHW = F.conv2d(input=BCHW, weight=self.last_blur_kernel_w, bias=None, groups=self.img_channels) + # BCHW = filter2d(BCHW, f.div_(f.sum())) # no need to specify padding (filter2d will add padding in itself based on filter size) + + if self.prob < 1e-6: + return BCHW + trans, color, cut = torch.rand(3) <= self.prob + trans, color, cut = trans.item(), color.item(), cut.item() + B, dev = BCHW.shape[0], BCHW.device + rand01 = torch.rand(7, B, 1, 1, device=dev) if (trans or color or cut) else None + + raw_h, raw_w = BCHW.shape[-2:] + if trans: + ratio = 0.125 + delta_h = round(raw_h * ratio) + delta_w = round(raw_w * ratio) + translation_h = rand01[0].mul(delta_h+delta_h+1).floor().long() - delta_h + translation_w = rand01[1].mul(delta_w+delta_w+1).floor().long() - delta_w + # translation_h = torch.randint(-delta_h, delta_h+1, size=(B, 1, 1), device=dev) + # translation_w = torch.randint(-delta_w, delta_w+1, size=(B, 1, 1), device=dev) + + grid_B, grid_h, grid_w = self.get_grids(B, raw_h, raw_w, dev) + grid_h = (grid_h + translation_h).add_(1).clamp_(0, raw_h+1) + grid_w = (grid_w + translation_w).add_(1).clamp_(0, raw_w+1) + bchw_pad = F.pad(BCHW, [1, 1, 1, 1, 0, 0, 0, 0]) + BCHW = bchw_pad.permute(0, 2, 3, 1).contiguous()[grid_B, grid_h, grid_w].permute(0, 3, 1, 2).contiguous() + + if color: + BCHW = BCHW.add(rand01[2].unsqueeze(-1).sub(0.5)) + # BCHW.add_(torch.rand(B, 1, 1, 1, dtype=BCHW.dtype, device=dev).sub_(0.5)) + bchw_mean = BCHW.mean(dim=1, keepdim=True) + BCHW = BCHW.sub(bchw_mean).mul(rand01[3].unsqueeze(-1).mul(2)).add_(bchw_mean) + # BCHW.sub_(bchw_mean).mul_(torch.rand(B, 1, 1, 1, dtype=BCHW.dtype, device=dev).mul_(2)).add_(bchw_mean) + bchw_mean = BCHW.mean(dim=(1, 2, 3), keepdim=True) + BCHW = BCHW.sub(bchw_mean).mul(rand01[4].unsqueeze(-1).add(0.5)).add_(bchw_mean) + # BCHW.sub_(bchw_mean).mul_(torch.rand(B, 1, 1, 1, dtype=BCHW.dtype, device=dev).add_(0.5)).add_(bchw_mean) + + if self.using_cutout and cut: + ratio = self.cutout # todo: styleswin ratio = 0.5, T&XL = 0.2 + cutout_h = round(raw_h * ratio) + cutout_w = round(raw_w * ratio) + offset_h = rand01[5].mul(raw_h + (1 - cutout_h % 2)).floor().long() + offset_w = rand01[6].mul(raw_w + (1 - cutout_w % 2)).floor().long() + # offset_h = torch.randint(0, raw_h + (1 - cutout_h % 2), size=(B, 1, 1), device=dev) + # offset_w = torch.randint(0, raw_w + (1 - cutout_w % 2), size=(B, 1, 1), device=dev) + + grid_B, grid_h, grid_w = self.get_grids(B, cutout_h, cutout_w, dev) + grid_h = (grid_h + offset_h).sub_(cutout_h // 2).clamp(min=0, max=raw_h - 1) + grid_w = (grid_w + offset_w).sub_(cutout_w // 2).clamp(min=0, max=raw_w - 1) + mask = torch.ones(B, raw_h, raw_w, dtype=BCHW.dtype, device=dev) + mask[grid_B, grid_h, grid_w] = 0 + BCHW = BCHW.mul(mask.unsqueeze(1)) + + return BCHW diff --git a/utils/log_param.py b/utils/log_param.py new file mode 100644 index 0000000..52bd925 --- /dev/null +++ b/utils/log_param.py @@ -0,0 +1,119 @@ +from collections import defaultdict +from math import log10 +from typing import Dict, List + + +def get_param_for_log(model_name_3letters: str, named_parameters) -> Dict[str, List[float]]: + dists = defaultdict(list) + + for n, p in named_parameters: + n: str + if p.grad is None: continue + post = 'B' if ('.bias' in n or '_bias' in n) else 'W' + + if 'gpt' in model_name_3letters: + if 'word' in n: tag = '0-word' + elif 'norm0_ve' in n: tag = '0-norm0_ve' + elif 'norm0_cond' in n: tag = '0-norm0_cond' + elif 'start' in n: tag, post = '1-start', 'T' + elif 'class_emb' in n: tag, post = '1-cls_emb', 'W' + elif 'cls_token' in n: tag, post = '1-cls', 'T' + elif 'cfg_uncond' in n: tag, post = '1-cond_cfg', 'T' + elif 'cond_sos' in n: tag, post = '1-cond_sos', 'W' + elif 'text_proj_for_sos' in n: tag = '1-text_sos' + elif 'text_proj_for_ca' in n: tag = '1-text_ca' + + elif 'ca_rpb' in n: tag, post = '2-ca_rpb', 'T' + elif 'sa_rpb' in n: tag, post = '2-sa_rpb', 'T' + elif 'start_p' in n or 'pos_start' in n: tag, post = '2-pos_st', 'T' + elif 'abs_pos_embed' in n: tag, post = '2-pos_abs', 'T' + elif 'pos_mlp' in n: tag = '2-pos_mlp' + elif 'lvl_embed' in n: tag, post = '2-pos_lvl', 'T' + elif 'pos_1LC' in n: tag, post = '2-pos_1LC', 'T' + elif 'pos_task' in n: tag, post = '2-pos_task', 'T' + + elif 'get_affine_4num' in n: tag = '1-freq_aff' + elif 'freq_proj' in n: tag, post = '1-freq_prj', 'W' + elif 'task_token' in n: tag, post = '1-task', 'T' + elif 'adaIN_elin' in n: tag = '4-aIN_elin' + elif 'shared_ada_lin' in n: tag = '2-shared_ada_lin' + elif 'ada_lin' in n: tag = '4-ada_lin' + elif 'ada_gss' in n: tag, post = '4-ada_gss', 'T' + elif 'ada_gamma' in n: tag, post = '4-aIN_elin', 'GA' + elif 'ada_beta' in n: tag, post = '4-aIN_elin', 'BE' + elif 'moe_bias' in n: tag, post = '4-moe_bias', 'B' + + elif 'scale_mul' in n: tag, post = '3-2-scale', 'LogMul' + elif 'norm1' in n: tag = '3-1-norm1' + elif 'sa.' in n or 'attn.' in n: tag = '3-2-sa' + elif 'ca.' in n: tag = '3-2-ca' + elif 'gamma1' in n: tag, post = '3-3-gam1', 'GA' + elif 'ca_norm' in n: tag = '3-2-ca_norm' + elif 'ca_gamma' in n: tag, post = '3-3-ca_gam', 'GA' + + elif 'norm2' in n: tag = '4-1-norm1' + elif 'ffn.' in n: tag = '4-2-ffn' + elif 'gamma2_last' in n: tag, post = '4-3-gam2-last', 'GA' + elif 'gamma2' in n: tag, post = '4-3-gam2', 'GA' + + elif 'head_nm' in n: tag = '5-headnm' + elif 'head0' in n: tag = '5-head0' + elif 'head_bias' in n: tag = '5-head_b', 'B' + elif 'head' in n: tag = '5-head' + elif 'up' in n: tag = '5-up' + + else: tag = f'___{n}___' + + elif 'vae' in model_name_3letters: + if 'encoder.' in n or 'decoder.' in n: + i, j = (0, 'enc') if 'encoder.' in n else (7, 'dec') + if 'conv_in' in n: tag = f'{0+i}-{j}_cin' + elif 'down.' in n and '.block' in n: tag = f'{1+i}-{j}_res' + elif 'down.' in n and '.downsample' in n: tag = f'{1+i}-{j}_cdown' + elif 'down.' in n and '.attn' in n: tag = f'{1+i}-{j}_attn' + elif 'up.' in n and '.block' in n: tag = f'{1+i}-{j}_res' + elif 'up.' in n and '.upsample' in n: tag = f'{1+i}-{j}_cup' + elif 'up.' in n and '.attn' in n: tag = f'{1+i}-{j}_attn' + elif 'mid.' in n and '.block' in n: tag = f'{2+i}-{j}_mid_res' + elif 'mid.' in n and '.attn' in n: tag = f'{2+i}-{j}_mid_at' + elif 'norm_out' in n: tag = f'{3+i}-{j}_nout' + elif 'conv_out' in n: tag = f'{3+i}-{j}_cout' + else: tag = f'3-enc___{n}___' + elif 'quant_conv' in n: tag = f'4-quan_pre' + elif 'post_quant_conv' in n: tag = f'6-quan_post' + elif 'quant_proj' in n: tag = f'5-0-quan_pre_proj' + elif 'quant_resi' in n: tag = f'5-2-quan_post_resi' + elif 'post_quant_proj' in n: tag = f'5-2-quan_post_proj' + elif 'quant' in n and 'norm_scale' in n: tag = f'5-1-quan_norm_scale' + elif 'quant' in n and 'embed' in n: tag = f'5-1-quan_emb' + else: + tag = f'uk___{n}___' + + elif 'disc' in model_name_3letters or 'dsc' in model_name_3letters: # discriminator + if 'dwt' in n: tag = '0-dwt' + elif 'from' in n: tag = '0-from' + elif 'resi' in n: tag = '0-resi' + elif 'fpn' in n: tag = '1-fpn' + elif 'down' in n: tag = '2-down' + elif 'head_conv' in n: tag = '3-head_conv' + elif 'head_cls' in n: tag = '4-head_cls' + elif 'norm.' in n: tag = 'x_norm' + elif 'head.' in n: # DinoDisc + tag = n.split('heads.')[-1][0] + if p.ndim == 3: tag += '.conv1d' + else: tag += '.other' + else: # StyleGanDisc + tag = n.rsplit('.', maxsplit=1)[0] + if p.ndim == 4: tag += '.conv' + else: tag += '.other' + + else: tag = f'uk___{n}___' + + m = p.grad.norm().item() + m = log10(m) if m > 1e-9 else -10 + dists[f'Gnorm_{model_name_3letters}.{tag}.{post}'].append(m) + m = p.data.abs().mean().item() + m = log10(m) if m > 1e-9 else -10 + dists[f'Para_{model_name_3letters}.{tag}.{post}'].append(m) + + return dists diff --git a/utils/loss.py b/utils/loss.py new file mode 100644 index 0000000..0fd84e5 --- /dev/null +++ b/utils/loss.py @@ -0,0 +1,63 @@ +import torch +from torch.nn import functional as F + +# usage: like tot_loss = hinge_loss(logits_real) + hinge_loss(-logits_fake) +def hinge_loss(logits: torch.Tensor): return (1 - logits).relu().mean() +def softplus_loss(logits: torch.Tensor): return F.softplus(-logits).mean() +def linear_loss(logits: torch.Tensor): return (-logits).mean() + + +def focal_l1_loss( + pred, target, reduction='none', + alpha=0.2, gamma=1.0, activate='sigmoid', residual=False, weight=None +): + r"""Calculate Focal L1 loss. + + Delving into Deep Imbalanced Regression. In ICML, 2021. + + + Args: + pred (torch.Tensor): The prediction with shape (N, \*). + target (torch.Tensor): The regression target with shape (N, \*). + alpha (float): A balanced form for Focal Loss. Defaults to 0.2. + gamma (float): The gamma for calculating the modulating factor. + Defaults to 1.0. + activate (str): activate methods in Focal loss in {'sigmoid', 'tanh'}. + Defaults to 'sigmoid'. + residual (bool): Whether to use the original l1_loss, i.e., l1 + focal_l1. + Defaults to False. + weight (tensor): Sample-wise reweight of (N, \*) or element-wise + reweight of (1, \*). Defaults to None. + reduction (str): The method used to reduce the loss. + + Returns: + torch.Tensor: The calculated loss + """ + _loss = F.l1_loss(pred, target, reduction='none') + if activate == 'tanh': + loss = _loss * (torch.tanh(alpha * _loss)) ** gamma + else: + loss = _loss * (2. * torch.sigmoid(alpha * _loss) - 1.) ** gamma + if residual: + loss += _loss + + if weight is not None: + loss *= weight.expand_as(loss) + if reduction == 'mean': + loss = loss.mean() + elif reduction == 'sum': + loss = loss.sum() + return loss + + +if __name__ == '__main__': + import matplotlib.pyplot as plt + import torch as tc + x = tc.linspace(-1.3, 1.3, 200) + gt = tc.zeros_like(x) + l1 = F.l1_loss(x, gt, reduction='none') + l2 = F.mse_loss(x, gt, reduction='none') + fl1 = focal_l1_loss(x, gt, reduction='none') + plt.plot(x, l1, 'r', x, l2, 'g', x, fl1, 'b') + plt.show() + \ No newline at end of file diff --git a/utils/lpips.py b/utils/lpips.py new file mode 100644 index 0000000..1625d2d --- /dev/null +++ b/utils/lpips.py @@ -0,0 +1,102 @@ +"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" + +import torch +import torch.nn as nn +from torchvision import models + +class LPIPS(nn.Module): + # Learned perceptual metric + def __init__(self, lpips_path, use_dropout=False): # do not use dropout by default because we use .eval mode by default + super().__init__() + # build models + self.net = Vgg16(requires_grad=False) + self.lins = nn.ModuleList([NetLinLayer(c, use_dropout=use_dropout) for c in [64, 128, 256, 512, 512]]) # c: vgg16 feature dimensions + + # detach parameters & set to eval mode + for param in self.parameters(): + param.requires_grad = False + self.eval() + + # load weights + self.load_state_dict(torch.load(lpips_path, map_location='cpu'), strict=True) + + # register helper tensors + self.register_buffer('shift', torch.tensor([-.030, -.088, -.188], dtype=torch.float32).view(1, 3, 1, 1).contiguous()) + self.register_buffer('scale_inv', 1. / torch.tensor([.458, .448, .450], dtype=torch.float32).view(1, 3, 1, 1).contiguous()) + + def forward(self, inp, rec): + """ + :param inp: image for calculating LPIPS loss, [-1, 1] + :param rec: image for calculating LPIPS loss, [-1, 1] + :return: lpips loss (scalar) + """ + B = inp.shape[0] + inp_and_recs = torch.cat((inp, rec), dim=0).sub(self.shift).mul_(self.scale_inv) # first use dataset_mean,std to denormalize to [-1, 1], then use vgg_inp_mean,std to normalize again + inp_and_recs = self.net(inp_and_recs) # inp_and_recs: List[Tensor], len(inp_and_recs) == 5 + diff = 0. + for inp_and_rec, lin in zip(inp_and_recs, self.lins): + diff += lin.model((normalize_tensor(inp_and_rec[:B]) - normalize_tensor(inp_and_rec[B:])).square_()).mean() + return diff + + +class NetLinLayer(nn.Module): + """ A single linear layer which does a 1x1 conv """ + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + layers = [nn.Dropout(), ] if use_dropout else [nn.Identity()] + layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] + self.model = nn.Sequential(*layers) + + +class Vgg16(torch.nn.Module): + def __init__(self, requires_grad=False): + super(Vgg16, self).__init__() + vgg_pretrained_features = models.vgg16().features + self.slice1 = torch.nn.Sequential(*[vgg_pretrained_features[x] for x in range(4)]) + self.slice2 = torch.nn.Sequential(*[vgg_pretrained_features[x] for x in range(4, 9)]) + self.slice3 = torch.nn.Sequential(*[vgg_pretrained_features[x] for x in range(9, 16)]) + self.slice4 = torch.nn.Sequential(*[vgg_pretrained_features[x] for x in range(16, 23)]) + self.slice5 = torch.nn.Sequential(*[vgg_pretrained_features[x] for x in range(23, 30)]) + self.N_slices = 5 + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x): + h_relu1_2 = self.slice1(x) + h_relu2_2 = self.slice2(h_relu1_2) + h_relu3_3 = self.slice3(h_relu2_2) + h_relu4_3 = self.slice4(h_relu3_3) + h_relu5_3 = self.slice5(h_relu4_3) + return h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3 + # vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) + # out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + # return out + + +def normalize_tensor(x, eps=1e-10): + norm_factor = torch.sum(x.square(), dim=1, keepdim=True).add_(1e-9).sqrt_() + return x / (norm_factor + eps) + + +def main(): + from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD + + l = LPIPS(r'C:\Users\16333\Desktop\PyCharm\vgpt\_vqgan\lpips_with_vgg.pth', IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, use_dropout=False) + # s = l.state_dict() + # for k in ['data_m', 'data_s', 'vgg_inp_m', 'vgg_inp_s_inv']: + # s.pop(k) + # torch.save(s, r'C:\Users\16333\Desktop\PyCharm\vgpt\_vqgan\lpips_with_vgg.pth') + x, y = torch.load(r'C:\Users\16333\Desktop\PyCharm\vgpt\_vqgan\x.pth'), torch.load(r'C:\Users\16333\Desktop\PyCharm\vgpt\_vqgan\y.pth') + y.requires_grad_(True) + loss = l(x, y) + print(f'loss.shape: {loss.shape}') + loss.mean().backward() + a, b = loss.data.flatten() + a, b = round(a.item(), 4), round(b.item(), 4) + assert a == 0.2965, a + assert b == 0.3166, b + + +if __name__ == '__main__': + main() diff --git a/utils/lr_control.py b/utils/lr_control.py new file mode 100644 index 0000000..2cf41ae --- /dev/null +++ b/utils/lr_control.py @@ -0,0 +1,106 @@ +import math +from pprint import pformat +from typing import Tuple, List, Dict, Union + +import torch.nn + +import dist + + +def lr_wd_annealing(sche_type: str, optimizer, peak_lr, wd, cur_it, wp_it, max_it, wp0=0.005, wpe=0.001): + """Decay the learning rate with half-cycle cosine after warmup""" + wp_it = round(wp_it) + + if cur_it < wp_it: + cur_lr = wp0 + (1-wp0) * cur_it / wp_it + else: + pasd = (cur_it - wp_it) / (max_it-1 - wp_it) # [0, 1] + rest = 1 - pasd # [1, 0] + if sche_type == 'cos': + cur_lr = wpe + (1-wpe) * (0.5 + 0.5 * math.cos(math.pi * pasd)) + elif sche_type == 'lin': + T = 0.15; max_rest = 1-T + if pasd < T: cur_lr = 1 + else: cur_lr = wpe + (1-wpe) * rest / max_rest # 1 to wpe + elif sche_type == 'lin0': + T = 0.05; max_rest = 1-T + if pasd < T: cur_lr = 1 + else: cur_lr = wpe + (1-wpe) * rest / max_rest + elif sche_type == 'lin00': + cur_lr = wpe + (1-wpe) * rest + elif sche_type.startswith('lin'): + T = float(sche_type[3:]); max_rest = 1-T + wpe_mid = wpe + (1-wpe) * max_rest + wpe_mid = (1 + wpe_mid) / 2 + if pasd < T: cur_lr = 1 + (wpe_mid-1) * pasd / T + else: cur_lr = wpe + (wpe_mid-wpe) * rest / max_rest + elif sche_type == 'exp': + T = 0.15; max_rest = 1-T + if pasd < T: cur_lr = 1 + else: + expo = (pasd-T) / max_rest * math.log(wpe) + cur_lr = math.exp(expo) + else: + raise NotImplementedError(f'unknown sche_type {sche_type}') + + cur_lr *= peak_lr + inf = 1e6 + min_lr, max_lr = inf, -1 + min_wd, max_wd = inf, -1 + for param_group in optimizer.param_groups: + param_group['lr'] = cur_lr * param_group.get('lr_sc', 1) # 'lr_sc' could be assigned + max_lr = max(max_lr, param_group['lr']) + min_lr = min(min_lr, param_group['lr']) + + param_group['weight_decay'] = wd * param_group.get('wd_sc', 1) + max_wd = max(max_wd, param_group['weight_decay']) + if param_group['weight_decay'] > 0: + min_wd = min(min_wd, param_group['weight_decay']) + + if min_lr == inf: min_lr = -1 + if min_wd == inf: min_wd = -1 + return min_lr, max_lr, min_wd, max_wd + + +def filter_params(model, ndim_dict, nowd_keys=()) -> Tuple[ + List[str], List[torch.nn.Parameter], List[Dict[str, Union[torch.nn.Parameter, float]]] +]: + para_groups, para_groups_dbg = {}, {} + names, paras = [], [] + names_no_grad = [] + count, numel = 0, 0 + for name, para in model.named_parameters(): + name = name.replace('_fsdp_wrapped_module.', '') + if not para.requires_grad: + names_no_grad.append(name) + continue # frozen weights + count += 1 + numel += para.numel() + names.append(name) + paras.append(para) + + if ndim_dict.get(name, 0) == 1 or name.endswith('bias') or any(k in name for k in nowd_keys): + cur_wd_sc, group_name = 0., 'ND' + else: + cur_wd_sc, group_name = 1., 'D' + + if group_name not in para_groups: + para_groups[group_name] = {'params': [], 'wd_sc': cur_wd_sc} + para_groups_dbg[group_name] = {'params': [], 'wd_sc': cur_wd_sc} + para_groups[group_name]['params'].append(para) + para_groups_dbg[group_name]['params'].append(name) + + for g in para_groups_dbg.values(): + g['params'] = pformat(', '.join(g['params']), width=200) + + print(f'[get_param_groups] param_groups = \n{pformat(para_groups_dbg, indent=2, width=240)}\n') + + for rk in range(dist.get_world_size()): + dist.barrier() + if dist.get_rank() == rk: + print(f'[get_param_groups][rank{dist.get_rank()}] {type(model).__name__=} {count=}, {numel=}', flush=True, force=True) + print('') + + assert len(names_no_grad) == 0, f'[get_param_groups] names_no_grad = \n{pformat(names_no_grad, indent=2, width=240)}\n' + del ndim_dict + return names, paras, list(para_groups.values()) diff --git a/utils/misc.py b/utils/misc.py new file mode 100644 index 0000000..3b1ceeb --- /dev/null +++ b/utils/misc.py @@ -0,0 +1,330 @@ +import datetime +import functools +import glob +import os +import subprocess +import sys +import threading +import time +from collections import defaultdict, deque +from typing import Iterator, List, Tuple + +import numpy as np +import pytz +import torch +import torch.distributed as tdist + +import dist +from utils import arg_util + +os_system = functools.partial(subprocess.call, shell=True) +def echo(info): + os_system(f'echo "[$(date "+%m-%d-%H:%M:%S")] ({os.path.basename(sys._getframe().f_back.f_code.co_filename)}, line{sys._getframe().f_back.f_lineno})=> {info}"') +def os_system_get_stdout(cmd): + return subprocess.run(cmd, shell=True, stdout=subprocess.PIPE).stdout.decode('utf-8') +def os_system_get_stdout_stderr(cmd): + cnt = 0 + while True: + try: + sp = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=30) + except subprocess.TimeoutExpired: + cnt += 1 + print(f'[fetch free_port file] timeout cnt={cnt}') + else: + return sp.stdout.decode('utf-8'), sp.stderr.decode('utf-8') + + +def time_str(fmt='[%m-%d %H:%M:%S]'): + return datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime(fmt) + + +class DistLogger(object): + def __init__(self, lg): + self._lg = lg + + @staticmethod + def do_nothing(*args, **kwargs): + pass + + def __getattr__(self, attr: str): + return getattr(self._lg, attr) if self._lg is not None else DistLogger.do_nothing + + +class TensorboardLogger(object): + def __init__(self, log_dir, filename_suffix): + try: import tensorflow_io as tfio + except: pass + from torch.utils.tensorboard import SummaryWriter + self.log_dir = log_dir + self.writer = SummaryWriter(log_dir=log_dir, filename_suffix=filename_suffix) + self.step = 0 + + def set_step(self, step=None): + if step is not None: + self.step = step + else: + self.step += 1 + + def loggable(self): + return self.step == 0 or (self.step + 1) % 500 == 0 + + def update(self, head='scalar', step=None, **kwargs): + if step is None: + step = self.step + if not self.loggable(): return + for k, v in kwargs.items(): + if v is None: continue + if hasattr(v, 'item'): v = v.item() + self.writer.add_scalar(f'{head}/{k}', v, step) + + def log_tensor_as_distri(self, tag, tensor1d, step=None): + if step is None: + step = self.step + if not self.loggable(): return + try: + self.writer.add_histogram(tag=tag, values=tensor1d, global_step=step) + except Exception as e: + print(f'[log_tensor_as_distri writer.add_histogram failed]: {e}') + + def log_image(self, tag, img_chw, step=None): + if step is None: + step = self.step + if not self.loggable(): return + self.writer.add_image(tag, img_chw, step, dataformats='CHW') + + def flush(self): + self.writer.flush() + + def close(self): + print(f'[{type(self).__name__}] file @ {self.log_dir} closed') + self.writer.close() + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=30, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + tdist.barrier() + tdist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + return np.median(self.deque) if len(self.deque) else 0 + + @property + def avg(self): + return sum(self.deque) / (len(self.deque) or 1) + + @property + def global_avg(self): + return self.total / (self.count or 1) + + @property + def max(self): + return max(self.deque) if len(self.deque) else 0 + + @property + def value(self): + return self.deque[-1] if len(self.deque) else 0 + + def time_preds(self, counts) -> Tuple[float, str, str]: + remain_secs = counts * self.median + return remain_secs, str(datetime.timedelta(seconds=round(remain_secs))), time.strftime("%Y-%m-%d %H:%M", time.localtime(time.time() + remain_secs)) + + def __str__(self): + return self.fmt.format(median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value) + + +class MetricLogger(object): + def __init__(self, delimiter=" "): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + self.iter_end_t = time.time() + self.log_iters = set() + + def update(self, **kwargs): + # if it != 0 and it not in self.log_iters: return + for k, v in kwargs.items(): + if v is None: continue + if hasattr(v, 'item'): v = v.item() + # assert isinstance(v, (float, int)), type(v) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + if len(meter.deque): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, start_it, max_iters, itrt, print_freq, header=None): # also solve logging & skipping iterations before start_it + self.log_iters = set(np.linspace(0, max_iters-1, print_freq, dtype=int).tolist()) + self.log_iters.add(start_it) + if not header: + header = '' + start_time = time.time() + self.iter_end_t = time.time() + self.iter_time = SmoothedValue(fmt='{avg:.4f}') + self.data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(max_iters))) + 'd' + log_msg = [ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ] + log_msg = self.delimiter.join(log_msg) + + if isinstance(itrt, Iterator) and not hasattr(itrt, 'preload') and not hasattr(itrt, 'set_epoch'): + for it in range(start_it, max_iters): + obj = next(itrt) + self.data_time.update(time.time() - self.iter_end_t) + yield it, obj + self.iter_time.update(time.time() - self.iter_end_t) + if it in self.log_iters: + eta_seconds = self.iter_time.global_avg * (max_iters - it) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + print(log_msg.format(it, max_iters, eta=eta_string, meters=str(self), time=str(self.iter_time), data=str(self.data_time)), flush=True) + self.iter_end_t = time.time() + else: + if isinstance(itrt, int): itrt = range(itrt) + for it, obj in enumerate(itrt): + if it < start_it: + self.iter_end_t = time.time() + continue + self.data_time.update(time.time() - self.iter_end_t) + yield it, obj + self.iter_time.update(time.time() - self.iter_end_t) + if it in self.log_iters: + eta_seconds = self.iter_time.global_avg * (max_iters - it) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + print(log_msg.format(it, max_iters, eta=eta_string, meters=str(self), time=str(self.iter_time), data=str(self.data_time)), flush=True) + self.iter_end_t = time.time() + + cost = time.time() - start_time + cost_str = str(datetime.timedelta(seconds=int(cost))) + print(f'{header} Cost of this ep: {cost_str} ({cost / (max_iters-start_it):.3f} s / it)', flush=True) + + +class TouchingDaemonDontForgetToStartMe(threading.Thread): + def __init__(self, files: List[str], sleep_secs: int, verbose=False): + super().__init__(daemon=True) + self.files = tuple(files) + self.sleep_secs = sleep_secs + self.is_finished = False + self.verbose = verbose + + f_back = sys._getframe().f_back + file_desc = f'{f_back.f_code.co_filename:24s}'[-24:] + self.print_prefix = f' ({file_desc}, line{f_back.f_lineno:-4d}) @daemon@ ' + + def finishing(self): + self.is_finished = True + + def run(self) -> None: + # stt, logged = time.time(), False + kw = {} + if dist.initialized(): kw['clean'] = True + + stt = time.time() + if self.verbose: print(f'{time_str()}{self.print_prefix}[TouchingDaemon tid={threading.get_native_id()}] start touching {self.files} per {self.sleep_secs}s ...', **kw) + while not self.is_finished: + for f in self.files: + if os.path.exists(f): + try: + os.utime(f) # todo: ByteNAS oncall: change to open(...) for force-updating mtime (use strace to ensure an `open` system call) + fp = open(f, 'a') + fp.close() + except: pass + # else: + # if not logged and self.verbose and time.time() - stt > 180: + # logged = True + # print(f'[TouchingDaemon tid={threading.get_native_id()}] [still alive ...]') + time.sleep(self.sleep_secs) + + if self.verbose: print(f'{time_str()}{self.print_prefix}[TouchingDaemon tid={threading.get_native_id()}] finish touching after {time.time()-stt:.1f} secs {self.files} per {self.sleep_secs}s. ', **kw) + + +def glob_with_latest_modified_first(pattern, recursive=False): + return sorted(glob.glob(pattern, recursive=recursive), key=os.path.getmtime, reverse=True) + + +def auto_resume(args: arg_util.Args, pattern='ckpt*.pth') -> Tuple[List[str], int, int, dict, dict]: + info = [] + file = os.path.join(args.local_out_dir_path, pattern) + all_ckpt = glob_with_latest_modified_first(file) + if len(all_ckpt) == 0: + info.append(f'[auto_resume] no ckpt found @ {file}') + info.append(f'[auto_resume quit]') + return info, 0, 0, {}, {} + else: + info.append(f'[auto_resume] load ckpt from @ {all_ckpt[0]} ...') + ckpt = torch.load(all_ckpt[0], map_location='cpu') + ep, it = ckpt['epoch'], ckpt['iter'] + info.append(f'[auto_resume success] resume from ep{ep}, it{it}') + return info, ep, it, ckpt['trainer'], ckpt['args'] + + +def create_npz_from_sample_folder(sample_folder: str): + """ + Builds a single .npz file from a folder of .png samples. Refer to DiT. + """ + import os, glob + import numpy as np + from tqdm import tqdm + from PIL import Image + + samples = [] + pngs = glob.glob(os.path.join(sample_folder, '*.png')) + glob.glob(os.path.join(sample_folder, '*.PNG')) + assert len(pngs) == 50_000, f'{len(pngs)} png files found in {sample_folder}, but expected 50,000' + for png in tqdm(pngs, desc='Building .npz file from samples (png only)'): + with Image.open(png) as sample_pil: + sample_np = np.asarray(sample_pil).astype(np.uint8) + samples.append(sample_np) + samples = np.stack(samples) + assert samples.shape == (50_000, samples.shape[1], samples.shape[2], 3) + npz_path = f'{sample_folder}.npz' + np.savez(npz_path, arr_0=samples) + print(f'Saved .npz file to {npz_path} [shape={samples.shape}].') + return npz_path diff --git a/utils/nan.py b/utils/nan.py new file mode 100644 index 0000000..f28807a --- /dev/null +++ b/utils/nan.py @@ -0,0 +1,111 @@ +import dist +import torch +from typing import Tuple +from utils import arg_util + + +def debug_nan_grad(model): + print('[debug_nan_grad opened]') + for n, p in list(model.named_parameters()) + list(model.named_buffers()): + if p.requires_grad and p.grad is not None: + if p.grad.isnan().any(): + err = f"[rk{dist.get_rank()}] [{n} (type={type(p)}, shape={tuple(p.shape)}, num={p.grad.isnan().sum().item()}/{p.numel()})] grad has NAN" + print(err, flush=True, force=True, deeper=True) + raise AttributeError(err) + if p.grad.isinf().any(): + err = f"[rk{dist.get_rank()}] [{n} (type={type(p)}, shape={tuple(p.shape)}, num={p.grad.isinf().sum().item()}/{p.numel()})] grad has INF" + print(err, flush=True, force=True, deeper=True) + raise AttributeError(err) + + +def debug_nan_param(model): + print('[debug_nan_param opened]') + for n, p in list(model.named_parameters()) + list(model.named_buffers()): + if p.data.isnan().any(): + err = f"[rk{dist.get_rank()}] [{n} (type={type(p)}, shape={tuple(p.shape)}, num={p.isnan().sum().item()}/{p.numel()})] param has NAN" + print(err, flush=True, force=True, deeper=True) + raise AttributeError(err) + if p.data.isinf().any() and 'attn_bias' not in n and 'attn_mask' not in n: + err = f"[rk{dist.get_rank()}] [{n} (type={type(p)}, shape={tuple(p.shape)}, num={p.isinf().sum().item()}/{p.numel()})] param has INF" + print(err, flush=True, force=True, deeper=True) + raise AttributeError(err) + + +def debug_nan_hook(model): + print('[debug_nan_hook opened]') + + Tensors = Tuple[torch.Tensor] + + def pre_f_hook(module, inps: Tensors): + if not module.training: + return + if inps is not None: + for x in inps: + if isinstance(x, torch.Tensor): + d = x.data + if d.isnan().any(): + err = f"[rk{dist.get_rank()}] [module={type(module)}] [==preforward==] inps has NAN (shape={tuple(d.shape)}, num={d.isnan().sum().item()}/{d.numel()})" + print(err, flush=True, force=True, deeper=True) + raise AttributeError(err) + if d.isinf().any(): + err = f"[rk{dist.get_rank()}] [module={type(module)}] [==preforward==] inps has INF (shape={tuple(d.shape)}, num={d.isinf().sum().item()}/{d.numel()})" + print(err, flush=True, force=True, deeper=True) + raise AttributeError(err) + # return inps + + def f_hook(module, inps: Tensors, oups: Tensors): + if not module.training: + return + if oups is not None: + for x in oups: + if isinstance(x, torch.Tensor): + d = x.data + if d.isnan().any(): + err = f"[rk{dist.get_rank()}] [module={type(module)}] [==forward==] oups has NAN (shape={tuple(d.shape)}, num={d.isnan().sum().item()}/{d.numel()})" + print(err, flush=True, force=True, deeper=True) + raise AttributeError(err) + if d.isinf().any(): + err = f"[rk{dist.get_rank()}] [module={type(module)}] [==forward==] oups has INF (shape={tuple(d.shape)}, num={d.isinf().sum().item()}/{d.numel()})" + print(err, flush=True, force=True, deeper=True) + raise AttributeError(err) + # return oups + + def b_hook(module, g_inps: Tensors, g_oups: Tensors): + if not module.training: + return + if g_inps is not None: + for x in g_inps: + if isinstance(x, torch.Tensor): + d = x.data + if d.isnan().any(): + err = f"[rk{dist.get_rank()}][ [module={type(module)}] ==backward==] g_inps has NAN (shape={tuple(d.shape)}, num={d.isnan().sum().item()}/{d.numel()})" + print(err, flush=True, force=True, deeper=True) + raise AttributeError(err) + if d.isinf().any(): + err = f"[rk{dist.get_rank()}][ [module={type(module)}] ==backward==] g_inps has INF (shape={tuple(d.shape)}, num={d.isinf().sum().item()}/{d.numel()})" + print(err, flush=True, force=True, deeper=True) + raise AttributeError(err) + if g_oups is not None: + for x in g_oups: + if isinstance(x, torch.Tensor): + d = x.data + if d.isnan().any(): + err = f"[rk{dist.get_rank()}][ [module={type(module)}] ==backward==] g_oups has NAN (shape={tuple(d.shape)}, num={d.isnan().sum().item()}/{d.numel()})" + print(err, flush=True, force=True, deeper=True) + raise AttributeError(err) + if d.isinf().any(): + err = f"[rk{dist.get_rank()}][ [module={type(module)}] ==backward==] g_oups has INF (shape={tuple(d.shape)}, num={d.isinf().sum().item()}/{d.numel()})" + print(err, flush=True, force=True, deeper=True) + raise AttributeError(err) + # return g_inps + + for n, m in model.named_modules(): + # if not isinstance(m, (torch.nn.Linear, torch.nn.LayerNorm, torch.nn.Conv2d, torch.nn.Identity, torch.nn.ModuleList, modules.DropPath)): + if not isinstance(m, (torch.nn.Identity, torch.nn.ModuleList)): + m.register_forward_pre_hook(pre_f_hook) + m.register_forward_hook(f_hook) + # [nan]: 为什么这两个类不能做register_backward_hook? 不过反正现在算完梯度马上检查梯度了,没必要做backward_hook了 + # https://www.cnblogs.com/sddai/p/14412250.html + + # if not isinstance(m, (modules.AttentionBlock, modules.ImageWiseCrossAttentionBlock)): + # m.register_backward_hook(b_hook) diff --git a/utils/optimizer.py b/utils/optimizer.py new file mode 100644 index 0000000..0c05ac0 --- /dev/null +++ b/utils/optimizer.py @@ -0,0 +1,207 @@ +import math +from typing import Tuple + +import torch +# from deepspeed.ops import lamb +from torch.optim.optimizer import Optimizer + + +class LAMBtimm(Optimizer): + """Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB + reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py + + LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its norm. (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + grad_averaging (bool, optional): whether apply (1-beta2) to grad when + calculating running averages of gradient. (default: True) + max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0) + trust_clip (bool): enable LAMBC trust ratio clipping (default: False) + always_adapt (boolean, optional): Apply adaptive learning rate to 0.0 + weight decay parameter (default: False) + + .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: + https://arxiv.org/abs/1904.00962 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__( + self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-7, + weight_decay=0.01, grad_averaging=True, max_grad_norm=2.0, trust_clip=False, always_adapt=False): + defaults = dict( + lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay, + grad_averaging=grad_averaging, max_grad_norm=max_grad_norm, + trust_clip=trust_clip, always_adapt=always_adapt) + super().__init__(params, defaults) + self.global_grad_norm = torch.tensor(0.1) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + device = self.param_groups[0]['params'][0].device + one_tensor = torch.tensor(1.0, dtype=torch.float32, device=device) # because torch.where doesn't handle scalars correctly + global_grad_norm = torch.full(size=(1,), fill_value=1e-12, dtype=torch.float32, device=device) + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') + global_grad_norm.add_(grad.pow(2).sum()) + + global_grad_norm = torch.sqrt(global_grad_norm) + self.global_grad_norm = global_grad_norm + max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], dtype=torch.float32, device=device) + clip_global_grad_norm = 1 / torch.where( + global_grad_norm > max_grad_norm, + global_grad_norm / max_grad_norm, + one_tensor) + + for group in self.param_groups: + bias_correction = 1 if group['bias_correction'] else 0 + beta1, beta2 = group['betas'] + grad_averaging = 1 if group['grad_averaging'] else 0 + beta3 = 1 - beta1 if grad_averaging else 1.0 + + # assume same step across group now to simplify things + # per parameter step can be easily support by making it tensor, or pass list into kernel + if 'step' in group: + group['step'] += 1 + else: + group['step'] = 1 + + if bias_correction: + bias_correction1 = 1 - beta1 ** group['step'] + bias_correction2 = 1 - beta2 ** group['step'] + else: + bias_correction1, bias_correction2 = 1.0, 1.0 + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.mul_(clip_global_grad_norm) + state = self.state[p] + + # State initialization + if len(state) == 0: + # Exponential moving average of gradient valuesa + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=beta3) # m_t + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # v_t + + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + update = (exp_avg / bias_correction1).div_(denom) + + weight_decay = group['weight_decay'] + if weight_decay != 0: + update.add_(p, alpha=weight_decay) + + if weight_decay != 0 or group['always_adapt']: + # Layer-wise LR adaptation. By default, skip adaptation on parameters that are + # excluded from weight decay, unless always_adapt == True, then always enabled. + w_norm = p.norm(2.0) + g_norm = update.norm(2.0) + trust_ratio = torch.where( + w_norm > 0, + torch.where(g_norm > 0, w_norm / g_norm, one_tensor), + one_tensor, + ) + if group['trust_clip']: + # LAMBC trust clipping, upper bound fixed at one + trust_ratio = torch.minimum(trust_ratio, one_tensor) + update.mul_(trust_ratio) + + p.add_(update, alpha=-group['lr']) + + return loss + + +class Lion(Optimizer): + def __init__( + self, + params, + lr: float = 1e-4, + betas: Tuple[float, float] = (0.9, 0.99), + weight_decay: float = 0.0, + use_triton: bool = False + ): + assert lr > 0. + assert all([0. <= beta <= 1. for beta in betas]) + + defaults = dict( + lr=lr, + betas=betas, + weight_decay=weight_decay + ) + + super().__init__(params, defaults) + + def update_fn(self, p, grad, exp_avg, lr, wd, beta1, beta2): + # stepweight decay + p.data.mul_(1 - lr * wd) + + # weight update + update = exp_avg.clone().mul_(beta1).add(grad, alpha=1 - beta1).sign_() + p.add_(update, alpha=-lr) + + # decay the momentum running average coefficient + exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2) + + @torch.no_grad() + def step( + self, + closure=None + ): + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in filter(lambda p: p.grad is not None, group['params']): + + grad, lr, wd, beta1, beta2, state = p.grad, group['lr'], group['weight_decay'], *group['betas'], self.state[p] + + # init state - exponential moving average of gradient values + + if len(state) == 0: + state['exp_avg'] = torch.zeros_like(p) + + exp_avg = state['exp_avg'] + + self.update_fn( + p, + grad, + exp_avg, + lr, + wd, + beta1, + beta2 + ) + + return loss diff --git a/vis.py b/vis.py new file mode 100644 index 0000000..dc5edd0 --- /dev/null +++ b/vis.py @@ -0,0 +1,157 @@ +import time +import warnings +from typing import List, Tuple + +import PIL.Image as PImage, PIL.ImageDraw as PImageDraw +import numpy as np +import seaborn as sns +import torch +import torch.nn.functional as F +import torchvision +from matplotlib.colors import ListedColormap + +from dist import for_visualize +from trainer import VAETrainer +from utils import misc + + +class Visualizer(object): + def __init__(self, enable: bool, device, trainer: VAETrainer): + self.enable = enable + if enable: + self.trainer: VAETrainer + self.device, self.trainer = device, trainer + # self.data_m = torch.tensor(dataset_mean, dtype=torch.float32, device=self.device).view(1, 3, 1, 1) + # self.data_s = torch.tensor(dataset_std, dtype=torch.float32, device=self.device).view(1, 3, 1, 1) + + self.inp_B3HW: torch.Tensor = ... + self.bound_mask: torch.Tensor = ... + self.cmap_div: ListedColormap = sns.color_palette('mako', as_cmap=True) + self.cmap_div: ListedColormap = sns.color_palette('icefire', as_cmap=True) + self.cmap_seq = ListedColormap(sns.color_palette('ch:start=.2, rot=-.3', as_cmap=True).colors[::-1]) + self.cmap_seq: ListedColormap = sns.color_palette('RdBu_r', as_cmap=True) + self.cmap_sim: ListedColormap = sns.color_palette('viridis', as_cmap=True) + + @for_visualize + def vis_prologue(self, inp_B3HW: torch.Tensor) -> None: + if not self.enable: return + self.inp_B3HW = inp_B3HW + + # self.bound_mask = get_boundary(self.patch_size, self.vis_needs_loss_BL) + # todo: multi scale log + # imgs = {} + # denormed_inp = self.vgpt_wo_ddp.denormalize(self.ls_inp_B3HW) + # bchw = denormed_inp + # # mean = (self.bound_mask * denormed_inp).sum(dim=(2, 3), keepdim=True) / self.bound_mask.sum(dim=(2, 3), keepdim=True) # BC11 + # # self.bound_mask = self.bound_mask * (1 - mean * 0.99) # BCHW + # # bchw = torch.where(self.bound_mask > 0, self.bound_mask, denormed_inp) + # chw = torchvision.utils.make_grid(bchw, padding=2, pad_value=1, nrow=10) + # imgs[f'1_gt'] = chw + # if log_inp: + # tb_lg.log_image(f'1_gt', chw, step=start_ep) + # tb_lg.flush() + # return imgs + + def denormalize(self, BCHW): + # BCHW = BCHW * self.data_s + # BCHW += self.data_m + return BCHW.add(1).mul_(0.5).clamp_(0, 1) + + @for_visualize + def vis(self, tb_lg: misc.TensorboardLogger, ep: int, png_path: str) -> Tuple[float, float]: + if not self.enable: return -1., -1. + vis_stt = time.time() + warnings.filterwarnings('ignore', category=DeprecationWarning) + + # get recon + B = self.inp_B3HW.shape[0] + with torch.inference_mode(): + rec_B3HW_ema = self.trainer.vae_ema.img_to_reconstructed_img(self.inp_B3HW) + training = self.trainer.vae_wo_ddp.training + self.trainer.vae_wo_ddp.eval() + rec_B3HW = self.trainer.vae_wo_ddp.img_to_reconstructed_img(self.inp_B3HW) + self.trainer.vae_wo_ddp.train(training) + + L1_ema = F.l1_loss(rec_B3HW_ema, self.inp_B3HW).item() + L1 = F.l1_loss(rec_B3HW, self.inp_B3HW).item() + Lpip_ema = self.trainer.lpips_loss(rec_B3HW_ema, self.inp_B3HW).item() + Lpip = self.trainer.lpips_loss(rec_B3HW, self.inp_B3HW).item() + diff_ema = (L1_ema + Lpip_ema) / 2 + diff = (L1 + Lpip) / 2 + ema_better = diff_ema < diff + + # calc loss for logging + tb_lg.update( + head='PT_viz', step=ep+1, + Diff=diff, Diff_ema=diff_ema, + L1rec=L1, L1rec_ema=L1_ema, + Lpips=Lpip, Lpips_ema=Lpip_ema, + z_ema_adv=diff - diff_ema + ) + + # viz + H, W = rec_B3HW.shape[-2], rec_B3HW.shape[-1] + cmp_grid = torchvision.utils.make_grid(self.denormalize(torch.cat((self.inp_B3HW, rec_B3HW_ema, rec_B3HW), dim=0)), padding=0, pad_value=1, nrow=B) + tb_lg.log_image('Raw_RecEMA_Rec', cmp_grid, step=ep+1) + if png_path: + chw = cmp_grid.permute(1, 2, 0).mul_(255).cpu().numpy() + chw = PImage.fromarray(chw.astype(np.uint8)) + if not chw.mode == 'RGB': + chw = chw.convert('RGB') + PImageDraw.Draw(chw).text((H//10, W//10), (f'EMA {ep+1}' if ema_better else f'SELF {ep+1}'), (10, 10, 10)) + chw.save(png_path) + + # dt = self.trainer.disc_wo_ddp.training + # self.trainer.disc_wo_ddp.eval() + # todo: 这个地方disc网络绝对是不要求梯度的状态,因为每个iter开始的时候,都是先disc要求,再disc不要求,再return该iter,换句话说,disc仅在forward内部会要求梯度 + # todo: vis + # for (inp, rec, rec2) in zip(self.ls_inp_B3HW, ls_rec_B3HW, ls_rec_BCHW2): inp.requires_grad = rec.requires_grad = rec2.requires_grad = True + # self.trainer.d_criterion(self.trainer.disc_wo_ddp( torch.cat(ls_inp + ls_rec1 + ls_rec2, dim=0) )).backward() + # self.trainer.disc_wo_ddp.train(dt) + + # for rec in ls_rec_B3HW: + # # if inp.grad is not None: + # # grad_i, grad_r = inp.grad.mean(dim=1), rec.grad.mean(dim=1) + # # inp.requires_grad = rec.requires_grad = False + # # del inp.grad, rec.grad + # # inp.grad = rec.grad = None + # # grad_i = grad_i.sub(grad_i.mean()).div_(grad_i.std()+1e-5).mul_(0.3).add_(0.5) + # # grad_r = grad_r.sub(grad_r.mean()).div_(grad_r.std()+1e-5).mul_(0.3).add_(0.5) + # # grad_i = torch.from_numpy(self.cmap_div(grad_i.cpu().numpy())[:, :, :, :3]).to(device=inp.device, dtype=inp.dtype).permute(0, 3, 1, 2) + # # grad_r = torch.from_numpy(self.cmap_div(grad_r.cpu().numpy())[:, :, :, :3]).to(device=inp.device, dtype=inp.dtype).permute(0, 3, 1, 2) + # # ls = [self.denormalize(inp), self.denormalize(rec), grad_i, grad_r] + # # else: + # ls = [self.denormalize(inp), self.denormalize(rec)] + # + # tb_lg.log_image(f'A_{rec.shape[-2]}', torchvision.utils.make_grid(torch.cat(ls, dim=0), padding=1, pad_value=1, nrow=B), step=ep+1) + # if png_path: pngs.append(torchvision.utils.make_grid(torch.cat(( + # F.interpolate(self.denormalize(inp), final_reso, mode='nearest'), + # F.interpolate(self.denormalize(rec), final_reso, mode='nearest'), + # ), dim=0), padding=1, pad_value=1, nrow=B)) + + # self.trainer.vae_wo_ddp.vis_key_params(tb_lg, ep) + # self.trainer.disc_wo_ddp.vis_key_params(tb_lg, ep) + + print(f' [*] [vis] {L1=:.3f}, {Lpip=:.3f} | {L1_ema=:.3f}, {Lpip_ema=:.3f} cost={time.time()-vis_stt:.2f}s', force=True) + + warnings.resetwarnings() + return min(diff, diff_ema) + + +# import numba as nb +# @nb.jit(nopython=True, nogil=True, fastmath=True) +def get_boundary(patch_size, needs_loss, boundary_wid=3): # vis_img: BCHW, needs_loss: BL + """ + get the boundary of `False`-value connected components on given boolmap `needs_loss` + """ + B, L = needs_loss.shape + hw = round(L ** 0.5) + boolmap = (~needs_loss).view(B, 1, hw, hw) # BL => B1hw + boolmap = boolmap.repeat_interleave(patch_size, dim=2).repeat_interleave(patch_size, dim=3) # B1hw => B1HW + + k_size = boundary_wid * 2 + 1 + conv_kernel = torch.ones(1, 1, k_size, k_size).to(boolmap.device) + bound_mask = F.conv2d(boolmap.float(), conv_kernel, padding=boundary_wid) + bound_mask = ((bound_mask - k_size ** 2).abs() < 0.1) ^ boolmap # B1HW + + return bound_mask.float()