Skip to content

Commit

Permalink
Merge pull request #120 from Captain272/master
Browse files Browse the repository at this point in the history
add custom dataset training
  • Loading branch information
WuJunde authored Jun 12, 2023
2 parents c5070a9 + 03a6713 commit 752f999
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 2 deletions.
61 changes: 61 additions & 0 deletions guided_diffusion/custom_dataset_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os
import sys
import pickle
import cv2
from skimage import io
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms.functional as F
import torchvision.transforms as transforms
import pandas as pd
from skimage.transform import rotate
from glob import glob
from sklearn.model_selection import train_test_split

class CustomDataset(Dataset):
def __init__(self, args, data_path , transform = None, mode = 'Training',plane = False):

print("loading data from the directory :",data_path)
path=data_path
images = sorted(glob(os.path.join(path, "images/*.png")))
masks = sorted(glob(os.path.join(path, "masks/*.png")))

self.name_list = images[:2]
self.label_list = masks[:2]
self.data_path = path
self.mode = mode

self.transform = transform

def __len__(self):
return len(self.name_list)

def __getitem__(self, index):
"""Get the images"""
name = self.name_list[index]
img_path = os.path.join(name)

mask_name = self.label_list[index]
msk_path = os.path.join(mask_name)

img = Image.open(img_path).convert('RGB')
mask = Image.open(msk_path).convert('L')

if self.mode == 'Training':
label = 0 if self.label_list[index] == 'benign' else 1
else:
label = int(self.label_list[index])

if self.transform:
state = torch.get_rng_state()
img = self.transform(img)
torch.set_rng_state(state)
mask = self.transform(mask)

if self.mode == 'Training':
return (img, mask, name)
else:
return (img, mask, name)
12 changes: 10 additions & 2 deletions scripts/segmentation_train.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@

import sys
import argparse
sys.path.append("..")
sys.path.append(".")
sys.path.append("../")
sys.path.append("./")
from guided_diffusion import dist_util, logger
from guided_diffusion.resample import create_named_schedule_sampler
from guided_diffusion.bratsloader import BRATSDataset, BRATSDataset3D
from guided_diffusion.isicloader import ISICDataset
from guided_diffusion.custom_dataset_loader import CustomDataset
from guided_diffusion.script_util import (
model_and_diffusion_defaults,
create_model_and_diffusion,
Expand Down Expand Up @@ -39,6 +40,13 @@ def main():

ds = BRATSDataset3D(args.data_dir, transform_train, test_flag=False)
args.in_ch = 5
else :
tran_list = [transforms.Resize((args.image_size,args.image_size)), transforms.ToTensor(),]
transform_train = transforms.Compose(tran_list)
print("Your current directory : ",args.data_dir)
ds = CustomDataset(args, args.data_dir, transform_train)
args.in_ch = 4

datal= th.utils.data.DataLoader(
ds,
batch_size=args.batch_size,
Expand Down

0 comments on commit 752f999

Please sign in to comment.