Skip to content

Commit

Permalink
[chore] adds functionality for DCGAN (#16) (#17)
Browse files Browse the repository at this point in the history
* Added 'cuda' functionality as well as streamlined
the dataloading process for segmentation tasks.

* [chore] adds the functionality for training DCGAN

---------

Co-authored-by: Divyanshu Suman <117962699+skillingshark@users.noreply.github.com>
  • Loading branch information
ishan121028 and enigmax86 authored Oct 9, 2023
1 parent f036fc3 commit e327f00
Show file tree
Hide file tree
Showing 7 changed files with 534 additions and 4 deletions.
37 changes: 37 additions & 0 deletions UniTrain/dataset/DCGAN.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import os
import glob
from PIL import Image
import torch
from torch.utils.data import Dataset

class DCGANdataset:
def __init__(self, data_dir, transform=None):
self.data_dir = data_dir
self.transform = transform
self.classes = os.listdir(data_dir)
self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
self.data = self.load_data()

def load_data(self):
data = []
for cls in self.classes:
class_path = os.path.join(self.data_dir, cls)
class_idx = self.class_to_idx[cls]
# Include .jpg, .jpeg, and .png extensions in the glob pattern
for file_path in glob.glob(os.path.join(class_path, '*.jpg')) + \
glob.glob(os.path.join(class_path, '*.jpeg')) + \
glob.glob(os.path.join(class_path, '*.png')):
data.append((file_path, class_idx))
return data

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
img_path, target = self.data[idx]
image = Image.open(img_path).convert('RGB')

if self.transform is not None:
image = self.transform(image)

return image, target
3 changes: 2 additions & 1 deletion UniTrain/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .classification import ClassificationDataset
from .segmentation import SegmentationDataset
from .segmentation import SegmentationDataset
from .DCGAN import DCGANdataset
1 change: 0 additions & 1 deletion UniTrain/dataset/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,4 @@ def __getitem__(self, index):

# You may need to further preprocess the mask if required
# Example: Convert mask to tensor and perform class mapping

return image, mask
71 changes: 71 additions & 0 deletions UniTrain/models/DCGAN.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from torchsummary import summary
import torch.nn as nn


class disc(nn.Module):

discriminator = nn.Sequential(
# in: 3 x 64 x 64

# Layer 01 : Converts a 3 Channel Image to a 64 Channel Image (FEATURE MAP)
nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True),
# out: 64 x 32 x 32

nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
# out: 128 x 16 x 16

nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
# out: 256 x 8 x 8

nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
# out: 512 x 4 x 4

nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0, bias=False),
# out: 1 x 1 x 1


nn.Flatten(),
nn.Sigmoid()

)


class gen(nn.Module):

latent_size = 128

generator = nn.Sequential(
# in: latent_size x 1 x 1

nn.ConvTranspose2d(latent_size, 512, kernel_size=4, stride=1, padding=0, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
# out: 512 x 4 x 4

nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
# out: 256 x 8 x 8

nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
# out: 128 x 16 x 16

nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True),
# out: 64 x 32 x 32

nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1, bias=False),
nn.Tanh()
# out: 3 x 64 x 64
)
3 changes: 2 additions & 1 deletion UniTrain/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .classification import ResNet9
from .segmentation import UNet
from .segmentation import UNet
from .DCGAN import *
Loading

0 comments on commit e327f00

Please sign in to comment.