-
-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[FEAT][ImgPatchEmbed] [chore][disable_warnings_and_logs]
- Loading branch information
Kye
committed
Dec 21, 2023
1 parent
bbb360a
commit 6a550fc
Showing
7 changed files
with
195 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# FILEPATH: /Users/defalt/Desktop/Athena/research/zeta/tests/nn/modules/test_img_patch_embed.py | ||
|
||
import pytest | ||
from torch import nn | ||
import torch | ||
from zeta.nn.modules.img_patch_embed import ImgPatchEmbed | ||
|
||
|
||
def test_class_init(): | ||
model = ImgPatchEmbed() | ||
|
||
assert isinstance(model.proj, nn.Conv2d) | ||
assert model.img_size == 224 | ||
assert model.patch_size == 16 | ||
assert model.num_patches == 196 | ||
|
||
|
||
def test_class_init_with_args(): | ||
model = ImgPatchEmbed( | ||
img_size=448, patch_size=32, in_chans=1, embed_dim=512 | ||
) | ||
|
||
assert isinstance(model.proj, nn.Conv2d) | ||
assert model.img_size == 448 | ||
assert model.patch_size == 32 | ||
assert model.num_patches == 196 | ||
assert model.proj.in_channels == 1 | ||
assert model.proj.out_channels == 512 | ||
|
||
|
||
def test_forward(): | ||
model = ImgPatchEmbed() | ||
x = torch.randn(1, 3, 224, 224) | ||
out = model(x) | ||
|
||
assert out.shape == torch.Size([1, 196, 768]) | ||
|
||
|
||
def test_forward_with_different_input(): | ||
model = ImgPatchEmbed() | ||
x = torch.randn(2, 3, 224, 224) | ||
out = model(x) | ||
|
||
assert out.shape == torch.Size([2, 196, 768]) | ||
|
||
|
||
def test_forward_with_different_img_size(): | ||
model = ImgPatchEmbed(img_size=448) | ||
x = torch.randn(1, 3, 448, 448) | ||
out = model(x) | ||
|
||
assert out.shape == torch.Size([1, 196, 768]) | ||
|
||
|
||
def test_forward_with_different_patch_size(): | ||
model = ImgPatchEmbed(patch_size=32) | ||
x = torch.randn(1, 3, 224, 224) | ||
out = model(x) | ||
|
||
assert out.shape == torch.Size([1, 49, 768]) | ||
|
||
|
||
def test_forward_with_different_in_chans(): | ||
model = ImgPatchEmbed(in_chans=1) | ||
x = torch.randn(1, 1, 224, 224) | ||
out = model(x) | ||
|
||
assert out.shape == torch.Size([1, 196, 768]) | ||
|
||
|
||
def test_forward_with_different_embed_dim(): | ||
model = ImgPatchEmbed(embed_dim=512) | ||
x = torch.randn(1, 3, 224, 224) | ||
out = model(x) | ||
|
||
assert out.shape == torch.Size([1, 196, 512]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
import torch | ||
from torch import nn | ||
|
||
|
||
class RelativePositionBias(nn.Module): | ||
def __init__( | ||
self, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from torch import nn | ||
|
||
|
||
class ImgPatchEmbed(nn.Module): | ||
"""patch embedding module | ||
Args: | ||
img_size (int, optional): image size. Defaults to 224. | ||
patch_size (int, optional): patch size. Defaults to 16. | ||
in_chans (int, optional): input channels. Defaults to 3. | ||
embed_dim (int, optional): embedding dimension. Defaults to 768. | ||
Examples: | ||
>>> x = torch.randn(1, 3, 224, 224) | ||
>>> model = ImgPatchEmbed() | ||
>>> model(x).shape | ||
torch.Size([1, 196, 768]) | ||
""" | ||
|
||
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): | ||
super().__init__() | ||
num_patches = (img_size // patch_size) * (img_size // patch_size) | ||
self.img_size = img_size | ||
self.patch_size = patch_size | ||
self.num_patches = num_patches | ||
|
||
self.proj = nn.Conv2d( | ||
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size | ||
) | ||
|
||
def forward(self, x): | ||
"""Forward | ||
Args: | ||
x (_type_): _description_ | ||
Returns: | ||
_type_: _description_ | ||
""" | ||
B, C, H, W = x.shape | ||
x = self.proj(x).flatten(2).transpose(1, 2) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters