Skip to content

Commit

Permalink
Merge pull request #15 from IvanKuchin/development
Browse files Browse the repository at this point in the history
save checkpoint on each epoch
  • Loading branch information
IvanKuchin committed Jul 25, 2024
2 parents 5f9c4a1 + 623353a commit 2362b38
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 11 deletions.
4 changes: 2 additions & 2 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
# https://radiopaedia.org/articles/windowing-ct?lang=us
# Option 2) 3D Slicer preset for abdominal CT
# W/L: 350/40, which makes the pancreas range from -310 to 390
PANCREAS_MIN_HU = -524
PANCREAS_MAX_HU = 1024
PANCREAS_MIN_HU = -310 # -512
PANCREAS_MAX_HU = 400 # 1024

IMAGE_DIMENSION_X = 160
IMAGE_DIMENSION_Y = IMAGE_DIMENSION_X
Expand Down
2 changes: 1 addition & 1 deletion dataset/pomc_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def consistency_check(self, data, label, data_metadata, label_metadata):

if data.shape != label.shape:
print("ERROR: data shape(", data.shape, ") is not equal to the label shape(", label.shape, ")")
# return False
return False

if not self._point_inside_box(data_metadata["min"], data_metadata["max"], label_metadata["space origin"]):
print("ERROR: label space origin(", label_metadata["space origin"], ") is outside the data box(", data_metadata["min"], data_metadata["max"], ")")
Expand Down
24 changes: 18 additions & 6 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,19 @@ def __read_dcm_slices(self, dcm_folder):
print("ERROR: no dcim files found")
return slices

# Stored Values (SV) are the values stored in the image pixel data attribute.
# Representation value should be calculated as:
# Rescaled value = SV * Rescale Slope + Rescale Intercept
# https://dicom.innolitics.com/ciods/digital-x-ray-image/dx-image/00281052
def __dcim_slice_stored_value_to_rescaled_value(self, slice):
rescale_intercept = slice.RescaleIntercept if hasattr(slice, "RescaleIntercept") else 0
rescale_slope = slice.RescaleSlope if hasattr(slice, "RescaleSlope") else 1
return slice.pixel_array * rescale_slope + rescale_intercept

def __get_pixel_data(self, dcm_slices):
result = np.array([])
if len(dcm_slices):
result = np.stack([_.pixel_array for _ in dcm_slices], axis = -1)
result = np.stack([self.__dcim_slice_stored_value_to_rescaled_value(_) for _ in dcm_slices], axis = -1)
else:
print("ERROR: dcim list is empty")

Expand Down Expand Up @@ -85,6 +94,12 @@ def __get_affine_matrix(self, dcm_slices):
affine[1, 1] = dcm_patient_orientation[4] * dcm_pixel_spacing[1]
affine[2, 1] = dcm_patient_orientation[5] * dcm_pixel_spacing[1]

affine[0, 3] = dcm_patient_position[0]
affine[1, 3] = dcm_patient_position[1]
affine[2, 3] = dcm_patient_position[2]

affine[2, 2] = dcm_slice_thickness

# --- inverse axes X and Y. This was found experimental way
# --- could be wrong ...
affine[0, 0] = -affine[0, 0]
Expand All @@ -95,11 +110,8 @@ def __get_affine_matrix(self, dcm_slices):
affine[1, 1] = -affine[1, 1]
affine[2, 1] = -affine[2, 1]

affine[2, 2] = dcm_slice_thickness

affine[0, 3] = dcm_patient_position[0]
affine[1, 3] = dcm_patient_position[1]
affine[2, 3] = dcm_patient_position[2]
affine[0, 3] = -affine[0, 3]
affine[1, 3] = -affine[1, 3]

return affine

Expand Down
9 changes: 7 additions & 2 deletions train_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,13 @@ def main():
model = craft_network(config.MODEL_CHECKPOINT)
# predict_on_random_data(model)

checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(config.MODEL_CHECKPOINT, monitor = config.MONITOR_METRIC,
mode = config.MONITOR_MODE, verbose = 2, save_best_only = True)
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
config.MODEL_CHECKPOINT,
monitor = config.MONITOR_METRIC,
mode = config.MONITOR_MODE,
verbose = 2,
# save_best_only = True
)
csv_logger = tf.keras.callbacks.CSVLogger(get_csv_dir(), separator = ',', append = True)
tensorboard_cb = tf.keras.callbacks.TensorBoard(get_tensorboard_log_dir())
reduce_lr_on_plateau = tf.keras.callbacks.ReduceLROnPlateau(factor = 0.1,
Expand Down

0 comments on commit 2362b38

Please sign in to comment.