Skip to content

Commit

Permalink
Merge pull request #1237 from Junranus/dev-postgresql
Browse files Browse the repository at this point in the history
Healthcare - Diabetic Retinopathy Classification
  • Loading branch information
nudles authored Dec 1, 2024
2 parents 4646811 + da61a58 commit 2551198
Show file tree
Hide file tree
Showing 2 changed files with 348 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
<!--
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->

# Singa for Diabetic Retinopathy Classification

## Diabetic Retinopathy

Diabetic Retinopathy (DR) is a progressive eye disease caused by long-term diabetes, which damages the blood vessels in the retina, the light-sensitive tissue at the back of the eye. It typically develops in stages, starting with non-proliferative diabetic retinopathy (NPDR), where weakened blood vessels leak fluid or blood, causing swelling or the formation of deposits. If untreated, it can progress to proliferative diabetic retinopathy (PDR), characterized by the growth of abnormal blood vessels that can lead to severe vision loss or blindness. Symptoms may include blurred vision, dark spots, or difficulty seeing at night, although it is often asymptomatic in the early stages. Early diagnosis through regular eye exams and timely treatment, such as laser therapy or anti-VEGF injections, can help manage the condition and prevent vision impairment.

The dataset has 5 groups characterized by the severity of Diabetic Retinopathy (DR).

- 0: No DR
- 1: Mild Non-Proliferative DR
- 2: Moderate Non-Proliferative DR
- 3: Severe Non-Proliferative DR
- 4: Proliferative DR


To mitigate the problem, we use Singa to implement a machine learning model to help with Diabetic Retinopathy diagnosis. The dataset is from Kaggle https://www.kaggle.com/datasets/mohammadasimbluemoon/diabeticretinopathy-messidor-eyepac-preprocessed. Please download the dataset before running the scripts.

## Structure

* `data` includes the scripts for preprocessing DR image datasets.

* `model` includes the CNN model construction codes by creating
a subclass of `Module` to wrap the neural network operations
of each model.

* `train_cnn.py` is the training script, which controls the training flow by
doing BackPropagation and SGD update.

