-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
195 additions
and
94 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,55 +1,71 @@ | ||
"""_summary_ | ||
""" | ||
This Module includes general implementation of Multiheaded Attention | ||
""" | ||
import torch | ||
import torch.nn as nn | ||
from torch import nn | ||
from fancy_einsum import einsum | ||
|
||
|
||
class R3DAttention(nn.Module): | ||
"""_summary_ | ||
""" | ||
R3DAttention module performs multi-head attention computation on 3D data. | ||
Args: | ||
nn (_type_): _description_ | ||
hidden_size (int): The dimensionality of the input embeddings. | ||
num_heads (int): The number of attention heads to use. | ||
dropout (float, optional): Dropout rate to prevent overfitting (default: 0.1). | ||
""" | ||
|
||
def __init__(self, hidden_size: int, num_heads: int, dropout=0.1): | ||
super(R3DAttention, self).__init__() | ||
|
||
# Calculating the size of each attention head | ||
head_size = int(hidden_size / num_heads) | ||
self.n_head = num_heads | ||
self.head_size = head_size | ||
|
||
# Projection layers for Q, K, and V | ||
self.qkv_proj = nn.Linear(hidden_size, 3 * hidden_size) | ||
self.output_proj = nn.Linear(hidden_size, hidden_size) | ||
|
||
# Dropout layers | ||
self.attn_dropout = nn.Dropout(dropout) | ||
self.resid_dropout = nn.Dropout(dropout) | ||
|
||
def forward(self, x: torch.Tensor, mask = None) -> torch.Tensor: | ||
def forward(self, x: torch.Tensor, mask=None) -> torch.Tensor: | ||
""" | ||
Perform multi-head scaled dot-product attention on the input. | ||
Args: | ||
x (torch.Tensor): (batch*view, patch, embedding) | ||
x (torch.Tensor): Input tensor of shape (batch*view, patch, embedding). | ||
mask (torch.Tensor, optional): Mask tensor for masking attention scores (default: None). | ||
Returns: | ||
torch.Tensor: (batch*view, patch, embedding) | ||
torch.Tensor: Output tensor after attention computation (batch*view, patch, embedding). | ||
""" | ||
b, p, _ = x.shape | ||
|
||
# Projecting input into query, key, and value representations | ||
qkv = self.qkv_proj(x).reshape(b, p, 3, self.n_head, self.head_size) | ||
q, k, v = qkv.chunk(dim=2) | ||
#b-batch, p-patch, c-constant(3), n-num_heads, s-head_size | ||
attn_score = einsum("b n pq s, b n pk s->b n pq pk", q, k) / ( | ||
self.head_size**0.5 | ||
) | ||
if mask: | ||
|
||
# Calculating attention scores | ||
attn_score = einsum("b n pq s, b n pk s->b n pq pk", q, k) / (self.head_size ** 0.5) | ||
|
||
# Applying optional mask to attention scores | ||
if mask is not None: | ||
attn_score -= mask | ||
|
||
# Computing attention probabilities and apply dropout | ||
attn_prob = attn_score.softmax(dim=-1) | ||
attn_prob = self.attn_dropout(attn_prob) | ||
|
||
# Weighted sum of values using attention probabilities | ||
z = einsum("b n pq pk, b n pk s ->b pq n s", attn_prob, v) | ||
z = z.reshape((z.shape[0], z.shape[1], -1)) | ||
|
||
# Projecting back to the original space and applying residual dropout | ||
out = self.output_proj(z) | ||
out = self.resid_dropout(out) | ||
return out | ||
|
||
|
||
# if __name__ == "__main__": | ||
# attn = R3DAttention(64, 4) | ||
# def count_par(model): | ||
# return sum(p.numel() for p in model.parameters() if p.requires_grad) | ||
# print(count_par(attn)) | ||
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.