Skip to content

Commit

Permalink
requirement
Browse files Browse the repository at this point in the history
  • Loading branch information
mk314k committed Nov 21, 2023
1 parent 7369208 commit a7bbeb0
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 11 deletions.
1 change: 1 addition & 0 deletions .github/workflows/pylint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r ../../requirements.txt
pip install pylint
- name: Analysing the code with pylint
run: |
Expand Down
19 changes: 8 additions & 11 deletions models/attention.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
"""_summary_
Returns:
_type_: _description_
"""
import torch
import torch.nn as nn
from fancy_einsum import einsum
from einops import rearrange


class R3DAttention(nn.Module):
Expand Down Expand Up @@ -34,9 +30,10 @@ def forward(self, x: torch.Tensor, mask = None) -> torch.Tensor:
Returns:
torch.Tensor: (batch*view, patch, embedding)
"""
qkv = self.qkv_proj(x)
b, p, _ = x.shape
qkv = self.qkv_proj(x).reshape(b, p, 3, self.n_head, self.head_size)
q, k, v = qkv.chunk(dim=2)
#b-batch, p-patch, c-constant(3), n-num_heads, s-head_size
q, k, v = rearrange(qkv, "b p (c n s)-> c b n p s", c=3, n=self.n_head)
attn_score = einsum("b n pq s, b n pk s->b n pq pk", q, k) / (
self.head_size**0.5
)
Expand All @@ -51,8 +48,8 @@ def forward(self, x: torch.Tensor, mask = None) -> torch.Tensor:
return out


if __name__ == "__main__":
attn = R3DAttention(64, 4)
def count_par(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(count_par(attn))
# if __name__ == "__main__":
# attn = R3DAttention(64, 4)
# def count_par(model):
# return sum(p.numel() for p in model.parameters() if p.requires_grad)
# print(count_par(attn))
9 changes: 9 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
fancy_einsum
torch
torchvision
tqdm
scikit-learn
matplotlib
opencv-python
scipy
numpy

0 comments on commit a7bbeb0

Please sign in to comment.