Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Development #1

Merged
merged 6 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/pylint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,15 @@ jobs:
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Find requirements.txt
id: find-requirements
run: |
REQUIREMENTS=$(git ls-files '*.txt' | grep 'requirements.txt' || true)
echo "::set-output name=requirements_path::$REQUIREMENTS"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r ${{ steps.find-requirements.outputs.requirements_path }} # Install requirements.txt
pip install pylint
- name: Analysing the code with pylint
run: |
Expand Down
Empty file added main.py
Empty file.
81 changes: 57 additions & 24 deletions models/attention.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,71 @@
"""
This Module includes general implementation of Multiheaded Attention
"""
import torch
import torch.nn as nn
from torch import nn
from fancy_einsum import einsum
from einops import rearrange


class R3DAttention(nn.Module):
"""
R3DAttention module performs multi-head attention computation on 3D data.

Args:
hidden_size (int): The dimensionality of the input embeddings.
num_heads (int): The number of attention heads to use.
dropout (float, optional): Dropout rate to prevent overfitting (default: 0.1).
"""

def __init__(self, hidden_size: int, num_heads: int, dropout=0.1):
super(R3DAttention, self).__init__()
self.hidden_size = hidden_size
assert hidden_size % num_heads == 0
self.num_heads = num_heads
head_size = hidden_size // num_heads
super().__init__()

# Calculating the size of each attention head
head_size = int(hidden_size / num_heads)
self.n_head = num_heads
self.head_size = head_size
self.qkv_proj = nn.Linear(hidden_size,3*hidden_size)

# Projection layers for Q, K, and V
self.qkv_proj = nn.Linear(hidden_size, 3 * hidden_size)
self.output_proj = nn.Linear(hidden_size, hidden_size)

# Dropout layers
self.attn_dropout = nn.Dropout(dropout)
self.resid_dropout = nn.Dropout(dropout)

def forward(self, x: torch.Tensor, mask=None) -> torch.Tensor:
"""
Perform multi-head scaled dot-product attention on the input.

Args:
x (torch.Tensor): Input tensor of shape (batch*view, patch, embedding).
mask (torch.Tensor, optional): Mask tensor for masking attention scores (default: None).

def forward(self, x: torch.Tensor, cache = None) -> torch.Tensor:
xshap = x.shape
if len(xshap)==3:
x = x.reshape((1,*x.shape))
b, c, w, h = x.shape
x = x.permute((0,2,3,1)).reshape((b, -1, c))
qkv = self.qkv_proj(x)
n_shape = qkv.shape
q,k,v = rearrange(qkv,"b s (c n h)-> c b n s h", c=3, n=self.num_heads)
attn_score = einsum("b n sq h, b n sk h->b n sq sk",q,k)/(self.head_size**0.5)
mask = 1e4*torch.triu(torch.ones_like(attn_score[0,0]),diagonal=1)
attn_prob = torch.softmax(attn_score-mask,dim=-1,dtype=x.dtype)#type: ignore
Returns:
torch.Tensor: Output tensor after attention computation (batch*view, patch, embedding).
"""
b, p, _ = x.shape

# Projecting input into query, key, and value representations
qkv = self.qkv_proj(x).reshape(b, p, 3, self.n_head, self.head_size)
q, k, v = qkv.chunk(dim=2)

# Calculating attention scores
attn_score = einsum("b n pq s, b n pk s->b n pq pk", q, k) / (self.head_size ** 0.5)

# Applying optional mask to attention scores
if mask is not None:
attn_score -= mask

# Computing attention probabilities and apply dropout
attn_prob = attn_score.softmax(dim=-1)
attn_prob = self.attn_dropout(attn_prob)
z = einsum("b n sq sk, b n sk h ->b sq n h",attn_prob,v)
z = torch.reshape(z,(z.shape[0],z.shape[1],-1))

# Weighted sum of values using attention probabilities
z = einsum("b n pq pk, b n pk s ->b pq n s", attn_prob, v)
z = z.reshape((z.shape[0], z.shape[1], -1))

