From bc899000e5c92e121577b4c2ac35cb26615f2621 Mon Sep 17 00:00:00 2001 From: JordiCorbilla Date: Mon, 16 Dec 2019 12:18:00 +0000 Subject: [PATCH] Adding data augmentation classes --- odir.py | 8 +- odir_advance_plotting.py | 272 +++++++++++++++++++++++++++ odir_data_augmentation_display.py | 254 +++++++++++++++++++++++++ odir_data_augmentation_generator.py | 43 +++++ odir_data_augmentation_runner.py | 108 +++++++++++ odir_data_augmentation_strategies.py | 140 ++++++++++++++ odir_discarded_images.py | 54 ++++++ odir_eye_patient.py | 2 +- 8 files changed, 876 insertions(+), 5 deletions(-) create mode 100644 odir_advance_plotting.py create mode 100644 odir_data_augmentation_display.py create mode 100644 odir_data_augmentation_generator.py create mode 100644 odir_data_augmentation_runner.py create mode 100644 odir_data_augmentation_strategies.py create mode 100644 odir_discarded_images.py diff --git a/odir.py b/odir.py index ae90db0..c1b5561 100644 --- a/odir.py +++ b/odir.py @@ -1,4 +1,4 @@ -# Copyright 2019 Jordi Corbilla. All Rights Reserved. +# Copyright 2019-2020 Jordi Corbilla. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,15 +15,15 @@ import numpy as np -def load_data(image_size): +def load_data(image_size, index): """Loads the ODIR dataset. Returns: Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. """ - x_train = np.load('odir_training'+'_' + str(image_size)+'.npy') - y_train = np.load('odir_training_labels'+'_' + str(image_size)+'.npy') + x_train = np.load('odir_training'+'_' + str(image_size) + '_' + str(index)+'.npy') + y_train = np.load('odir_training_labels'+'_' + str(image_size) + '_' + str(index)+'.npy') x_test = np.load('odir_testing'+'_' + str(image_size)+'.npy') y_test = np.load('odir_testing_labels'+'_' + str(image_size)+'.npy') diff --git a/odir_advance_plotting.py b/odir_advance_plotting.py new file mode 100644 index 0000000..69c82d7 --- /dev/null +++ b/odir_advance_plotting.py @@ -0,0 +1,272 @@ +# Copyright 2019 Jordi Corbilla. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import, division, print_function, unicode_literals + +import sys + +import matplotlib.pyplot as plt +from sklearn.metrics import confusion_matrix +import numpy as np +import seaborn as sns +import matplotlib as mpl + + +class Plotter: + def __init__(self, class_names): + self.class_names = class_names + + def plot_metrics(self, history, test_run, index): + metrics2 = ['loss', 'auc', 'precision', 'recall'] + for n, metric in enumerate(metrics2): + name = metric.replace("_", " ").capitalize() + plt.subplot(2, 2, n + 1) + plt.plot(history.epoch, history.history[metric], color='green', label='Train') + plt.plot(history.epoch, history.history['val_' + metric], color='green', linestyle="--", label='Val') + plt.xlabel('Epoch') + plt.ylabel(name) + if metric == 'loss': + plt.ylim([0, plt.ylim()[1]]) + elif metric == 'auc': + plt.ylim([0.8, 1]) + else: + plt.ylim([0, 1]) + + plt.legend() + + plt.savefig('image_run' + str(index) + test_run + '.png') + plt.show() + plt.close() + + def plot_input_images(self, x_train, y_train): + plt.figure(figsize=(9, 9)) + for i in range(100): + plt.subplot(10, 10, i + 1) + plt.xticks([]) + plt.yticks([]) + plt.grid(False) + plt.imshow(x_train[i]) + classes = "" + for j in range(8): + if y_train[i][j] >= 0.5: + classes = classes + self.class_names[j] + "\n" + plt.xlabel(classes, fontsize=7, color='black', labelpad=1) + + plt.subplots_adjust(bottom=0.04, right=0.95, top=0.94, left=0.06, wspace=0.56, hspace=0.17) + plt.show() + + def plot_image(self, i, predictions_array, true_label, img): + predictions_array, true_label, img = predictions_array[i], true_label[i], img[i] + plt.grid(False) + plt.xticks([]) + plt.yticks([]) + + plt.imshow(img) + + ground = "" + count_true = 0 + predicted_true = 0 + + for index in range(8): + if true_label[index] >= 0.5: + count_true = count_true + 1 + ground = ground + self.class_names[index] + "\n" + if predictions_array[index] >= 0.5: + predicted_true = predicted_true + 1 + + if count_true == predicted_true: + color = 'green' + else: + color = 'red' + + first, second, third, i, j, k = self.calculate_3_largest(predictions_array, 8) + prediction = "{} {:2.0f}% \n".format(self.class_names[i], 100 * first) + if second > 0.1: + prediction = prediction + "{} {:2.0f}% \n".format(self.class_names[j], 100 * second) + if third > 0.1: + prediction = prediction + "{} {:2.0f}% \n".format(self.class_names[k], 100 * third) + plt.xlabel("Predicted: {} Ground Truth: {}".format(prediction, ground), color=color) + + def calculate_3_largest(self, arr, arr_size): + if arr_size < 3: + print(" Invalid Input ") + return + + third = first = second = -sys.maxsize + index_1 = 0 + index_2 = 0 + index_3 = 0 + + for i in range(0, arr_size): + if arr[i] > first: + third = second + second = first + first = arr[i] + elif arr[i] > second: + third = second + second = arr[i] + elif arr[i] > third: + third = arr[i] + + for i in range(0, arr_size): + if arr[i] == first: + index_1 = i + for i in range(0, arr_size): + if arr[i] == second and i != index_1: + index_2 = i + for i in range(0, arr_size): + if arr[i] == third and i != index_1 and i!= index_2: + index_3 = i + return first, second, third, index_1, index_2, index_3 + + def plot_value_array(self, i, predictions_array, true_label): + predictions_array, true_label = predictions_array[i], true_label[i] + plt.grid(False) + plt.xticks([]) + plt.yticks([]) + bar_plot = plt.bar(range(8), predictions_array, color="#777777") + plt.xticks(range(8), ('N', 'D', 'G', 'C', 'A', 'H', 'M', 'O')) + plt.ylim([0, 1]) + + for j in range(8): + if true_label[j] >= 0.5: + bar_plot[j].set_color('green') + + for j in range(8): + if predictions_array[j] >= 0.5 and true_label[j] < 0.5: + bar_plot[j].set_color('red') + + def bar_label(rects): + for rect in rects: + height = rect.get_height() + value = height * 100 + if value > 1: + plt.annotate('{:2.0f}%'.format(value), + xy=(rect.get_x() + rect.get_width() / 2, height), + xytext=(0, 3), # 3 points vertical offset + textcoords="offset points", + ha='center', va='bottom') + + bar_label(bar_plot) + + def ensure_test_prediction_exists(self, predictions): + exists = False + for j in range(8): + if predictions[j] >= 0.5: + exists = True + return exists + + def plot_output(self, test_predictions_baseline, y_test, x_test_drawing): + mpl.rcParams["font.size"] = 7 + num_rows = 5 + num_cols = 3 + num_images = num_rows * num_cols + plt.figure(figsize=(2 * 2 * num_cols, 2 * num_rows)) + j = 0 + i = 0 + while j < num_images: + if self.ensure_test_prediction_exists(test_predictions_baseline[i]): + plt.subplot(num_rows, 2 * num_cols, 2 * j + 1) + self.plot_image(i, test_predictions_baseline, y_test, x_test_drawing) + plt.subplot(num_rows, 2 * num_cols, 2 * j + 2) + self.plot_value_array(i, test_predictions_baseline, y_test) + j = j + 1 + i = i + 1 + if i > 400: + break + + plt.subplots_adjust(bottom=0.08, right=0.95, top=0.94, left=0.05, wspace=0.11, hspace=0.56) + plt.show() + + def plot_output_single(self, i, test_predictions_baseline, y_test, x_test_drawing): + plt.figure(figsize=(6, 3)) + plt.subplot(1, 2, 1) + self.plot_image(i, test_predictions_baseline, y_test, x_test_drawing) + plt.subplot(1, 2, 2) + self.plot_value_array(i, test_predictions_baseline, y_test) + plt.show() + + def plot_confusion_matrix(self, y_true, y_pred, classes, normalize=False, title=None, cmap=plt.cm.Blues): + """ + This function prints and plots the confusion matrix. + Normalization can be applied by setting `normalize=True`. + """ + if not title: + if normalize: + title = 'Normalized confusion matrix' + else: + title = 'Confusion matrix, without normalization' + + # Compute confusion matrix + cm = confusion_matrix(y_true.argmax(axis=1), y_pred.argmax(axis=1)) + # Only use the labels that appear in the data + if normalize: + cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] + print("Normalized confusion matrix") + else: + print('Confusion matrix, without normalization') + + print(cm) + + fig, ax = plt.subplots() + im = ax.imshow(cm, interpolation='nearest', cmap=cmap) + ax.figure.colorbar(im, ax=ax) + # We want to show all ticks... + ax.set(xticks=np.arange(cm.shape[1]), + yticks=np.arange(cm.shape[0]), + # ... and label them with the respective list entries + # xticklabels=classes, yticklabels=classes, + title=title, + ylabel='True label', + xlabel='Predicted label') + ax.set_ylim(8.0, -1.0) + # Rotate the tick labels and set their alignment. + plt.setp(ax.get_xticklabels(), rotation=45, ha="right", + rotation_mode="anchor") + + # Loop over data dimensions and create text annotations. + fmt = '.2f' if normalize else 'd' + thresh = cm.max() / 2. + for i in range(cm.shape[0]): + for j in range(cm.shape[1]): + ax.text(j, i, format(cm[i, j], fmt), + ha="center", va="center", + color="white" if cm[i, j] > thresh else "black") + fig.tight_layout() + return ax + + def print_normalized_confusion_matrix(self, y_test, test_predictions_baseline): + np.set_printoptions(precision=2) + + # Plot non-normalized confusion matrix + self.plot_confusion_matrix(y_test, test_predictions_baseline, classes=self.class_names, + title='Confusion matrix, without normalization') + + # Plot normalized confusion matrix + self.plot_confusion_matrix(y_test, test_predictions_baseline, classes=self.class_names, normalize=True, + title='Normalized confusion matrix') + + plt.show() + + def plot_confusion_matrix_generic(self, labels2, predictions, test_run, p=0.5): + cm = confusion_matrix(labels2.argmax(axis=1), predictions.argmax(axis=1)) + plt.figure(figsize=(6, 6)) + ax = sns.heatmap(cm, annot=True, fmt="d") + ax.set_ylim(8.0, -1.0) + plt.title('Confusion matrix') + plt.ylabel('Actual label') + plt.xlabel('Predicted label') + plt.savefig('image_run3' + test_run + '.png') + plt.show() + plt.close() diff --git a/odir_data_augmentation_display.py b/odir_data_augmentation_display.py new file mode 100644 index 0000000..3b9aa90 --- /dev/null +++ b/odir_data_augmentation_display.py @@ -0,0 +1,254 @@ +# Copyright 2019-2020 Jordi Corbilla. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import, division, print_function, unicode_literals + +import logging.config +import os + +import numpy as np +from absl import app +import cv2 +from odir_image_treatment import ImageTreatment +import matplotlib.pyplot as plt + +def main(argv): + treatment = ImageTreatment(image_size) + file = '2_right.jpg' + file_path = r'C:\temp\ODIR-5K_Training_Dataset_treated_' + str(image_size) + saving_path = r'C:\temp\ODIR-5K_Training_Dataset_augmented_' + str(image_size) + file_id = file.replace('.jpg', '') + + #Get the image in the correct format + eye_image = os.path.join(file_path, file) + image = cv2.imread(eye_image) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + original_image = image + + ## Generate brightness images + bright = treatment.brightness(image, 0.1) + plt.subplots(figsize=(10, 10)) + plt.subplot(2, 2, 1) + plt.imshow(original_image) + plt.title('Base Image') + plt.subplot(2, 2, 2) + plt.imshow(bright) + plt.title('Delta = 0.1') + plt.show() + plt.close() + bright = cv2.cvtColor(bright, cv2.COLOR_BGR2RGB) + status = cv2.imwrite(os.path.join(saving_path, file_id + '_a.jpg'), bright) + print("Image written to file-system : ", status) + + ## Generate brightness images + contrast = treatment.contrast(image, 2) + plt.subplots(figsize=(10, 10)) + plt.subplot(2, 2, 1) + plt.imshow(original_image) + plt.title('Base Image') + plt.subplot(2, 2, 2) + plt.imshow(contrast) + plt.title('Contrast Factor = 2') + plt.show() + plt.close() + contrast = cv2.cvtColor(contrast, cv2.COLOR_BGR2RGB) + status = cv2.imwrite(os.path.join(saving_path, file_id + '_b.jpg'), contrast) + print("Image written to file-system : ", status) + + ## Generate brightness images + saturation = treatment.saturation(image, 0.5) + plt.subplots(figsize=(10, 10)) + plt.subplot(2, 2, 1) + plt.imshow(original_image) + plt.title('Base Image') + plt.subplot(2, 2, 2) + plt.imshow(saturation) + plt.title('Saturation Factor = 2') + plt.show() + plt.close() + saturation = cv2.cvtColor(saturation, cv2.COLOR_BGR2RGB) + status = cv2.imwrite(os.path.join(saving_path, file_id + '_c.jpg'), saturation) + print("Image written to file-system : ", status) + + ## Generate scaling images + vector = [0.90, 0.80, 0.70, 0.50] + newImages = treatment.scaling(image, vector) + + plt.subplots(figsize = (10, 10)) + plt.subplot(2, 2, 1) + plt.imshow(original_image) + plt.title('Base Image') + plt.subplot(2, 2, 2) + plt.imshow(newImages[0]) + plt.title('Scale = 0.90') + plt.subplot(2, 2, 3) + plt.imshow(newImages[1]) + plt.title('Scale = 0.80') + plt.subplot(2, 2, 4) + plt.imshow(newImages[2]) + plt.title('Scale = 0.70') + plt.show() + plt.close() + for i in range(len(vector)): + saving_image = cv2.cvtColor(newImages[i], cv2.COLOR_BGR2RGB) + status = cv2.imwrite(os.path.join(saving_path, file_id + '_d'+str(i)+'.jpg'), saving_image) + print("Image written to file-system : ", status) + + intensity = treatment.rescale_intensity(original_image) + plt.subplots(figsize=(10, 10)) + plt.subplot(2, 2, 1) + plt.imshow(original_image) + plt.title('Base Image') + plt.subplot(2, 2, 2) + plt.imshow(intensity) + plt.title('Rescale Intensity = 2-98%') + plt.show() + plt.close() + intensity = cv2.cvtColor(intensity, cv2.COLOR_BGR2RGB) + status = cv2.imwrite(os.path.join(saving_path, file_id + '_e.jpg'), intensity) + print("Image written to file-system : ", status) + + gamma = treatment.gamma(original_image, 0.5) + plt.subplots(figsize=(10, 10)) + plt.subplot(2, 2, 1) + plt.imshow(original_image) + plt.title('Base Image') + plt.subplot(2, 2, 2) + plt.imshow(gamma) + plt.title('Gamma = 0.2') + plt.show() + plt.close() + gamma = cv2.cvtColor(gamma, cv2.COLOR_BGR2RGB) + status = cv2.imwrite(os.path.join(saving_path, file_id + '_f.jpg'), gamma) + print("Image written to file-system : ", status) + + hue = treatment.hue(original_image, 0.2) + plt.subplots(figsize=(10, 10)) + plt.subplot(2, 2, 1) + plt.imshow(original_image) + plt.title('Base Image') + plt.subplot(2, 2, 2) + plt.imshow(hue) + plt.title('Gamma = 0.2') + plt.show() + plt.close() + hue = cv2.cvtColor(hue, cv2.COLOR_BGR2RGB) + status = cv2.imwrite(os.path.join(saving_path, file_id + '_g.jpg'), hue) + print("Image written to file-system : ", status) + + central = treatment.crop_to_bounding_box(original_image, 0, 0, 112,112) + plt.subplots(figsize=(10, 10)) + plt.subplot(2, 2, 1) + plt.imshow(original_image) + plt.title('Base Image') + plt.subplot(2, 2, 2) + plt.imshow(central) + plt.title('Gamma = 0.2') + plt.show() + plt.close() + central = cv2.cvtColor(central, cv2.COLOR_BGR2RGB) + status = cv2.imwrite(os.path.join(saving_path, file_id + '_h.jpg'), central) + print("Image written to file-system : ", status) + + central = treatment.crop_to_bounding_box(original_image, 112, 0, 112,112) + plt.subplots(figsize=(10, 10)) + plt.subplot(2, 2, 1) + plt.imshow(original_image) + plt.title('Base Image') + plt.subplot(2, 2, 2) + plt.imshow(central) + plt.title('Gamma = 0.2') + plt.show() + plt.close() + central = cv2.cvtColor(central, cv2.COLOR_BGR2RGB) + status = cv2.imwrite(os.path.join(saving_path, file_id + '_i.jpg'), central) + print("Image written to file-system : ", status) + + central = treatment.crop_to_bounding_box(original_image, 0, 112, 112, 112) + plt.subplots(figsize=(10, 10)) + plt.subplot(2, 2, 1) + plt.imshow(original_image) + plt.title('Base Image') + plt.subplot(2, 2, 2) + plt.imshow(central) + plt.title('Gamma = 0.2') + plt.show() + plt.close() + central = cv2.cvtColor(central, cv2.COLOR_BGR2RGB) + status = cv2.imwrite(os.path.join(saving_path, file_id + '_j.jpg'), central) + print("Image written to file-system : ", status) + + central = treatment.crop_to_bounding_box(original_image, 112, 112, 112,112) + plt.subplots(figsize=(10, 10)) + plt.subplot(2, 2, 1) + plt.imshow(original_image) + plt.title('Base Image') + plt.subplot(2, 2, 2) + plt.imshow(central) + plt.title('Gamma = 0.2') + plt.show() + plt.close() + central = cv2.cvtColor(central, cv2.COLOR_BGR2RGB) + status = cv2.imwrite(os.path.join(saving_path, file_id + '_k.jpg'), central) + print("Image written to file-system : ", status) + + # central = treatment.central_crop(original_image, 0.5) + # plt.subplots(figsize=(10, 10)) + # plt.subplot(2, 2, 1) + # plt.imshow(original_image) + # plt.title('Base Image') + # plt.subplot(2, 2, 2) + # plt.imshow(central) + # plt.title('Gamma = 0.2') + # plt.show() + # plt.close() + # central = cv2.cvtColor(central, cv2.COLOR_BGR2RGB) + # status = cv2.imwrite(os.path.join(saving_path, file_id + '_h.jpg'), central) + # print("Image written to file-system : ", status) + + # hist = treatment.equalize_histogram(original_image) + # plt.subplots(figsize=(10, 10)) + # plt.subplot(2, 2, 1) + # plt.imshow(original_image) + # plt.title('Base Image') + # plt.subplot(2, 2, 2) + # plt.imshow(hist) + # plt.title('Equialize Histogram') + # plt.show() + # plt.close() + # #hist = cv2.cvtColor(hist, cv2.COLOR_BGR2RGB) + # status = cv2.imwrite(os.path.join(saving_path, file_id + '_e.jpg'), hist) + # print("Image written to file-system : ", status) + # + # equalize = treatment.equalize_adapthist(original_image) + # plt.subplots(figsize=(10, 10)) + # plt.subplot(2, 2, 1) + # plt.imshow(original_image) + # plt.title('Base Image') + # plt.subplot(2, 2, 2) + # plt.imshow(equalize) + # plt.title('equalize adapt hist - 0.03') + # plt.show() + # plt.close() + # equalize = cv2.cvtColor(equalize, cv2.COLOR_BGR2RGB) + # status = cv2.imwrite(os.path.join(saving_path, file_id + '_f.jpg'), equalize) + # print("Image written to file-system : ", status) + + +if __name__ == '__main__': + # create logger + logging.config.fileConfig('logging.conf') + logger = logging.getLogger('odir') + image_size = 224 + app.run(main) diff --git a/odir_data_augmentation_generator.py b/odir_data_augmentation_generator.py new file mode 100644 index 0000000..61fd5c7 --- /dev/null +++ b/odir_data_augmentation_generator.py @@ -0,0 +1,43 @@ +# Copyright 2019 Jordi Corbilla. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import, division, print_function, unicode_literals +import numpy as np +from tensorflow.keras.preprocessing.image import ImageDataGenerator + + +class DataGenerator: + def data_augmentation(x_train, y_train, augment_size=25000): + image_generator = ImageDataGenerator( + rotation_range=10, + zoom_range=1.1, + width_shift_range=0.07, + height_shift_range=0.07, + brightness_range=[0.2,1.0], + shear_range=0.25, + horizontal_flip=False, + vertical_flip=False, + data_format="channels_last") + # fit data for zca whitening + image_generator.fit(x_train, augment=True) + # get transformed images + randidx = np.random.randint(x_train.shape[0], size=augment_size) + x_augmented = x_train[randidx].copy() + y_augmented = y_train[randidx].copy() + x_augmented = image_generator.flow(x_augmented, np.zeros(augment_size), + batch_size=augment_size, shuffle=False).next()[0] + # append augmented data to trainset + x_train2 = np.concatenate((x_train, x_augmented)) + y_train2 = np.concatenate((y_train, y_augmented)) + return x_train2, y_train2 \ No newline at end of file diff --git a/odir_data_augmentation_runner.py b/odir_data_augmentation_runner.py new file mode 100644 index 0000000..6d3d328 --- /dev/null +++ b/odir_data_augmentation_runner.py @@ -0,0 +1,108 @@ +# Copyright 2019-2020 Jordi Corbilla. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import csv +import logging.config +import os +from absl import app + +from odir_data_augmentation_strategies import DataAugmentationStrategy +from odir_load_ground_truth_files import GroundTruthFiles + + +def write_header(): + with open(r'ground_truth\odir_augmented.csv', 'w', newline='') as csv_file: + file_writer = csv.writer(csv_file, delimiter=',', quotechar='|', quoting=csv.QUOTE_MINIMAL) + file_writer.writerow(['ID', 'Normal', 'Diabetes', 'Glaucoma', 'Cataract', 'AMD', 'Hypertension', + 'Myopia', 'Others']) + return file_writer + + +def process_files(images, cache, files): + total = 0 + for strategy in range(len(images)): + images_to_process = images[strategy][0] + samples_per_image = images[strategy][1] + for image_index in range(images_to_process): + image_vector = files[image_index] + file_name = image_vector[0] + + # Only check during the first strategy + if strategy == 0: + if file_name not in cache: + cache[file_name] = 1 + else: + cache[file_name] = cache[file_name] * 20 + + # print('Processing: ' + file_name) + augment = DataAugmentationStrategy(image_size, file_name) + count = augment.generate_images(samples_per_image, image_vector, cache[file_name]) + total = total + count + return total + + +def main(argv): + # load the ground truth file + files = GroundTruthFiles() + files.populate_vectors(csv_path) + + print('files record count order by size ASC') + print('hypertension ' + str(len(files.hypertension))) + print('myopia ' + str(len(files.myopia))) + print('cataract ' + str(len(files.cataract))) + print('amd ' + str(len(files.amd))) + print('glaucoma ' + str(len(files.glaucoma))) + print('others ' + str(len(files.others))) + print('diabetes ' + str(len(files.diabetes))) + + images_hypertension = [[len(files.hypertension), 13], [128, 14]] + images_myopia = [[len(files.myopia), 9], [196, 14]] + images_cataract = [[len(files.cataract), 9], [66, 14]] + images_amd = [[len(files.amd), 9], [16, 14]] + images_glaucoma = [[len(files.glaucoma), 7], [312, 14]] + images_others = [[len(files.others), 1], [568, 14]] + images_diabetes = [[1038, 1]] + + # Delete previous file + exists = os.path.isfile(r'ground_truth\odir_augmented.csv') + if exists: + os.remove(r'ground_truth\odir_augmented.csv') + + write_header() + + images_processed = {} + + total_hypertension = process_files(images_hypertension, images_processed, files.hypertension) + total_myopia = process_files(images_myopia, images_processed, files.myopia) + total_cataract = process_files(images_cataract, images_processed, files.cataract) + total_amd = process_files(images_amd, images_processed, files.amd) + total_glaucoma = process_files(images_glaucoma, images_processed, files.glaucoma) + total_others = process_files(images_others, images_processed, files.others) + total_diabetes = process_files(images_diabetes, images_processed, files.diabetes) + + print("total generated hypertension: " + str(total_hypertension)) + print("total generated myopia: " + str(total_myopia)) + print("total generated cataract: " + str(total_cataract)) + print("total generated amd: " + str(total_amd)) + print("total generated glaucoma: " + str(total_glaucoma)) + print("total generated others: " + str(total_others)) + print("total generated diabetes: " + str(total_diabetes)) + +if __name__ == '__main__': + # create logger + logging.config.fileConfig('logging.conf') + logger = logging.getLogger('odir') + image_size = 224 + csv_path = 'ground_truth\odir.csv' + app.run(main) diff --git a/odir_data_augmentation_strategies.py b/odir_data_augmentation_strategies.py new file mode 100644 index 0000000..2c18e02 --- /dev/null +++ b/odir_data_augmentation_strategies.py @@ -0,0 +1,140 @@ +# Copyright 2019-2020 Jordi Corbilla. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import csv +import os +import cv2 +from odir_image_treatment import ImageTreatment + + +class DataAugmentationStrategy: + def __init__(self, image_size, file_name): + self.base_image = file_name + self.treatment = ImageTreatment(image_size) + self.file_path = r'C:\temp\ODIR-5K_Training_Dataset_treated_' + str(image_size) + self.saving_path = r'C:\temp\ODIR-5K_Training_Dataset_augmented_' + str(image_size) + self.file_id = file_name.replace('.jpg', '') + + def save_image(self, original_vector, image, sample): + central = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + file = self.file_id + '_'+str(sample)+'.jpg' + file_name = os.path.join(self.saving_path, file) + exists = os.path.isfile(file_name) + if exists: + print("duplicate file found: " + file_name) + + status = cv2.imwrite(file_name, central) + + with open(r'ground_truth\odir_augmented.csv', 'a', newline='') as csv_file: + file_writer = csv.writer(csv_file, delimiter=',', quotechar='|', quoting=csv.QUOTE_MINIMAL) + file_writer.writerow([file, original_vector[1], original_vector[2], original_vector[3], original_vector[4], + original_vector[5], original_vector[6], original_vector[7], original_vector[8]]) + + #print(file_name + " written to file-system : ", status) + + def generate_images(self, number_samples, original_vector, weights): + eye_image = os.path.join(self.file_path, self.base_image) + image = cv2.imread(eye_image) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + original_image = image + saved = 0 + + # For any repeating elements, just give the other output + # We are only expecting up to 3 repetitions + if weights == 20: + original_image = self.treatment.rot90(original_image, 2) + if weights == 400: + original_image = self.treatment.rot90(original_image, 3) + if weights > 401: + print(str(self.file_id) + ' samples:' + str(number_samples)) + raise ValueError('this cannot happen') + + # for the sample type 14, just generate 1 image and leave the method + if number_samples == 14: + central = self.treatment.rot90(original_image, 1) + self.save_image(original_vector, central, weights+14) + saved = saved +1 + return saved + + if number_samples > 0: + central = self.treatment.crop_to_bounding_box(original_image, 0, 0, 112, 112) + self.save_image(original_vector, central, weights+0) + saved = saved + 1 + + if number_samples > 1: + central = self.treatment.crop_to_bounding_box(original_image, 112, 0, 112, 112) + self.save_image(original_vector, central, weights+1) + saved = saved + 1 + + if number_samples > 2: + central = self.treatment.crop_to_bounding_box(original_image, 0, 112, 112, 112) + self.save_image(original_vector, central, weights+2) + saved = saved + 1 + + if number_samples > 3: + central = self.treatment.crop_to_bounding_box(original_image, 112, 112, 112, 112) + self.save_image(original_vector, central, weights+3) + saved = saved + 1 + + if number_samples > 4: + vector = [0.50] + central = self.treatment.scaling(original_image, vector) + self.save_image(original_vector, central[0], weights+4) + saved = saved + 1 + + if number_samples > 5: + vector = [0.70] + central = self.treatment.scaling(original_image, vector) + self.save_image(original_vector, central[0], weights+5) + saved = saved + 1 + + if number_samples > 6: + vector = [0.80] + central = self.treatment.scaling(original_image, vector) + self.save_image(original_vector, central[0], weights+6) + saved = saved + 1 + + if number_samples > 7: + vector = [0.90] + central = self.treatment.scaling(original_image, vector) + self.save_image(original_vector, central[0], weights+7) + saved = saved + 1 + + if number_samples > 8: + central = self.treatment.rescale_intensity(original_image) + self.save_image(original_vector, central, weights+8) + saved = saved + 1 + + if number_samples > 9: + central = self.treatment.contrast(original_image, 2) + self.save_image(original_vector, central, weights+9) + saved = saved + 1 + + if number_samples > 10: + central = self.treatment.saturation(original_image, 0.5) + self.save_image(original_vector, central, weights+10) + saved = saved + 1 + + if number_samples > 11: + central = self.treatment.gamma(original_image, 0.5) + self.save_image(original_vector, central, weights+11) + saved = saved + 1 + + if number_samples > 12: + central = self.treatment.hue(original_image, 0.2) + self.save_image(original_vector, central, weights+12) + saved = saved + 1 + + return saved + diff --git a/odir_discarded_images.py b/odir_discarded_images.py new file mode 100644 index 0000000..307d50f --- /dev/null +++ b/odir_discarded_images.py @@ -0,0 +1,54 @@ +# Copyright 2019-2020 Jordi Corbilla. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import logging +import pandas as pd +import xlrd as x +import csv + +spreadsheet = r"C:\Users\thund\Source\Repos\TFM-ODIR\models\image_classification\DiscardedImages.xlsx" +sheet = pd.read_excel(spreadsheet, sheet_name="Sheet1") +Patients = {} + +with open('discarded.csv', 'w', newline='') as csv_file: + file_writer = csv.writer(csv_file, delimiter=',', quotechar='|', quoting=csv.QUOTE_MINIMAL) + file_writer.writerow( + ['ID', 'Fundus', 'Diagnostic', 'Normal', 'Diabetes', 'Glaucoma', 'Cataract', 'AMD', 'Hypertension', + 'Myopia', 'Others']) + for i in sheet.index: + # load the data from the excel sheet + patient_id = sheet['ID'][i] + left_fundus = sheet['Left-Fundus'][i] + right_fundus = sheet['Right-Fundus'][i] + left_keywords = sheet['Left-Diagnostic Keywords'][i] + right_keywords = sheet['Right-Diagnostic Keywords'][i] + normal = sheet['N'][i] + diabetes = sheet['D'][i] + glaucoma = sheet['G'][i] + cataract = sheet['C'][i] + amd = sheet['A'][i] + hypertension = sheet['H'][i] + myopia = sheet['M'][i] + others = sheet['O'][i] + left_keywords = left_keywords.replace(",", "|") + right_keywords = right_keywords.replace(",", "|") + left_keywords = left_keywords.replace(",", "|") + right_keywords = right_keywords.replace(",", "|") + print(patient_id) + file_writer.writerow([patient_id, left_fundus, left_keywords, normal, diabetes, glaucoma, cataract, amd, hypertension, myopia, + others]) + file_writer.writerow([patient_id, right_fundus, right_keywords, normal, diabetes, glaucoma, cataract, amd, hypertension, myopia, + others]) + diff --git a/odir_eye_patient.py b/odir_eye_patient.py index 8708175..1935160 100644 --- a/odir_eye_patient.py +++ b/odir_eye_patient.py @@ -1,4 +1,4 @@ -# Copyright 2019 Jordi Corbilla. All Rights Reserved. +# Copyright 2019-2020 Jordi Corbilla. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License.