Skip to content

Commit

Permalink
Merge pull request #2 from mk314k/development
Browse files Browse the repository at this point in the history
complete attention
  • Loading branch information
mk314k committed Nov 27, 2023
2 parents 34b7192 + 4ef745a commit 0d1bfa5
Show file tree
Hide file tree
Showing 33 changed files with 360 additions and 861 deletions.
Binary file added .DS_Store
Binary file not shown.
4 changes: 3 additions & 1 deletion .github/workflows/pylint.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
name: Pylint

on: [push]
on:
push:
branches: [ "main" ]

jobs:
build:
Expand Down
6 changes: 2 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,11 @@ This project focuses on the problem of reconstructing 3D shapes from a single 2D
## Dependencies
- Python
- Pytorch
- einops
- fancy_einsum
- sklearn

## Results
<img src="test/3dcar.png" width="300" height="300"/> | <img src="test/r3dcar.png" width="300" height="300"/> | <img src="test/car.png" width="300" height="300"/>
<img src="test/3dphone.png" width="300" height="300"/> | <img src="test/r3dphone.png" width="300" height="300"/> | <img src="test/phone.png" width="300" height="300"/>
<img src="test/3dcar.png" width="200" height="200"/> | <img src="test/r3dcar.png" width="200" height="200"/> | <img src="test/car.png" width="200" height="200"/>
<img src="test/3dphone.png" width="200" height="200"/> | <img src="test/r3dphone.png" width="200" height="200"/> | <img src="test/phone.png" width="200" height="200"/>

fig. 3d image (given), 3d image (generated), 2d image (given) from left to right

Expand Down
Binary file added data/.DS_Store
Binary file not shown.
Binary file added data/shapenetcore/.DS_Store
Binary file not shown.
70 changes: 70 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""
_summary_
"""
import torch
from matplotlib import pyplot as plt
import tqdm.auto as tqdm
from utils import load_data
from train import train_epoch
from models.r3d import R3D


if __name__ == "__main__":
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


TRAIN_DATA_PATH = "data/shapenetcore/train_imgs"
TEST_DATA_PATH = "data/shapenetcore/test_imgs"
VOXEL_SIZE = 64
pixel_shape = (192,256)

train_2d, train_3d = load_data(
TRAIN_DATA_PATH,
voxel_size = VOXEL_SIZE,
pixel_shape = pixel_shape,
device = device
)
test_2d, test_3d = load_data(
TEST_DATA_PATH,
voxel_size = VOXEL_SIZE,
pixel_shape = pixel_shape,
device = device
)


model = R3D(

).to(device)

# Setting hyperparameters for optimizers
LR = 1e-3
WD = 0.2
betas=(0.9, 0.98)
# Initializing optimizers
vae_optim = torch.optim.AdamW(
model.vae_parameters(),
lr=LR,
weight_decay=WD,
betas=betas
)
gan_optim = torch.optim.AdamW(
model.gan_parameters(),
lr=LR,
weight_decay=WD,
betas=betas
)
NUM_EPOCHS = 25
train_loss = []
for _ in tqdm.tqdm(NUM_EPOCHS):
train_loss.append(train_epoch(
train_2d, train_3d, model, vae_optim, gan_optim
))
# Plotting training losses
plt.figure(figsize=(16, 6))
plt.plot(range(len(train_loss)), [tloss[0] for tloss in train_loss], label='Training VAE loss')
plt.plot(range(len(train_loss)), [tloss[1] for tloss in train_loss], label='Training GAN loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Model Performance')
plt.legend()
plt.show()
24 changes: 13 additions & 11 deletions models/attention.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""
This Module includes general implementation of Multiheaded Attention
Author:mk314k
"""
import torch
from torch import nn
from fancy_einsum import einsum


class R3DAttention(nn.Module):
"""
R3DAttention module performs multi-head attention computation on 3D data.
R3DAttention module performs multi-head attention computation on Image Patches.
Args:
hidden_size (int): The dimensionality of the input embeddings.
Expand Down Expand Up @@ -43,29 +43,31 @@ def forward(self, x: torch.Tensor, mask=None) -> torch.Tensor:
Returns:
torch.Tensor: Output tensor after attention computation (batch*view, patch, embedding).
"""
b, p, _ = x.shape
bv, p, _ = x.shape

# Projecting input into query, key, and value representations
qkv = self.qkv_proj(x).reshape(b, p, 3, self.n_head, self.head_size)
q, k, v = qkv.chunk(dim=2)
qkv = self.qkv_proj(x).reshape(bv, p, 3, self.n_head, self.head_size)
qkv = qkv.permute(2,0,3,1,4) #(3, bv, n, p, h)
q, k, v = qkv.chunk(3)

# Calculating attention scores
attn_score = einsum("b n pq s, b n pk s->b n pq pk", q, k) / (self.head_size ** 0.5)
# Calculating attention scores (bv, n, p, p)
attn_score = q[0].matmul(k[0].transpose(-1,-2))/ (self.head_size ** 0.5)

# Applying optional mask to attention scores
if mask is not None:
attn_score -= mask

# Computing attention probabilities and apply dropout
attn_prob = attn_score.softmax(dim=-1)
attn_prob = attn_score.softmax(dim=-1) #(bv, n, p, p)
attn_prob = self.attn_dropout(attn_prob)

# Weighted sum of values using attention probabilities
z = einsum("b n pq pk, b n pk s ->b pq n s", attn_prob, v)
z = z.reshape((z.shape[0], z.shape[1], -1))
z = attn_prob.matmul(v[0]) #(bv, n, p, h)
z = z.permute(0, 2, 1, 3) #(bv, p, n, h)
z = z.reshape((bv, p, -1)) #(bv, n, e)

