Skip to content

Commit

Permalink
fixing linting errors
Browse files Browse the repository at this point in the history
  • Loading branch information
mk314k committed Nov 21, 2023
1 parent 5c2de7f commit 695ffcd
Show file tree
Hide file tree
Showing 6 changed files with 195 additions and 94 deletions.
54 changes: 35 additions & 19 deletions models/attention.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,71 @@
"""_summary_
"""
This Module includes general implementation of Multiheaded Attention
"""
import torch
import torch.nn as nn
from torch import nn
from fancy_einsum import einsum


class R3DAttention(nn.Module):
"""_summary_
"""
R3DAttention module performs multi-head attention computation on 3D data.
Args:
nn (_type_): _description_
hidden_size (int): The dimensionality of the input embeddings.
num_heads (int): The number of attention heads to use.
dropout (float, optional): Dropout rate to prevent overfitting (default: 0.1).
"""

def __init__(self, hidden_size: int, num_heads: int, dropout=0.1):
super(R3DAttention, self).__init__()

# Calculating the size of each attention head
head_size = int(hidden_size / num_heads)
self.n_head = num_heads
self.head_size = head_size

# Projection layers for Q, K, and V
self.qkv_proj = nn.Linear(hidden_size, 3 * hidden_size)
self.output_proj = nn.Linear(hidden_size, hidden_size)

# Dropout layers
self.attn_dropout = nn.Dropout(dropout)
self.resid_dropout = nn.Dropout(dropout)

def forward(self, x: torch.Tensor, mask = None) -> torch.Tensor:
def forward(self, x: torch.Tensor, mask=None) -> torch.Tensor:
"""
Perform multi-head scaled dot-product attention on the input.
Args:
x (torch.Tensor): (batch*view, patch, embedding)
x (torch.Tensor): Input tensor of shape (batch*view, patch, embedding).
mask (torch.Tensor, optional): Mask tensor for masking attention scores (default: None).
Returns:
torch.Tensor: (batch*view, patch, embedding)
torch.Tensor: Output tensor after attention computation (batch*view, patch, embedding).
"""
b, p, _ = x.shape

# Projecting input into query, key, and value representations
qkv = self.qkv_proj(x).reshape(b, p, 3, self.n_head, self.head_size)
q, k, v = qkv.chunk(dim=2)
#b-batch, p-patch, c-constant(3), n-num_heads, s-head_size
attn_score = einsum("b n pq s, b n pk s->b n pq pk", q, k) / (
self.head_size**0.5
)
if mask:

# Calculating attention scores
attn_score = einsum("b n pq s, b n pk s->b n pq pk", q, k) / (self.head_size ** 0.5)

# Applying optional mask to attention scores
if mask is not None:
attn_score -= mask

# Computing attention probabilities and apply dropout
attn_prob = attn_score.softmax(dim=-1)
attn_prob = self.attn_dropout(attn_prob)

# Weighted sum of values using attention probabilities
z = einsum("b n pq pk, b n pk s ->b pq n s", attn_prob, v)
z = z.reshape((z.shape[0], z.shape[1], -1))

# Projecting back to the original space and applying residual dropout
out = self.output_proj(z)
out = self.resid_dropout(out)
return out


# if __name__ == "__main__":
# attn = R3DAttention(64, 4)
# def count_par(model):
# return sum(p.numel() for p in model.parameters() if p.requires_grad)
# print(count_par(attn))
return out
39 changes: 32 additions & 7 deletions models/discriminator.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,31 @@
"""
This module includes 3D Discriminator module
"""
import torch
import torch.nn as nn
from torch import nn

