diff --git a/tests/nn/modules/test_img_patch_embed.py b/tests/nn/modules/test_img_patch_embed.py new file mode 100644 index 00000000..2f38d2d3 --- /dev/null +++ b/tests/nn/modules/test_img_patch_embed.py @@ -0,0 +1,76 @@ +# FILEPATH: /Users/defalt/Desktop/Athena/research/zeta/tests/nn/modules/test_img_patch_embed.py + +import pytest +from torch import nn +import torch +from zeta.nn.modules.img_patch_embed import ImgPatchEmbed + + +def test_class_init(): + model = ImgPatchEmbed() + + assert isinstance(model.proj, nn.Conv2d) + assert model.img_size == 224 + assert model.patch_size == 16 + assert model.num_patches == 196 + + +def test_class_init_with_args(): + model = ImgPatchEmbed( + img_size=448, patch_size=32, in_chans=1, embed_dim=512 + ) + + assert isinstance(model.proj, nn.Conv2d) + assert model.img_size == 448 + assert model.patch_size == 32 + assert model.num_patches == 196 + assert model.proj.in_channels == 1 + assert model.proj.out_channels == 512 + + +def test_forward(): + model = ImgPatchEmbed() + x = torch.randn(1, 3, 224, 224) + out = model(x) + + assert out.shape == torch.Size([1, 196, 768]) + + +def test_forward_with_different_input(): + model = ImgPatchEmbed() + x = torch.randn(2, 3, 224, 224) + out = model(x) + + assert out.shape == torch.Size([2, 196, 768]) + + +def test_forward_with_different_img_size(): + model = ImgPatchEmbed(img_size=448) + x = torch.randn(1, 3, 448, 448) + out = model(x) + + assert out.shape == torch.Size([1, 196, 768]) + + +def test_forward_with_different_patch_size(): + model = ImgPatchEmbed(patch_size=32) + x = torch.randn(1, 3, 224, 224) + out = model(x) + + assert out.shape == torch.Size([1, 49, 768]) + + +def test_forward_with_different_in_chans(): + model = ImgPatchEmbed(in_chans=1) + x = torch.randn(1, 1, 224, 224) + out = model(x) + + assert out.shape == torch.Size([1, 196, 768]) + + +def test_forward_with_different_embed_dim(): + model = ImgPatchEmbed(embed_dim=512) + x = torch.randn(1, 3, 224, 224) + out = model(x) + + assert out.shape == torch.Size([1, 196, 512]) diff --git a/tests/nn/modules/test_simple_mamba.py b/tests/nn/modules/test_simple_mamba.py index c6c90f35..bcf20cfd 100644 --- a/tests/nn/modules/test_simple_mamba.py +++ b/tests/nn/modules/test_simple_mamba.py @@ -5,6 +5,7 @@ from torch import nn from zeta.nn.modules.simple_mamba import Mamba, ResidualBlock, RMSNorm + def test_mamba_class_init(): model = Mamba(10000, 512, 6) @@ -13,6 +14,7 @@ def test_mamba_class_init(): assert isinstance(model.norm_f, RMSNorm) assert isinstance(model.lm_head, nn.Linear) + def test_mamba_forward(): model = Mamba(10000, 512, 6) x = torch.randint(0, 10000, (1, 50)) @@ -20,6 +22,7 @@ def test_mamba_forward(): assert out.shape == torch.Size([1, 50, 10000]) + def test_residual_block_class_init(): block = ResidualBlock(512) @@ -28,6 +31,7 @@ def test_residual_block_class_init(): assert isinstance(block.fc1, nn.Linear) assert isinstance(block.fc2, nn.Linear) + def test_residual_block_forward(): block = ResidualBlock(512) x = torch.randn(1, 50, 512) @@ -35,6 +39,7 @@ def test_residual_block_forward(): assert out.shape == torch.Size([1, 50, 512]) + def test_mamba_different_vocab_size(): model = Mamba(20000, 512, 6) x = torch.randint(0, 20000, (1, 50)) @@ -42,6 +47,7 @@ def test_mamba_different_vocab_size(): assert out.shape == torch.Size([1, 50, 20000]) + def test_mamba_different_dim(): model = Mamba(10000, 1024, 6) x = torch.randint(0, 10000, (1, 50)) @@ -49,6 +55,7 @@ def test_mamba_different_dim(): assert out.shape == torch.Size([1, 50, 10000]) + def test_mamba_different_depth(): model = Mamba(10000, 512, 12) x = torch.randint(0, 10000, (1, 50)) @@ -56,6 +63,7 @@ def test_mamba_different_depth(): assert out.shape == torch.Size([1, 50, 10000]) + def test_residual_block_different_dim(): block = ResidualBlock(1024) x = torch.randn(1, 50, 1024) @@ -63,6 +71,7 @@ def test_residual_block_different_dim(): assert out.shape == torch.Size([1, 50, 1024]) + def test_mamba_with_dropout(): model = Mamba(10000, 512, 6, dropout=0.5) x = torch.randint(0, 10000, (1, 50)) @@ -70,6 +79,7 @@ def test_mamba_with_dropout(): assert out.shape == torch.Size([1, 50, 10000]) + def test_residual_block_with_dropout(): block = ResidualBlock(512, dropout=0.5) x = torch.randn(1, 50, 512) @@ -77,6 +87,7 @@ def test_residual_block_with_dropout(): assert out.shape == torch.Size([1, 50, 512]) + def test_mamba_with_custom_layer(): class CustomLayer(nn.Module): def forward(self, x): @@ -86,4 +97,4 @@ def forward(self, x): x = torch.randint(0, 10000, (1, 50)) out = model(x) - assert out.shape == torch.Size([1, 50, 10000]) \ No newline at end of file + assert out.shape == torch.Size([1, 50, 10000]) diff --git a/zeta/nn/biases/relative_position_bias.py b/zeta/nn/biases/relative_position_bias.py index aae02239..d5110cb5 100644 --- a/zeta/nn/biases/relative_position_bias.py +++ b/zeta/nn/biases/relative_position_bias.py @@ -6,6 +6,7 @@ import torch from torch import nn + class RelativePositionBias(nn.Module): def __init__( self, diff --git a/zeta/nn/modules/__init__.py b/zeta/nn/modules/__init__.py index a94e436f..3f33195e 100644 --- a/zeta/nn/modules/__init__.py +++ b/zeta/nn/modules/__init__.py @@ -46,6 +46,7 @@ from zeta.nn.modules.visual_expert import VisualExpert from zeta.nn.modules.yolo import yolo from zeta.nn.modules.swiglu import SwiGLU, SwiGLUStacked +from zeta.nn.modules.img_patch_embed import ImgPatchEmbed # from zeta.nn.modules.img_reshape import image_reshape # from zeta.nn.modules.flatten_features import flatten_features @@ -111,4 +112,5 @@ "AdaptiveLayerNorm", "SwiGLU", "SwiGLUStacked", + "ImgPatchEmbed", ] diff --git a/zeta/nn/modules/img_patch_embed.py b/zeta/nn/modules/img_patch_embed.py new file mode 100644 index 00000000..dcfd7e68 --- /dev/null +++ b/zeta/nn/modules/img_patch_embed.py @@ -0,0 +1,45 @@ +from torch import nn + + +class ImgPatchEmbed(nn.Module): + """patch embedding module + + + Args: + img_size (int, optional): image size. Defaults to 224. + patch_size (int, optional): patch size. Defaults to 16. + in_chans (int, optional): input channels. Defaults to 3. + embed_dim (int, optional): embedding dimension. Defaults to 768. + + Examples: + >>> x = torch.randn(1, 3, 224, 224) + >>> model = ImgPatchEmbed() + >>> model(x).shape + torch.Size([1, 196, 768]) + + + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + num_patches = (img_size // patch_size) * (img_size // patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size + ) + + def forward(self, x): + """Forward + + Args: + x (_type_): _description_ + + Returns: + _type_: _description_ + """ + B, C, H, W = x.shape + x = self.proj(x).flatten(2).transpose(1, 2) + return x diff --git a/zeta/nn/modules/simple_mamba.py b/zeta/nn/modules/simple_mamba.py index 7f0c60fc..27d21e3c 100644 --- a/zeta/nn/modules/simple_mamba.py +++ b/zeta/nn/modules/simple_mamba.py @@ -6,7 +6,6 @@ from typing import Optional, Union - # [HELPERS] ---------------------------------------------------------------------------------------- class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-5): @@ -57,8 +56,6 @@ def forward(self, x): return output - - class Mamba(nn.Module): def __init__( self, vocab_size: int = None, dim: int = None, depth: int = None @@ -98,7 +95,6 @@ class MambaLMHeadModel, https://github.com/state-spaces/mamba/blob/main/mamba_ss return logits - class MambaBlock(nn.Module): def __init__( self, @@ -107,7 +103,7 @@ def __init__( depth: int, d_state: int = 16, expand: int = 2, - dt_rank: Union[int, str] = 'auto', + dt_rank: Union[int, str] = "auto", d_conv: int = 4, conv_bias: bool = True, bias: bool = False, @@ -136,7 +132,6 @@ def __init__( self.A_log = nn.Parameter(torch.log(A)) self.D = nn.Parameter(torch.ones(dim_inner)) self.out_proj = nn.Linear(dim_inner, dim, bias=bias) - def forward(self, x): """Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba paper [1]. @@ -260,4 +255,3 @@ def selective_scan(self, u, delta, A, B, C, D): y = y + u * rearrange(D, "d_in -> d_in 1") return y - diff --git a/zeta/utils/disable_logging.py b/zeta/utils/disable_logging.py index c4bcc12c..4e9eb8df 100644 --- a/zeta/utils/disable_logging.py +++ b/zeta/utils/disable_logging.py @@ -1,13 +1,55 @@ +# import logging +# import os +# import warnings + + +# def disable_warnings_and_logs(): +# """ +# Disables various warnings and logs. +# """ +# # disable warnings +# warnings.filterwarnings("ignore") + +# # disable tensorflow warnings +# os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" + +# # disable bnb warnings and others +# logging.getLogger().setLevel(logging.WARNING) + +# class CustomFilter(logging.Filter): +# def filter(self, record): +# unwanted_logs = [ +# "Setting ds_accelerator to mps (auto detect)", +# ( +# "NOTE: Redirects are currently not supported in Windows or" +# " MacOs." +# ), +# ] +# return not any(log in record.getMessage() for log in unwanted_logs) + +# # add custom filter to root logger +# logger = logging.getLogger() +# f = CustomFilter() +# logger.addFilter(f) + +# # disable specific loggers +# loggers = [ +# "real_accelerator", +# "torch.distributed.elastic.multiprocessing.redirects", +# ] + +# for logger_name in loggers: +# logger = logging.getLogger(logger_name) +# logger.setLevel(logging.CRITICAL) + + import logging import os import warnings - def disable_warnings_and_logs(): - """Disable warnings and logs. - - Returns: - _type_: _description_ + """ + Disables various warnings and logs. """ # disable warnings warnings.filterwarnings("ignore") @@ -20,12 +62,19 @@ def disable_warnings_and_logs(): class CustomFilter(logging.Filter): def filter(self, record): - msg = "Created a temporary directory at" - return msg not in record.getMessage() + unwanted_logs = [ + "Setting ds_accelerator to mps (auto detect)", + ( + "NOTE: Redirects are currently not supported in Windows or" + " MacOs." + ), + ] + return not any(log in record.getMessage() for log in unwanted_logs) + # add custom filter to root logger logger = logging.getLogger() f = CustomFilter() logger.addFilter(f) - -disable_warnings_and_logs() + # disable all loggers + logging.disable(logging.CRITICAL) \ No newline at end of file