# Projecting back to the original space and applying residual dropout
out = self.output_proj(z)
out = self.output_proj(z) #(bv, n, e)
out = self.resid_dropout(out)

return out
36 changes: 17 additions & 19 deletions models/discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,49 +4,47 @@
import torch
from torch import nn

def conv3d(channel):
"""_summary_
def conv3d(in_channel:int, out_channel:int):
"""
Conv3d structure with default hyperparameters
Args:
channel (_type_): _description_
in_channel (int): channel of image before this layer
out_channel (int): channel of image after this layer
Returns:
_type_: _description_
nn.Module: Conv3d Module
"""
return nn.Sequential(
nn.Conv3d(channel[0], channel[1], kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm3d(channel[1]),
nn.Conv3d(in_channel, out_channel, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm3d(out_channel),
nn.LeakyReLU(0.2, inplace=True),
)


class R3Discriminator(nn.Module):
"""_summary_
Args:
nn (_type_): _description_
"""
Discriminator Module
"""
def __init__(self):
super().__init__()
self.out_channels = 512
self.out_dim = 4
self.conv1 = conv3d((1, 64))
self.conv2 = conv3d((64, 128))
self.conv3 = conv3d((128, 256))
self.conv4 = conv3d((256, 512))
self.conv1 = conv3d(1, 64)
self.conv2 = conv3d(64, 128)
self.conv3 = conv3d(128, 256)
self.conv4 = conv3d(256, 512)
self.out = nn.Sequential(
nn.Linear(512 * self.out_dim * self.out_dim * self.out_dim, 1),
nn.Sigmoid(),
)

def forward(self, x:torch.Tensor)->torch.Tensor:
"""_summary_
"""
Args:
x (torch.Tensor): _description_
x (torch.Tensor): 3d image
Returns:
torch.Tensor: _description_
torch.Tensor: real/fake value of image
"""
x = self.conv1(x)
x = self.conv2(x)
Expand Down
27 changes: 19 additions & 8 deletions models/encoder.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,36 @@
"""
Variational Autoencoder Architecture
Variational Autoencoder Architecture using Multihead Attentions
Author: mk314k
"""
import math
import torch
from torch import nn
from attention import R3DAttention


class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, patch_size: int, img_channel: int, embedding_channel: int):
"""_summary_
def __init__(self, patch_size = 3, img_channel = 1, out_channel = 256):
"""
Args:
patch_size (int): _description_
in_channel (int): _description_
out_channel (int): _description_
"""

super().__init__()
self.patch_embedding = nn.Conv2d(img_channel, embedding_channel, kernel_size=patch_size)
self.embedding_channel = out_channel
pad = math.floor(patch_size / 2)
mid_channel = int(out_channel/8)
mid_kernel = 2*patch_size - 1
self.patch_embedding = nn.Sequential(
nn.Conv2d(img_channel, mid_channel, kernel_size=patch_size, padding=pad),
nn.MaxPool2d(kernel_size=pad),
nn.Conv2d(mid_channel, 4*mid_channel, kernel_size=mid_kernel, stride=mid_kernel),
nn.MaxPool2d(kernel_size=pad),
nn.Conv2d(4*mid_channel, out_channel, kernel_size=patch_size, stride=patch_size),
nn.MaxPool2d(kernel_size=pad)
)

def forward(self, x: torch.Tensor):
"""_summary_
Expand All @@ -31,7 +41,8 @@ def forward(self, x: torch.Tensor):
Returns:
torch.Tensor: (batch*view, patch, embedding)
"""
out = self.patch_embedding(x).view(x.size(0), -1, self.num_patches).permute(0, 2, 1)
out = self.patch_embedding(x)
out = out.view(*out.shape[:2], -1).permute(0, 2, 1)
return out


Expand All @@ -53,7 +64,7 @@ def __init__( # pylint: disable=too-many-arguments
super().__init__()
self.patch_embedding = PatchEmbed(
img_channel=in_channel,
embedding_channel=embedding_dim,
out_channel=embedding_dim,
patch_size=embedding_kernel
)
self.pos_embedding = nn.Embedding(num_patches, embedding_dim)
Expand Down
Empty file removed models/model.py
Empty file.
51 changes: 51 additions & 0 deletions models/r3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""
This is an optional wrapper for VAE_GAN Model
"""
import torch
from encoder import R3DEncoder
from generator import R3DGenerator
from discriminator import R3Discriminator

class R3D:
"""
The main model
"""
def __init__(self, # pylint: disable=too-many-arguments
in_channel=1,
num_patches=128,
embedding_dim=64,
embedding_kernel=3,
attention_head=4,
latent_dim=1024
):
self.encoder = R3DEncoder(
in_channel, num_patches, embedding_dim, embedding_kernel, attention_head, latent_dim
)
self.generator = R3DGenerator(latent_dim)
self.discriminator = R3Discriminator()
self.enc = None
self.gan = None
self.disc_true = None
self.disc_false = None

def __call__(self, x:torch.Tensor, y:torch.Tensor)->torch.Tensor:
self.enc = self.encoder(x)
self.gan = self.generator(self.enc)
self.disc_true = self.discriminator(y.reshape(1, *y.shape))
self.disc_false = self.discriminator(self.gan)

return self.gan

def vae_parameters(self):
"""
Returns:
list of parameters for encoding 2d images and generating 3d images
"""
return list(self.encoder.parameters()) + list(self.generator.parameters())

def gan_parameters(self):
"""
Returns:
list of parameters for 3d image generation and discrimination
"""
return list(self.generator.parameters()) + list(self.discriminator.parameters())
Loading

0 comments on commit 0d1bfa5

Please sign in to comment.