Skip to content

Commit

Permalink
Merge pull request #1 from mk314k/development
Browse files Browse the repository at this point in the history
Development
  • Loading branch information
mk314k authored Nov 21, 2023
2 parents 5e80682 + bf26541 commit 34b7192
Show file tree
Hide file tree
Showing 28 changed files with 1,035 additions and 151 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/pylint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,15 @@ jobs:
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Find requirements.txt
id: find-requirements
run: |
REQUIREMENTS=$(git ls-files '*.txt' | grep 'requirements.txt' || true)
echo "::set-output name=requirements_path::$REQUIREMENTS"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r ${{ steps.find-requirements.outputs.requirements_path }} # Install requirements.txt
pip install pylint
- name: Analysing the code with pylint
run: |
Expand Down
Empty file added main.py
Empty file.
81 changes: 57 additions & 24 deletions models/attention.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,71 @@
"""
This Module includes general implementation of Multiheaded Attention
"""
import torch
import torch.nn as nn
from torch import nn
from fancy_einsum import einsum
from einops import rearrange


class R3DAttention(nn.Module):
"""
R3DAttention module performs multi-head attention computation on 3D data.
Args:
hidden_size (int): The dimensionality of the input embeddings.
num_heads (int): The number of attention heads to use.
dropout (float, optional): Dropout rate to prevent overfitting (default: 0.1).
"""

def __init__(self, hidden_size: int, num_heads: int, dropout=0.1):
super(R3DAttention, self).__init__()
self.hidden_size = hidden_size
assert hidden_size % num_heads == 0
self.num_heads = num_heads
head_size = hidden_size // num_heads
super().__init__()

# Calculating the size of each attention head
head_size = int(hidden_size / num_heads)
self.n_head = num_heads
self.head_size = head_size
self.qkv_proj = nn.Linear(hidden_size,3*hidden_size)

# Projection layers for Q, K, and V
self.qkv_proj = nn.Linear(hidden_size, 3 * hidden_size)
self.output_proj = nn.Linear(hidden_size, hidden_size)

# Dropout layers
self.attn_dropout = nn.Dropout(dropout)
self.resid_dropout = nn.Dropout(dropout)

def forward(self, x: torch.Tensor, mask=None) -> torch.Tensor:
"""
Perform multi-head scaled dot-product attention on the input.
Args:
x (torch.Tensor): Input tensor of shape (batch*view, patch, embedding).
mask (torch.Tensor, optional): Mask tensor for masking attention scores (default: None).
def forward(self, x: torch.Tensor, cache = None) -> torch.Tensor:
xshap = x.shape
if len(xshap)==3:
x = x.reshape((1,*x.shape))
b, c, w, h = x.shape
x = x.permute((0,2,3,1)).reshape((b, -1, c))
qkv = self.qkv_proj(x)
n_shape = qkv.shape
q,k,v = rearrange(qkv,"b s (c n h)-> c b n s h", c=3, n=self.num_heads)
attn_score = einsum("b n sq h, b n sk h->b n sq sk",q,k)/(self.head_size**0.5)
mask = 1e4*torch.triu(torch.ones_like(attn_score[0,0]),diagonal=1)
attn_prob = torch.softmax(attn_score-mask,dim=-1,dtype=x.dtype)#type: ignore
Returns:
torch.Tensor: Output tensor after attention computation (batch*view, patch, embedding).
"""
b, p, _ = x.shape

# Projecting input into query, key, and value representations
qkv = self.qkv_proj(x).reshape(b, p, 3, self.n_head, self.head_size)
q, k, v = qkv.chunk(dim=2)

# Calculating attention scores
attn_score = einsum("b n pq s, b n pk s->b n pq pk", q, k) / (self.head_size ** 0.5)

# Applying optional mask to attention scores
if mask is not None:
attn_score -= mask

# Computing attention probabilities and apply dropout
attn_prob = attn_score.softmax(dim=-1)
attn_prob = self.attn_dropout(attn_prob)
z = einsum("b n sq sk, b n sk h ->b sq n h",attn_prob,v)
z = torch.reshape(z,(z.shape[0],z.shape[1],-1))

