Skip to content

Commit

Permalink
removed einsum dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
mk314k committed Nov 25, 2023
1 parent bf26541 commit c9103bf
Show file tree
Hide file tree
Showing 26 changed files with 84 additions and 732 deletions.
Binary file added .DS_Store
Binary file not shown.
6 changes: 2 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,11 @@ This project focuses on the problem of reconstructing 3D shapes from a single 2D
## Dependencies
- Python
- Pytorch
- einops
- fancy_einsum
- sklearn

## Results
<img src="test/3dcar.png" width="300" height="300"/> | <img src="test/r3dcar.png" width="300" height="300"/> | <img src="test/car.png" width="300" height="300"/>
<img src="test/3dphone.png" width="300" height="300"/> | <img src="test/r3dphone.png" width="300" height="300"/> | <img src="test/phone.png" width="300" height="300"/>
<img src="test/3dcar.png" width="200" height="200"/> | <img src="test/r3dcar.png" width="200" height="200"/> | <img src="test/car.png" width="200" height="200"/>
<img src="test/3dphone.png" width="200" height="200"/> | <img src="test/r3dphone.png" width="200" height="200"/> | <img src="test/phone.png" width="200" height="200"/>

fig. 3d image (given), 3d image (generated), 2d image (given) from left to right

Expand Down
Binary file added data/.DS_Store
Binary file not shown.
Binary file added data/shapenetcore/.DS_Store
Binary file not shown.
24 changes: 13 additions & 11 deletions models/attention.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""
This Module includes general implementation of Multiheaded Attention
Author:mk314k
"""
import torch
from torch import nn
from fancy_einsum import einsum


class R3DAttention(nn.Module):
"""
R3DAttention module performs multi-head attention computation on 3D data.
R3DAttention module performs multi-head attention computation on Image Patches.
Args:
hidden_size (int): The dimensionality of the input embeddings.
Expand Down Expand Up @@ -43,29 +43,31 @@ def forward(self, x: torch.Tensor, mask=None) -> torch.Tensor:
Returns:
torch.Tensor: Output tensor after attention computation (batch*view, patch, embedding).
"""
b, p, _ = x.shape
bv, p, _ = x.shape

# Projecting input into query, key, and value representations
qkv = self.qkv_proj(x).reshape(b, p, 3, self.n_head, self.head_size)
q, k, v = qkv.chunk(dim=2)
qkv = self.qkv_proj(x).reshape(bv, p, 3, self.n_head, self.head_size)
qkv = qkv.permute(2,0,3,1,4) #(3, bv, n, p, h)
q, k, v = qkv.chunk(3)

# Calculating attention scores
attn_score = einsum("b n pq s, b n pk s->b n pq pk", q, k) / (self.head_size ** 0.5)
# Calculating attention scores (bv, n, p, p)
attn_score = q[0].matmul(k[0].transpose(-1,-2))/ (self.head_size ** 0.5)

# Applying optional mask to attention scores
if mask is not None:
attn_score -= mask

# Computing attention probabilities and apply dropout
attn_prob = attn_score.softmax(dim=-1)
attn_prob = attn_score.softmax(dim=-1) #(bv, n, p, p)
attn_prob = self.attn_dropout(attn_prob)

# Weighted sum of values using attention probabilities
z = einsum("b n pq pk, b n pk s ->b pq n s", attn_prob, v)
z = z.reshape((z.shape[0], z.shape[1], -1))
z = attn_prob.matmul(v[0]) #(bv, n, p, h)
z = z.permute(0, 2, 1, 3) #(bv, p, n, h)
z = z.reshape((bv, p, -1)) #(bv, n, e)

# Projecting back to the original space and applying residual dropout
out = self.output_proj(z)
out = self.output_proj(z) #(bv, n, e)
out = self.resid_dropout(out)

return out
32 changes: 15 additions & 17 deletions models/discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
import torch
from torch import nn

def conv3d(channel):
"""_summary_
def conv3d(in_channel:int, out_channel:int):
"""
Conv3d structure with default hyperparameters
Args:
channel (_type_): _description_
in_channel (int): channel of image before this layer
out_channel (int): channel of image after this layer
Returns:
_type_: _description_
nn.Module: Conv3d Module
"""
return nn.Sequential(
nn.Conv3d(channel[0], channel[1], kernel_size=4, stride=2, padding=1, bias=False),
Expand All @@ -21,32 +22,29 @@ def conv3d(channel):


class R3Discriminator(nn.Module):
"""_summary_
Args:
nn (_type_): _description_
"""
Discriminator Module
"""
def __init__(self):
super().__init__()
self.out_channels = 512
self.out_dim = 4
self.conv1 = conv3d((1, 64))
self.conv2 = conv3d((64, 128))
self.conv3 = conv3d((128, 256))
self.conv4 = conv3d((256, 512))
self.conv1 = conv3d(1, 64)
self.conv2 = conv3d(64, 128)
self.conv3 = conv3d(128, 256)
self.conv4 = conv3d(256, 512)
self.out = nn.Sequential(
nn.Linear(512 * self.out_dim * self.out_dim * self.out_dim, 1),
nn.Sigmoid(),
)

def forward(self, x:torch.Tensor)->torch.Tensor:
"""_summary_
"""
Args:
x (torch.Tensor): _description_
x (torch.Tensor): 3d image
Returns:
torch.Tensor: _description_
torch.Tensor: real/fake value of image
"""
x = self.conv1(x)
x = self.conv2(x)
Expand Down
241 changes: 0 additions & 241 deletions renv/bin/Activate.ps1

This file was deleted.

Loading

0 comments on commit c9103bf

Please sign in to comment.