-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from mk314k/development
Development
- Loading branch information
Showing
28 changed files
with
1,035 additions
and
151 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,38 +1,71 @@ | ||
""" | ||
This Module includes general implementation of Multiheaded Attention | ||
""" | ||
import torch | ||
import torch.nn as nn | ||
from torch import nn | ||
from fancy_einsum import einsum | ||
from einops import rearrange | ||
|
||
|
||
class R3DAttention(nn.Module): | ||
""" | ||
R3DAttention module performs multi-head attention computation on 3D data. | ||
Args: | ||
hidden_size (int): The dimensionality of the input embeddings. | ||
num_heads (int): The number of attention heads to use. | ||
dropout (float, optional): Dropout rate to prevent overfitting (default: 0.1). | ||
""" | ||
|
||
def __init__(self, hidden_size: int, num_heads: int, dropout=0.1): | ||
super(R3DAttention, self).__init__() | ||
self.hidden_size = hidden_size | ||
assert hidden_size % num_heads == 0 | ||
self.num_heads = num_heads | ||
head_size = hidden_size // num_heads | ||
super().__init__() | ||
|
||
# Calculating the size of each attention head | ||
head_size = int(hidden_size / num_heads) | ||
self.n_head = num_heads | ||
self.head_size = head_size | ||
self.qkv_proj = nn.Linear(hidden_size,3*hidden_size) | ||
|
||
# Projection layers for Q, K, and V | ||
self.qkv_proj = nn.Linear(hidden_size, 3 * hidden_size) | ||
self.output_proj = nn.Linear(hidden_size, hidden_size) | ||
|
||
# Dropout layers | ||
self.attn_dropout = nn.Dropout(dropout) | ||
self.resid_dropout = nn.Dropout(dropout) | ||
|
||
def forward(self, x: torch.Tensor, mask=None) -> torch.Tensor: | ||
""" | ||
Perform multi-head scaled dot-product attention on the input. | ||
Args: | ||
x (torch.Tensor): Input tensor of shape (batch*view, patch, embedding). | ||
mask (torch.Tensor, optional): Mask tensor for masking attention scores (default: None). | ||
def forward(self, x: torch.Tensor, cache = None) -> torch.Tensor: | ||
xshap = x.shape | ||
if len(xshap)==3: | ||
x = x.reshape((1,*x.shape)) | ||
b, c, w, h = x.shape | ||
x = x.permute((0,2,3,1)).reshape((b, -1, c)) | ||
qkv = self.qkv_proj(x) | ||
n_shape = qkv.shape | ||
q,k,v = rearrange(qkv,"b s (c n h)-> c b n s h", c=3, n=self.num_heads) | ||
attn_score = einsum("b n sq h, b n sk h->b n sq sk",q,k)/(self.head_size**0.5) | ||
mask = 1e4*torch.triu(torch.ones_like(attn_score[0,0]),diagonal=1) | ||
attn_prob = torch.softmax(attn_score-mask,dim=-1,dtype=x.dtype)#type: ignore | ||
Returns: | ||
torch.Tensor: Output tensor after attention computation (batch*view, patch, embedding). | ||
""" | ||
b, p, _ = x.shape | ||
|
||
# Projecting input into query, key, and value representations | ||
qkv = self.qkv_proj(x).reshape(b, p, 3, self.n_head, self.head_size) | ||
q, k, v = qkv.chunk(dim=2) | ||
|
||
# Calculating attention scores | ||
attn_score = einsum("b n pq s, b n pk s->b n pq pk", q, k) / (self.head_size ** 0.5) | ||
|
||
# Applying optional mask to attention scores | ||
if mask is not None: | ||
attn_score -= mask | ||
|
||
# Computing attention probabilities and apply dropout | ||
attn_prob = attn_score.softmax(dim=-1) | ||
attn_prob = self.attn_dropout(attn_prob) | ||
z = einsum("b n sq sk, b n sk h ->b sq n h",attn_prob,v) | ||
z = torch.reshape(z,(z.shape[0],z.shape[1],-1)) | ||
|
||
# Weighted sum of values using attention probabilities | ||
z = einsum("b n pq pk, b n pk s ->b pq n s", attn_prob, v) | ||
z = z.reshape((z.shape[0], z.shape[1], -1)) | ||
|
||
# Projecting back to the original space and applying residual dropout | ||
out = self.output_proj(z) | ||
out = self.resid_dropout(out) | ||
out = out.reshape((b, w, h,c)).permute((0,3,1,2)) | ||
return out | ||
|
||
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,31 +1,58 @@ | ||
""" | ||
This module includes 3D Discriminator module | ||
""" | ||
import torch | ||
import torch.nn as nn | ||
from torch import nn | ||
|
||
conv3d = lambda channel: nn.Sequential( | ||
nn.Conv3d(channel[0], channel[1], kernel_size=4,stride=2, padding=1, bias=False), | ||
def conv3d(channel): | ||
"""_summary_ | ||
Args: | ||
channel (_type_): _description_ | ||
Returns: | ||
_type_: _description_ | ||
""" | ||
return nn.Sequential( | ||
nn.Conv3d(channel[0], channel[1], kernel_size=4, stride=2, padding=1, bias=False), | ||
nn.BatchNorm3d(channel[1]), | ||
nn.LeakyReLU(0.2, inplace=True) | ||
nn.LeakyReLU(0.2, inplace=True), | ||
) | ||
|
||
|
||
class R3Discriminator(nn.Module): | ||
"""_summary_ | ||
Args: | ||
nn (_type_): _description_ | ||
""" | ||
def __init__(self): | ||
super().__init__() | ||
self.out_channels = 512 | ||
self.out_dim = 4 | ||
self.conv1 = conv3d((1,64)) | ||
self.conv2 = conv3d((64,128)) | ||
self.conv3 = conv3d((128,256)) | ||
self.conv4 = conv3d((256,512)) | ||
self.conv1 = conv3d((1, 64)) | ||
self.conv2 = conv3d((64, 128)) | ||
self.conv3 = conv3d((128, 256)) | ||
self.conv4 = conv3d((256, 512)) | ||
self.out = nn.Sequential( | ||
nn.Linear(512 * self.out_dim * self.out_dim * self.out_dim, 1), | ||
nn.Sigmoid(), | ||
) | ||
|
||
def forward(self, x): | ||
def forward(self, x:torch.Tensor)->torch.Tensor: | ||
"""_summary_ | ||
Args: | ||
x (torch.Tensor): _description_ | ||
Returns: | ||
torch.Tensor: _description_ | ||
""" | ||
x = self.conv1(x) | ||
x = self.conv2(x) | ||
x = self.conv3(x) | ||
x = self.conv4(x) | ||
# Flatten and apply linear + sigmoid | ||
x = x.view(-1, self.out_channels * self.out_dim * self.out_dim * self.out_dim) | ||
x = self.out(x) | ||
return x | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,52 +1,92 @@ | ||
""" | ||
Variational Autoencoder Architecture | ||
Author: mk314k | ||
""" | ||
import torch | ||
import torch.nn as nn | ||
from torch import nn | ||
from attention import R3DAttention | ||
|
||
class R3DEncoder(nn.Module): | ||
def __init__(self, inChannel =1, imSize = (192,256), latent_dim=1024): | ||
|
||
class PatchEmbed(nn.Module): | ||
""" Image to Patch Embedding | ||
""" | ||
def __init__(self, patch_size: int, img_channel: int, embedding_channel: int): | ||
"""_summary_ | ||
Args: | ||
patch_size (int): _description_ | ||
in_channel (int): _description_ | ||
out_channel (int): _description_ | ||
""" | ||
|
||
super().__init__() | ||
self.conv1_channel = 32 | ||
self.conv1 = nn.Sequential( | ||
nn.Conv2d(inChannel, self.conv1_channel, kernel_size=3, stride=1, padding=1), | ||
nn.MaxPool2d(2), | ||
nn.ReLU() | ||
) | ||
self.attention1 = R3DAttention(self.conv1_channel, 4) | ||
self.conv2 = nn.Sequential( | ||
nn.Conv2d(self.conv1_channel, 2*self.conv1_channel, kernel_size=3, stride=1, padding=1), | ||
nn.MaxPool2d(4), | ||
nn.ReLU() | ||
) | ||
self.attention2 = R3DAttention(2*self.conv1_channel, 8) | ||
self.conv3 = nn.Sequential( | ||
nn.Conv2d(2*self.conv1_channel, 4*self.conv1_channel, kernel_size=3, stride=1, padding=1), | ||
nn.MaxPool2d(4), | ||
nn.ReLU() | ||
self.patch_embedding = nn.Conv2d(img_channel, embedding_channel, kernel_size=patch_size) | ||
|
||
def forward(self, x: torch.Tensor): | ||
"""_summary_ | ||
Args: | ||
x (torch.Tensor): (batch*view, channel, width, height) | ||
Returns: | ||
torch.Tensor: (batch*view, patch, embedding) | ||
""" | ||
out = self.patch_embedding(x).view(x.size(0), -1, self.num_patches).permute(0, 2, 1) | ||
return out | ||
|
||
|
||
class R3DEncoder(nn.Module): # pylint: disable=too-many-instance-attributes | ||
"""_summary_ | ||
Args: | ||
nn (_type_): _description_ | ||
""" | ||
def __init__( # pylint: disable=too-many-arguments | ||
self, | ||
in_channel=1, | ||
num_patches=128, | ||
embedding_dim=64, | ||
embedding_kernel=3, | ||
attention_head=4, | ||
latent_dim=1024 | ||
): | ||
super().__init__() | ||
self.patch_embedding = PatchEmbed( | ||
img_channel=in_channel, | ||
embedding_channel=embedding_dim, | ||
patch_size=embedding_kernel | ||
) | ||
self.fc1 = nn.Linear(4*self.conv1_channel*6*8, 1024) | ||
self.fc2 = nn.Linear(1024, 512) | ||
self.fc3 = nn.Linear(512, 256) | ||
self.pos_embedding = nn.Embedding(num_patches, embedding_dim) | ||
self.batch_attention = R3DAttention(embedding_dim, attention_head) | ||
# self.view_attention = R3DAttention(2 * conv1_channel, 8) | ||
self.fc = nn.Linear(embedding_dim, 256) | ||
self.fc_mu = nn.Linear(256, latent_dim) | ||
self.fc_logvar = nn.Linear(256, latent_dim) | ||
self.N = torch.distributions.Normal(0, 1) | ||
self.N.loc = self.N.loc.cuda() | ||
self.N.scale = self.N.scale.cuda() | ||
self.kl = 0 | ||
self.dist = torch.distributions.Normal(0, 1) | ||
self.dist.loc = self.N.loc.cuda() | ||
self.dist.scale = self.N.scale.cuda() | ||
self.kl_val = 0 | ||
|
||
def forward(self, x: torch.Tensor): | ||
""" | ||
Args: | ||
x (torch.Tensor): input tensor | ||
Each image can have multiple view | ||
def forward(self, x): | ||
Returns: | ||
torch.Tensor: _description_ | ||
""" | ||
# apply encoder network to input image | ||
x = self.conv1(x) | ||
x = x + self.attention1(x) | ||
x = self.conv2(x) | ||
x = x + self.attention2(x) | ||
x = self.conv3(x) | ||
# x = x.reshape((4*self.conv1_channel,6, 32, 8, 32)).permute((2, 4, 0, 1, 3)).reshape((1024, -1)) | ||
x = torch.flatten(x) | ||
x = F.relu(self.fc1(x)) | ||
x = F.relu(self.fc2(x)) | ||
x = F.relu(self.fc3(x)) | ||
mu = self.fc_mu(x) | ||
sigma = torch.exp(self.fc_logvar(x)) | ||
z = mu + sigma*self.N.sample(mu.shape) | ||
self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum() | ||
return z | ||
# assert(len(x.shape) == 5, "input must be of shape (batch, view, channel, width, height)") | ||
b, v, c, w, h = x.shape | ||
x = x.view(b * v, c, w, h) | ||
x = self.patch_embedding(x) | ||
x = x + self.pos_embedding() #fix it | ||
x = x + self.patch_attention(x) | ||
x = self.mlp(x) | ||
mu_val = self.fc_mu(x) | ||
sigma = self.fc_logvar(x).exp() | ||
z_val = mu_val + sigma * self.dist.sample(mu_val.shape) | ||
self.kl_val = (sigma ** 2 + mu_val**2 - sigma.log() - 1 / 2).sum() | ||
return z_val |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,35 +1,51 @@ | ||
""" | ||
This module includes code for Generating 3D images | ||
""" | ||
import torch | ||
import torch.nn as nn | ||
from torch import nn | ||
|
||
class R3DGenerator(nn.Module): | ||
"""_summary_ | ||
Args: | ||
nn (_type_): _description_ | ||
""" | ||
def __init__(self, z_dim): | ||
super().__init__() | ||
self.z_dim = z_dim | ||
self.linear = nn.Linear(z_dim, 64 * 4 * 4 * 4) | ||
self.conv1 = nn.Sequential( | ||
nn.ConvTranspose3d(64, 32, kernel_size=4, stride=2, padding=1), | ||
nn.BatchNorm3d(32), | ||
nn.ReLU() | ||
nn.ReLU(), | ||
) | ||
self.conv2 = nn.Sequential( | ||
nn.ConvTranspose3d(32, 16, kernel_size=4, stride=2, padding=1), | ||
nn.BatchNorm3d(16), | ||
nn.ReLU() | ||
nn.ReLU(), | ||
) | ||
self.conv3 = nn.Sequential( | ||
nn.ConvTranspose3d(16, 8, kernel_size=4, stride=2, padding=1), | ||
nn.BatchNorm3d(8), | ||
nn.ReLU() | ||
nn.ReLU(), | ||
) | ||
self.conv4 = nn.Sequential( | ||
nn.ConvTranspose3d(8, 1, kernel_size=4, stride=2, padding=1), | ||
nn.Sigmoid() | ||
nn.ConvTranspose3d(8, 1, kernel_size=4, stride=2, padding=1), nn.Sigmoid() | ||
) | ||
def forward(self, z): | ||
|
||
def forward(self, z:torch.Tensor)->torch.Tensor: | ||
"""_summary_ | ||
Args: | ||
z (_type_): _description_ | ||
Returns: | ||
_type_: _description_ | ||
""" | ||
out = self.linear(z) | ||
out = out.view(-1, 64, 4, 4, 4) | ||
out = self.conv1(out) | ||
out = self.conv2(out) | ||
out = self.conv3(out) | ||
out = self.conv4(out) | ||
return out | ||
return out |
Empty file.
Oops, something went wrong.