# Weighted sum of values using attention probabilities
z = einsum("b n pq pk, b n pk s ->b pq n s", attn_prob, v)
z = z.reshape((z.shape[0], z.shape[1], -1))

# Projecting back to the original space and applying residual dropout
out = self.output_proj(z)
out = self.resid_dropout(out)
out = out.reshape((b, w, h,c)).permute((0,3,1,2))
return out

return out
47 changes: 37 additions & 10 deletions models/discriminator.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,58 @@
"""
This module includes 3D Discriminator module
"""
import torch
import torch.nn as nn
from torch import nn

conv3d = lambda channel: nn.Sequential(
nn.Conv3d(channel[0], channel[1], kernel_size=4,stride=2, padding=1, bias=False),
def conv3d(channel):
"""_summary_
Args:
channel (_type_): _description_
Returns:
_type_: _description_
"""
return nn.Sequential(
nn.Conv3d(channel[0], channel[1], kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm3d(channel[1]),
nn.LeakyReLU(0.2, inplace=True)
nn.LeakyReLU(0.2, inplace=True),
)


class R3Discriminator(nn.Module):
"""_summary_
Args:
nn (_type_): _description_
"""
def __init__(self):
super().__init__()
self.out_channels = 512
self.out_dim = 4
self.conv1 = conv3d((1,64))
self.conv2 = conv3d((64,128))
self.conv3 = conv3d((128,256))
self.conv4 = conv3d((256,512))
self.conv1 = conv3d((1, 64))
self.conv2 = conv3d((64, 128))
self.conv3 = conv3d((128, 256))
self.conv4 = conv3d((256, 512))
self.out = nn.Sequential(
nn.Linear(512 * self.out_dim * self.out_dim * self.out_dim, 1),
nn.Sigmoid(),
)

def forward(self, x):
def forward(self, x:torch.Tensor)->torch.Tensor:
"""_summary_
Args:
x (torch.Tensor): _description_
Returns:
torch.Tensor: _description_
"""
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
# Flatten and apply linear + sigmoid
x = x.view(-1, self.out_channels * self.out_dim * self.out_dim * self.out_dim)
x = self.out(x)
return x
return x
126 changes: 83 additions & 43 deletions models/encoder.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,92 @@
"""
Variational Autoencoder Architecture
Author: mk314k
"""
import torch
import torch.nn as nn
from torch import nn
from attention import R3DAttention

class R3DEncoder(nn.Module):
def __init__(self, inChannel =1, imSize = (192,256), latent_dim=1024):

class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, patch_size: int, img_channel: int, embedding_channel: int):
"""_summary_
Args:
patch_size (int): _description_
in_channel (int): _description_
out_channel (int): _description_
"""

super().__init__()
self.conv1_channel = 32
self.conv1 = nn.Sequential(
nn.Conv2d(inChannel, self.conv1_channel, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(2),
nn.ReLU()
)
self.attention1 = R3DAttention(self.conv1_channel, 4)
self.conv2 = nn.Sequential(
nn.Conv2d(self.conv1_channel, 2*self.conv1_channel, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(4),
nn.ReLU()
)
self.attention2 = R3DAttention(2*self.conv1_channel, 8)
self.conv3 = nn.Sequential(
nn.Conv2d(2*self.conv1_channel, 4*self.conv1_channel, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(4),
nn.ReLU()
self.patch_embedding = nn.Conv2d(img_channel, embedding_channel, kernel_size=patch_size)

def forward(self, x: torch.Tensor):
"""_summary_
Args:
x (torch.Tensor): (batch*view, channel, width, height)
Returns:
torch.Tensor: (batch*view, patch, embedding)
"""
out = self.patch_embedding(x).view(x.size(0), -1, self.num_patches).permute(0, 2, 1)
return out


class R3DEncoder(nn.Module): # pylint: disable=too-many-instance-attributes
"""_summary_
Args:
nn (_type_): _description_
"""
def __init__( # pylint: disable=too-many-arguments
self,
in_channel=1,
num_patches=128,
embedding_dim=64,
embedding_kernel=3,
attention_head=4,
latent_dim=1024
):
super().__init__()
self.patch_embedding = PatchEmbed(
img_channel=in_channel,
embedding_channel=embedding_dim,
patch_size=embedding_kernel
)
self.fc1 = nn.Linear(4*self.conv1_channel*6*8, 1024)
self.fc2 = nn.Linear(1024, 512)
self.fc3 = nn.Linear(512, 256)
self.pos_embedding = nn.Embedding(num_patches, embedding_dim)
self.batch_attention = R3DAttention(embedding_dim, attention_head)
# self.view_attention = R3DAttention(2 * conv1_channel, 8)
self.fc = nn.Linear(embedding_dim, 256)
self.fc_mu = nn.Linear(256, latent_dim)
self.fc_logvar = nn.Linear(256, latent_dim)
self.N = torch.distributions.Normal(0, 1)
self.N.loc = self.N.loc.cuda()
self.N.scale = self.N.scale.cuda()
self.kl = 0
self.dist = torch.distributions.Normal(0, 1)
self.dist.loc = self.N.loc.cuda()
self.dist.scale = self.N.scale.cuda()
self.kl_val = 0

def forward(self, x: torch.Tensor):
"""
Args:
x (torch.Tensor): input tensor
Each image can have multiple view
def forward(self, x):
Returns:
torch.Tensor: _description_
"""
# apply encoder network to input image
x = self.conv1(x)
x = x + self.attention1(x)
x = self.conv2(x)
x = x + self.attention2(x)
x = self.conv3(x)
# x = x.reshape((4*self.conv1_channel,6, 32, 8, 32)).permute((2, 4, 0, 1, 3)).reshape((1024, -1))
x = torch.flatten(x)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
mu = self.fc_mu(x)
sigma = torch.exp(self.fc_logvar(x))
z = mu + sigma*self.N.sample(mu.shape)
self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum()
return z
# assert(len(x.shape) == 5, "input must be of shape (batch, view, channel, width, height)")
b, v, c, w, h = x.shape
x = x.view(b * v, c, w, h)
x = self.patch_embedding(x)
x = x + self.pos_embedding() #fix it
x = x + self.patch_attention(x)
x = self.mlp(x)
mu_val = self.fc_mu(x)
sigma = self.fc_logvar(x).exp()
z_val = mu_val + sigma * self.dist.sample(mu_val.shape)
self.kl_val = (sigma ** 2 + mu_val**2 - sigma.log() - 1 / 2).sum()
return z_val
32 changes: 24 additions & 8 deletions models/generator.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,51 @@
"""
This module includes code for Generating 3D images
"""
import torch
import torch.nn as nn
from torch import nn

class R3DGenerator(nn.Module):
"""_summary_
Args:
nn (_type_): _description_
"""
def __init__(self, z_dim):
super().__init__()
self.z_dim = z_dim
self.linear = nn.Linear(z_dim, 64 * 4 * 4 * 4)
self.conv1 = nn.Sequential(
nn.ConvTranspose3d(64, 32, kernel_size=4, stride=2, padding=1),
nn.BatchNorm3d(32),
nn.ReLU()
nn.ReLU(),
)
self.conv2 = nn.Sequential(
nn.ConvTranspose3d(32, 16, kernel_size=4, stride=2, padding=1),
nn.BatchNorm3d(16),
nn.ReLU()
nn.ReLU(),
)
self.conv3 = nn.Sequential(
nn.ConvTranspose3d(16, 8, kernel_size=4, stride=2, padding=1),
nn.BatchNorm3d(8),
nn.ReLU()
nn.ReLU(),
)
self.conv4 = nn.Sequential(
nn.ConvTranspose3d(8, 1, kernel_size=4, stride=2, padding=1),
nn.Sigmoid()
nn.ConvTranspose3d(8, 1, kernel_size=4, stride=2, padding=1), nn.Sigmoid()
)
def forward(self, z):

def forward(self, z:torch.Tensor)->torch.Tensor:
"""_summary_
Args:
z (_type_): _description_
Returns:
_type_: _description_
"""
out = self.linear(z)
out = out.view(-1, 64, 4, 4, 4)
out = self.conv1(out)
out = self.conv2(out)
out = self.conv3(out)
out = self.conv4(out)
return out
return out
Empty file added models/model.py
Empty file.
Loading

0 comments on commit 34b7192

Please sign in to comment.