## Command
```bash
python train_cnn.py cnn diaret -dir pathToDataset
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,297 @@
from singa import singa_wrap as singa
from singa import device
from singa import tensor
from singa import opt
import numpy as np
import time
import argparse
import sys
sys.path.append("../../..")

from PIL import Image

from healthcare.data import diaret
from healthcare.models import diabetic_retinopthy_net

np_dtype = {"float16": np.float16, "float32": np.float32}

singa_dtype = {"float16": tensor.float16, "float32": tensor.float32}


# Data augmentation
def augmentation(x, batch_size):
xpad = np.pad(x, [[0, 0], [0, 0], [4, 4], [4, 4]], 'symmetric')
for data_num in range(0, batch_size):
offset = np.random.randint(8, size=2)
x[data_num, :, :, :] = xpad[data_num, :,
offset[0]:offset[0] + x.shape[2],
offset[1]:offset[1] + x.shape[2]]
if_flip = np.random.randint(2)
if (if_flip):
x[data_num, :, :, :] = x[data_num, :, :, ::-1]
return x


# Calculate accuracy
def accuracy(pred, target):
# y is network output to be compared with ground truth (int)
y = np.argmax(pred, axis=1)
a = y == target
correct = np.array(a, "int").sum()
return correct


# Data partition according to the rank
def partition(global_rank, world_size, train_x, train_y, val_x, val_y):
# Partition training data
data_per_rank = train_x.shape[0] // world_size
idx_start = global_rank * data_per_rank
idx_end = (global_rank + 1) * data_per_rank
train_x = train_x[idx_start:idx_end]
train_y = train_y[idx_start:idx_end]

# Partition evaluation data
data_per_rank = val_x.shape[0] // world_size
idx_start = global_rank * data_per_rank
idx_end = (global_rank + 1) * data_per_rank
val_x = val_x[idx_start:idx_end]
val_y = val_y[idx_start:idx_end]
return train_x, train_y, val_x, val_y


# Function to all reduce NUMPY accuracy and loss from multiple devices
def reduce_variable(variable, dist_opt, reducer):
reducer.copy_from_numpy(variable)
dist_opt.all_reduce(reducer.data)
dist_opt.wait()
output = tensor.to_numpy(reducer)
return output


def resize_dataset(x, image_size):
num_data = x.shape[0]
dim = x.shape[1]
X = np.zeros(shape=(num_data, dim, image_size, image_size),
dtype=np.float32)
for n in range(0, num_data):
for d in range(0, dim):
X[n, d, :, :] = np.array(Image.fromarray(x[n, d, :, :]).resize(
(image_size, image_size), Image.BILINEAR),
dtype=np.float32)
return X


def run(global_rank,
world_size,
dir_path,
max_epoch,
batch_size,
model,
data,
sgd,
graph,
verbosity,
dist_option='plain',
spars=None,
precision='float32'):
# now CPU version only, could change to GPU device for GPU-support machines
dev = device.get_default_device()
dev.SetRandSeed(0)
np.random.seed(0)
if data == 'diaret':
train_x, train_y, val_x, val_y = diaret.load(dir_path=dir_path)
else:
print(
'Wrong dataset!'
)
sys.exit(0)

num_channels = train_x.shape[1]
image_size = train_x.shape[2]
data_size = np.prod(train_x.shape[1:train_x.ndim]).item()
num_classes = (np.max(train_y) + 1).item()

if model == 'cnn':
model = diabetic_retinopthy_net.create_model(num_channels=num_channels,
num_classes=num_classes)
else:
print(
'Wrong model!'
)
sys.exit(0)

# For distributed training, sequential has better performance
if hasattr(sgd, "communicator"):
DIST = True
sequential = True
else:
DIST = False
sequential = False

if DIST:
train_x, train_y, val_x, val_y = partition(global_rank, world_size,
train_x, train_y, val_x,
val_y)

if model.dimension == 4:
tx = tensor.Tensor(
(batch_size, num_channels, model.input_size, model.input_size), dev,
singa_dtype[precision])
elif model.dimension == 2:
tx = tensor.Tensor((batch_size, data_size),
dev, singa_dtype[precision])
np.reshape(train_x, (train_x.shape[0], -1))
np.reshape(val_x, (val_x.shape[0], -1))

ty = tensor.Tensor((batch_size,), dev, tensor.int32)
num_train_batch = train_x.shape[0] // batch_size
num_val_batch = val_x.shape[0] // batch_size
idx = np.arange(train_x.shape[0], dtype=np.int32)

# Attach model to graph
model.set_optimizer(sgd)
model.compile([tx], is_train=True, use_graph=graph, sequential=sequential)
dev.SetVerbosity(verbosity)

# Training and evaluation loop
for epoch in range(max_epoch):
start_time = time.time()
np.random.shuffle(idx)

if global_rank == 0:
print('Starting Epoch %d:' % (epoch))

# Training phase
train_correct = np.zeros(shape=[1], dtype=np.float32)
test_correct = np.zeros(shape=[1], dtype=np.float32)
train_loss = np.zeros(shape=[1], dtype=np.float32)

model.train()
for b in range(num_train_batch):
# if b % 100 == 0:
# print ("b: \n", b)
# Generate the patch data in this iteration
x = train_x[idx[b * batch_size:(b + 1) * batch_size]]
if model.dimension == 4:
x = augmentation(x, batch_size)
if (image_size != model.input_size):
x = resize_dataset(x, model.input_size)
x = x.astype(np_dtype[precision])
y = train_y[idx[b * batch_size:(b + 1) * batch_size]]

# Copy the patch data into input tensors
tx.copy_from_numpy(x)
ty.copy_from_numpy(y)

# Train the model
out, loss = model(tx, ty, dist_option, spars)
train_correct += accuracy(tensor.to_numpy(out), y)
train_loss += tensor.to_numpy(loss)[0]

if DIST:
# Reduce the evaluation accuracy and loss from multiple devices
reducer = tensor.Tensor((1,), dev, tensor.float32)
train_correct = reduce_variable(train_correct, sgd, reducer)
train_loss = reduce_variable(train_loss, sgd, reducer)

if global_rank == 0:
print('Training loss = %f, training accuracy = %f' %
(train_loss, train_correct /
(num_train_batch * batch_size * world_size)),
flush=True)

# Evaluation phase
model.eval()
for b in range(num_val_batch):
x = val_x[b * batch_size:(b + 1) * batch_size]
if model.dimension == 4:
if (image_size != model.input_size):
x = resize_dataset(x, model.input_size)
x = x.astype(np_dtype[precision])
y = val_y[b * batch_size:(b + 1) * batch_size]
tx.copy_from_numpy(x)
ty.copy_from_numpy(y)
out_test = model(tx)
test_correct += accuracy(tensor.to_numpy(out_test), y)

if DIST:
# Reduce the evaulation accuracy from multiple devices
test_correct = reduce_variable(test_correct, sgd, reducer)

# Output the evaluation accuracy
if global_rank == 0:
print('Evaluation accuracy = %f, Elapsed Time = %fs' %
(test_correct / (num_val_batch * batch_size * world_size),
time.time() - start_time),
flush=True)

dev.PrintTimeProfiling()


if __name__ == '__main__':
# Use argparse to get command config: max_epoch, model, data, etc., for single gpu training
parser = argparse.ArgumentParser(
description='Training using the autograd and graph.')
parser.add_argument(
'model',
choices=['cnn'],
default='cnn')
parser.add_argument('data',
choices=['diaret'],
default='diaret')
parser.add_argument('-p',
choices=['float32', 'float16'],
default='float32',
dest='precision')
parser.add_argument('-dir',
'--dir-path',
default="/tmp/diaret",
type=str,
help='the directory to store the Diabetic Retinopathy dataset',
dest='dir_path')
parser.add_argument('-m',
'--max-epoch',
default=300,
type=int,
help='maximum epochs',
dest='max_epoch')
parser.add_argument('-b',
'--batch-size',
default=64,
type=int,
help='batch size',
dest='batch_size')
parser.add_argument('-l',
'--learning-rate',
default=0.005,
type=float,
help='initial learning rate',
dest='lr')
parser.add_argument('-g',
'--disable-graph',
default='True',
action='store_false',
help='disable graph',
dest='graph')
parser.add_argument('-v',
'--log-verbosity',
default=0,
type=int,
help='logging verbosity',
dest='verbosity')

args = parser.parse_args()

sgd = opt.SGD(lr=args.lr, momentum=0.9, weight_decay=1e-5,
dtype=singa_dtype[args.precision])
run(0,
1,
args.dir_path,
args.max_epoch,
args.batch_size,
args.model,
args.data,
sgd,
args.graph,
args.verbosity,
precision=args.precision)

0 comments on commit 2551198

Please sign in to comment.