conv3d = lambda channel: nn.Sequential(
nn.Conv3d(channel[0], channel[1], kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm3d(channel[1]),
nn.LeakyReLU(0.2, inplace=True),
)
def conv3d(channel):
"""_summary_
Args:
channel (_type_): _description_
Returns:
_type_: _description_
"""
return nn.Sequential(
nn.Conv3d(channel[0], channel[1], kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm3d(channel[1]),
nn.LeakyReLU(0.2, inplace=True),
)


class R3Discriminator(nn.Module):
"""_summary_
Args:
nn (_type_): _description_
"""
def __init__(self):
super().__init__()
self.out_channels = 512
Expand All @@ -22,7 +39,15 @@ def __init__(self):
nn.Sigmoid(),
)

def forward(self, x):
def forward(self, x:torch.Tensor)->torch.Tensor:
"""_summary_
Args:
x (torch.Tensor): _description_
Returns:
torch.Tensor: _description_
"""
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
Expand Down
13 changes: 3 additions & 10 deletions models/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
Author: mk314k
"""
import torch
import torch.nn as nn
import torch.nn.Functional as F
from torch import nn
from attention import R3DAttention


Expand Down Expand Up @@ -43,7 +42,7 @@ class R3DEncoder(nn.Module):
nn (_type_): _description_
"""
def __init__(
self,
self,
in_channel=1,
num_patches=128,
embedding_dim=64,
Expand All @@ -68,12 +67,6 @@ def __init__(
self.dist.scale = self.N.scale.cuda()
self.kl_val = 0

def get_kl(self):
"""
doc
"""
return self.kl_val

def forward(self, x: torch.Tensor):
"""
Expand All @@ -85,7 +78,7 @@ def forward(self, x: torch.Tensor):
torch.Tensor: _description_
"""
# apply encoder network to input image
assert(len(x.shape) == 5, "input must be of shape (batch, view, channel, width, height)")
# assert(len(x.shape) == 5, "input must be of shape (batch, view, channel, width, height)")
b, v, c, w, h = x.shape
x = x.view(b * v, c, w, h)
x = self.patch_embedding(x)
Expand Down
21 changes: 18 additions & 3 deletions models/generator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
"""
This module includes code for Generating 3D images
"""
import torch
import torch.nn as nn

from torch import nn

class R3DGenerator(nn.Module):
"""_summary_
Args:
nn (_type_): _description_
"""
def __init__(self, z_dim):
super().__init__()
self.z_dim = z_dim
Expand All @@ -26,7 +33,15 @@ def __init__(self, z_dim):
nn.ConvTranspose3d(8, 1, kernel_size=4, stride=2, padding=1), nn.Sigmoid()
)

def forward(self, z):
def forward(self, z:torch.Tensor)->torch.Tensor:
"""_summary_
Args:
z (_type_): _description_
Returns:
_type_: _description_
"""
out = self.linear(z)
out = out.view(-1, 64, 4, 4, 4)
out = self.conv1(out)
Expand Down
91 changes: 62 additions & 29 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,77 +1,110 @@
""".vscode/
"""
This module includes all the loss and other functions necessary for training the model
Author:mk314k
"""
import tqdm.auto as tqdm
from torch.utils.data import TensorDataset, DataLoader, SubsetRandomSampler
from sklearn.model_selection import train_test_split
import torch
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from models.encoder import R3DEncoder
from models.generator import R3DGenerator
from models.discriminator import R3Discriminator

# Define your GAN loss function
def gan_loss(discriminator_output, is_real):
"""
This is docstring

def gan_loss(discriminator_output:torch.Tensor, is_real:bool)->torch.Tensor:
"""Measure to classify real 3d image with generated 3d image
Args:
discriminator_output (torch.Tensor): _description_
is_real (bool): _description_
Returns:
torch.Tensor: _description_
"""
if is_real:
target = torch.ones_like(discriminator_output)
target = torch.ones_like(discriminator_output) # pylint: disable=no-member
else:
target = torch.zeros_like(discriminator_output)
target = torch.zeros_like(discriminator_output) # pylint: disable=no-member
loss = F.binary_cross_entropy(discriminator_output, target, reduction='mean')
return loss

# Define your VAE loss function

def vae_loss(recon_x, label):
"""
This is docstring
"""_summary_
Args:
recon_x (_type_): _description_
label (_type_): _description_
Returns:
_type_: _description_
"""
return F.binary_cross_entropy(recon_x[0], label, reduction='sum')

# Define your training function
def train(train_x, train_y, model_e, model_d, model_g, vae_optim, gan_optim, num_epochs=10):
"""
This is docstring

def train(
train_x:torch.Tensor,
train_y:torch.Tensor,
model_e:nn.Module,
model_d:nn.Module,
model_g:nn.Module,
optims,
num_epochs=10
):
"""_summary_
Args:
train_x (torch.Tensor): _description_
train_y (torch.Tensor): _description_
model_e (nn.Module): _description_
model_d (nn.Module): _description_
model_g (nn.Module): _description_
optims (_type_): _description_
num_epochs (int, optional): _description_. Defaults to 10.
Returns:
_type_: _description_
"""
device = train_x.device
vae_optimizer, gan_optimizer = optims
train_losses = []
for epoch in tqdm.tqdm(range(num_epochs)):
for _ in tqdm.tqdm(range(num_epochs)):
for i in range(70):
batch_label = train_y[i].to(torch.float).to(device)
batch_label = train_y[i].to(torch.float)
for j in [2]:
batch_data = train_x[i,j].to(torch.float).reshape(1,1,192,256).to(device)
batch_data = train_x[i,j].to(torch.float).reshape(1,1,192,256)
e_logit = model_e(batch_data)
g_logit = model_g(e_logit)
g_loss = model_e.kl + vae_loss(g_logit, batch_label)
vae_optim.zero_grad()
vae_optimizer.zero_grad()
g_loss.backward(retain_graph=True)
vae_optim.step()
vae_optimizer.step()
g_logit = model_g(e_logit.detach())
d_true = model_d(batch_label.reshape((1, *batch_label.shape)))
d_false = model_d(g_logit)
d_loss = gan_loss(d_false, False) + gan_loss(d_true, True)
gan_optim.zero_grad()
gan_optimizer.zero_grad()
d_loss.backward()
gan_optim.step()
gan_optimizer.step()
train_losses.append((g_loss.item(), d_loss.item()))
return train_losses

if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
img2d = torch.zeros(10,10)
img3d = torch.zeros(10,10,10)
tt_split = train_test_split(img2d, img3d, test_size=0.3, random_state=500)
train_data, test_data, train_label, test_label = tt_split
# Define and initialize your models
# Intializing models
MODEL_E = R3DEncoder().to(device)
MODEL_G = R3DGenerator(1024).to(device)
MODEL_D = R3Discriminator().to(device)
# Set hyperparameters for your optimizers
# Setting hyperparameters for your optimizers
LR = 1e-3
WD = 0.2
betas=(0.9, 0.98)
# Initialize optimizers
# Initializing optimizers
vae_optim = torch.optim.AdamW(
list(MODEL_E.parameters())+list(MODEL_G.parameters()),
lr=LR,
Expand All @@ -84,8 +117,8 @@ def train(train_x, train_y, model_e, model_d, model_g, vae_optim, gan_optim, num
weight_decay=WD,
betas=betas
)
train_loss = train(train_data, train_label, MODEL_E, MODEL_G, MODEL_D, vae_optim, gan_optim)
# Plot training losses
train_loss = train(train_data, train_label, MODEL_E, MODEL_G, MODEL_D, (vae_optim, gan_optim))
# Plotting training losses
plt.figure(figsize=(16, 6))
plt.plot(range(len(train_loss)), [tloss[0] for tloss in train_loss], label='Training VAE loss')
plt.plot(range(len(train_loss)), [tloss[1] for tloss in train_loss], label='Training GAN loss')
Expand Down
Loading

0 comments on commit 695ffcd

Please sign in to comment.