# Projecting back to the original space and applying residual dropout
out = self.output_proj(z)
out = self.resid_dropout(out)
out = out.reshape((b, w, h,c)).permute((0,3,1,2))
return out

return out
47 changes: 37 additions & 10 deletions models/discriminator.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,58 @@
"""
This module includes 3D Discriminator module
"""
import torch
import torch.nn as nn
from torch import nn

conv3d = lambda channel: nn.Sequential(
nn.Conv3d(channel[0], channel[1], kernel_size=4,stride=2, padding=1, bias=False),
def conv3d(channel):
"""_summary_

Args:
channel (_type_): _description_

Returns:
_type_: _description_
"""
return nn.Sequential(
nn.Conv3d(channel[0], channel[1], kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm3d(channel[1]),
nn.LeakyReLU(0.2, inplace=True)
nn.LeakyReLU(0.2, inplace=True),
)


class R3Discriminator(nn.Module):
"""_summary_

Args:
nn (_type_): _description_
"""
def __init__(self):
super().__init__()
self.out_channels = 512
self.out_dim = 4
self.conv1 = conv3d((1,64))
self.conv2 = conv3d((64,128))
self.conv3 = conv3d((128,256))
self.conv4 = conv3d((256,512))
self.conv1 = conv3d((1, 64))
self.conv2 = conv3d((64, 128))
self.conv3 = conv3d((128, 256))
self.conv4 = conv3d((256, 512))
self.out = nn.Sequential(
nn.Linear(512 * self.out_dim * self.out_dim * self.out_dim, 1),
nn.Sigmoid(),
)

def forward(self, x):
def forward(self, x:torch.Tensor)->torch.Tensor:
"""_summary_

Args:
x (torch.Tensor): _description_

Returns:
torch.Tensor: _description_
"""
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
# Flatten and apply linear + sigmoid
x = x.view(-1, self.out_channels * self.out_dim * self.out_dim * self.out_dim)
x = self.out(x)
return x
return x
126 changes: 83 additions & 43 deletions models/encoder.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,92 @@
"""
Variational Autoencoder Architecture
Author: mk314k
"""
import torch
import torch.nn as nn
from torch import nn
from attention import R3DAttention

class R3DEncoder(nn.Module):
def __init__(self, inChannel =1, imSize = (192,256), latent_dim=1024):

class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, patch_size: int, img_channel: int, embedding_channel: int):
"""_summary_

Args:
patch_size (int): _description_
in_channel (int): _description_
out_channel (int): _description_
"""

super().__init__()
self.conv1_channel = 32
self.conv1 = nn.Sequential(
nn.Conv2d(inChannel, self.conv1_channel, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(2),
nn.ReLU()
)
self.attention1 = R3DAttention(self.conv1_channel, 4)
self.conv2 = nn.Sequential(
nn.Conv2d(self.conv1_channel, 2*self.conv1_channel, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(4),
nn.ReLU()
)
self.attention2 = R3DAttention(2*self.conv1_channel, 8)
self.conv3 = nn.Sequential(
nn.Conv2d(2*self.conv1_channel, 4*self.conv1_channel, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(4),
nn.ReLU()
self.patch_embedding = nn.Conv2d(img_channel, embedding_channel, kernel_size=patch_size)

def forward(self, x: torch.Tensor):
"""_summary_

Args:
x (torch.Tensor): (batch*view, channel, width, height)

Returns:
torch.Tensor: (batch*view, patch, embedding)
"""
out = self.patch_embedding(x).view(x.size(0), -1, self.num_patches).permute(0, 2, 1)
return out


class R3DEncoder(nn.Module): # pylint: disable=too-many-instance-attributes
"""_summary_

Args:
nn (_type_): _description_
"""
def __init__( # pylint: disable=too-many-arguments
self,
in_channel=1,
num_patches=128,
embedding_dim=64,
embedding_kernel=3,
attention_head=4,
latent_dim=1024
):
super().__init__()
self.patch_embedding = PatchEmbed(
img_channel=in_channel,
embedding_channel=embedding_dim,
patch_size=embedding_kernel
)
self.fc1 = nn.Linear(4*self.conv1_channel*6*8, 1024)
self.fc2 = nn.Linear(1024, 512)
self.fc3 = nn.Linear(512, 256)
self.pos_embedding = nn.Embedding(num_patches, embedding_dim)
self.batch_attention = R3DAttention(embedding_dim, attention_head)
# self.view_attention = R3DAttention(2 * conv1_channel, 8)
self.fc = nn.Linear(embedding_dim, 256)
self.fc_mu = nn.Linear(256, latent_dim)
self.fc_logvar = nn.Linear(256, latent_dim)
self.N = torch.distributions.Normal(0, 1)
self.N.loc = self.N.loc.cuda()
self.N.scale = self.N.scale.cuda()
self.kl = 0
self.dist = torch.distributions.Normal(0, 1)
self.dist.loc = self.N.loc.cuda()
self.dist.scale = self.N.scale.cuda()
self.kl_val = 0

def forward(self, x: torch.Tensor):
"""

Args:
x (torch.Tensor): input tensor
Each image can have multiple view

def forward(self, x):
Returns:
torch.Tensor: _description_
"""
# apply encoder network to input image
x = self.conv1(x)
x = x + self.attention1(x)
x = self.conv2(x)
x = x + self.attention2(x)
x = self.conv3(x)
# x = x.reshape((4*self.conv1_channel,6, 32, 8, 32)).permute((2, 4, 0, 1, 3)).reshape((1024, -1))
x = torch.flatten(x)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
mu = self.fc_mu(x)
sigma = torch.exp(self.fc_logvar(x))
z = mu + sigma*self.N.sample(mu.shape)
self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum()
return z
# assert(len(x.shape) == 5, "input must be of shape (batch, view, channel, width, height)")
b, v, c, w, h = x.shape
x = x.view(b * v, c, w, h)
x = self.patch_embedding(x)
x = x + self.pos_embedding() #fix it
x = x + self.patch_attention(x)
x = self.mlp(x)
mu_val = self.fc_mu(x)
sigma = self.fc_logvar(x).exp()
z_val = mu_val + sigma * self.dist.sample(mu_val.shape)
self.kl_val = (sigma ** 2 + mu_val**2 - sigma.log() - 1 / 2).sum()
return z_val
32 changes: 24 additions & 8 deletions models/generator.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,51 @@
"""
This module includes code for Generating 3D images
"""
import torch
import torch.nn as nn
from torch import nn

class R3DGenerator(nn.Module):
"""_summary_

Args:
nn (_type_): _description_
"""
def __init__(self, z_dim):
super().__init__()
self.z_dim = z_dim
self.linear = nn.Linear(z_dim, 64 * 4 * 4 * 4)
self.conv1 = nn.Sequential(
nn.ConvTranspose3d(64, 32, kernel_size=4, stride=2, padding=1),
nn.BatchNorm3d(32),
nn.ReLU()
nn.ReLU(),
)
self.conv2 = nn.Sequential(
nn.ConvTranspose3d(32, 16, kernel_size=4, stride=2, padding=1),
nn.BatchNorm3d(16),
nn.ReLU()
nn.ReLU(),
)
self.conv3 = nn.Sequential(
nn.ConvTranspose3d(16, 8, kernel_size=4, stride=2, padding=1),
nn.BatchNorm3d(8),
nn.ReLU()
nn.ReLU(),
)
self.conv4 = nn.Sequential(
nn.ConvTranspose3d(8, 1, kernel_size=4, stride=2, padding=1),
nn.Sigmoid()
nn.ConvTranspose3d(8, 1, kernel_size=4, stride=2, padding=1), nn.Sigmoid()
)
def forward(self, z):

def forward(self, z:torch.Tensor)->torch.Tensor:
"""_summary_

Args:
z (_type_): _description_

Returns:
_type_: _description_
"""
out = self.linear(z)
out = out.view(-1, 64, 4, 4, 4)
out = self.conv1(out)
out = self.conv2(out)
out = self.conv3(out)
out = self.conv4(out)
return out
return out
Empty file added models/model.py
Empty file.
Loading
Loading