Skip to content

Commit

Permalink
Merge pull request #304 from ellisdg/dgx
Browse files Browse the repository at this point in the history
Various updates and improvements
  • Loading branch information
ellisdg authored Jul 5, 2022
2 parents 668728b + 5210046 commit 0ab7de6
Show file tree
Hide file tree
Showing 10 changed files with 346 additions and 229 deletions.
5 changes: 4 additions & 1 deletion unet3d/models/pytorch/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ def build_or_load_model(model_name, model_filename, n_features, n_outputs, n_gpu
elif n_gpus > 0:
model = model.cuda()
if os.path.exists(model_filename):
state_dict = torch.load(model_filename)
if n_gpus > 0:
state_dict = torch.load(model_filename)
else:
state_dict = torch.load(model_filename, map_location=torch.device('cpu'))
model = load_state_dict(model, state_dict, n_gpus=n_gpus, strict=strict)
return model

Expand Down
2 changes: 2 additions & 0 deletions unet3d/predict/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .predict import *
from .volumetric import volumetric_predictions
235 changes: 11 additions & 224 deletions unet3d/predict.py → unet3d/predict/predict.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import os
import time
import numpy as np
import nibabel as nib
import pandas as pd
from nilearn.image import resample_to_img, new_img_like
from .utils.utils import (load_json, get_nibabel_data, one_hot_image_to_label_map,
break_down_volume_into_half_size_volumes, combine_half_size_volumes)
from .utils.sequences import SubjectPredictionSequence
from .utils.pytorch.dataset import HCPSubjectDataset
from .utils.hcp import new_cifti_scalar_like, get_metric_data
from .utils.filenames import generate_hcp_filenames, load_subject_ids
from .utils.augment import generate_permutation_keys, permute_data, reverse_permute_data
from nilearn.image import new_img_like

from unet3d.predict.volumetric import load_volumetric_model_and_dataset, load_images_from_dataset, \
prediction_to_image, write_prediction_image_to_file
from unet3d.predict.utils import pytorch_predict_batch_array, get_feature_filename_and_subject_id, pytorch_predict_batch
from unet3d.utils.utils import (load_json, get_nibabel_data, break_down_volume_into_half_size_volumes, combine_half_size_volumes)
from unet3d.utils.sequences import SubjectPredictionSequence
from unet3d.utils.pytorch.dataset import HCPSubjectDataset
from unet3d.utils.hcp import new_cifti_scalar_like, get_metric_data
from unet3d.utils.filenames import generate_hcp_filenames, load_subject_ids
from unet3d.utils.augment import generate_permutation_keys, permute_data, reverse_permute_data


def predict_data_loader(model, data_loader):
Expand Down Expand Up @@ -190,43 +192,6 @@ def whole_brain_scalar_predictions(model_filename, subject_ids, hcp_dir, output_
raise ValueError("Predictions not yet implemented for {}".format(package))


def volumetric_predictions(model_filename, filenames, prediction_dir, model_name, n_features, window,
criterion_name, package="keras", n_gpus=1, n_workers=1, batch_size=1,
model_kwargs=None, n_outputs=None, sequence_kwargs=None, sequence=None,
metric_names=None, evaluate_predictions=False, interpolation="linear",
resample_predictions=True, output_template=None, segmentation=False,
segmentation_labels=None, threshold=0.5, sum_then_threshold=True, label_hierarchy=None,
write_input_images=False):
if package == "pytorch":
pytorch_volumetric_predictions(model_filename=model_filename,
model_name=model_name,
n_outputs=n_outputs,
n_features=n_features,
filenames=filenames,
prediction_dir=prediction_dir,
window=window,
criterion_name=criterion_name,
n_gpus=n_gpus,
n_workers=n_workers,
batch_size=batch_size,
model_kwargs=model_kwargs,
sequence_kwargs=sequence_kwargs,
sequence=sequence,
metric_names=metric_names,
evaluate_predictions=evaluate_predictions,
interpolation=interpolation,
resample_predictions=resample_predictions,
output_template=output_template,
segmentation=segmentation,
segmentation_labels=segmentation_labels,
threshold=threshold,
sum_then_threshold=sum_then_threshold,
label_hierarchy=label_hierarchy,
write_input_images=write_input_images)
else:
raise ValueError("Predictions not yet implemented for {}".format(package))


def pytorch_whole_brain_scalar_predictions(model_filename, model_name, n_outputs, n_features, filenames, window,
criterion_name, metric_names, surface_names, prediction_dir=None,
output_csv=None, reference=None, n_gpus=1, n_workers=1, batch_size=1,
Expand Down Expand Up @@ -295,184 +260,6 @@ def pytorch_whole_brain_scalar_predictions(model_filename, model_name, n_outputs
pd.DataFrame(results, columns=columns).to_csv(output_csv)


def load_volumetric_model(model_name, model_filename, n_outputs, n_features, n_gpus, strict, **kwargs):
from unet3d.models.pytorch.build import build_or_load_model
model = build_or_load_model(model_name=model_name, model_filename=model_filename, n_outputs=n_outputs,
n_features=n_features, n_gpus=n_gpus, strict=strict, **kwargs)
model.eval()
return model


def load_volumetric_sequence(sequence, sequence_kwargs, filenames, window, spacing, metric_names, batch_size=1):
from .utils.pytorch.dataset import AEDataset
if sequence is None:
sequence = AEDataset
if sequence_kwargs is None:
sequence_kwargs = dict()
dataset = sequence(filenames=filenames, window=window, spacing=spacing, batch_size=batch_size,
metric_names=metric_names,
**sequence_kwargs)
return dataset


def load_volumetric_model_and_dataset(model_name, model_filename, model_kwargs, n_outputs, n_features,
strict_model_loading, n_gpus, sequence, sequence_kwargs, filenames, window,
spacing, metric_names):
if model_kwargs is None:
model_kwargs = dict()

model = load_volumetric_model(model_name=model_name, model_filename=model_filename, n_outputs=n_outputs,
n_features=n_features, strict=strict_model_loading, n_gpus=n_gpus, **model_kwargs)
dataset = load_volumetric_sequence(sequence, sequence_kwargs, filenames, window, spacing, metric_names,
batch_size=1)
basename = os.path.basename(model_filename).split(".")[0]
return model, dataset, basename


def load_images_from_dataset(dataset, idx, resample_predictions):
if resample_predictions:
x_image, ref_image = dataset.get_feature_image(idx, return_unmodified=True)
else:
x_image = dataset.get_feature_image(idx)
ref_image = None
return x_image, ref_image


def get_feature_filename_and_subject_id(dataset, idx, verbose=False):
epoch_filenames = dataset.epoch_filenames[idx]
x_filename = epoch_filenames[dataset.feature_index]
if verbose:
print("Reading:", x_filename)
subject_id = epoch_filenames[-1]
return x_filename, subject_id


def pytorch_predict_batch(batch_x, model, n_gpus):
if n_gpus > 0:
batch_x = batch_x.cuda()
if hasattr(model, "test"):
pred_x = model.test(batch_x)
else:
pred_x = model(batch_x)
return pred_x.cpu()


def prediction_to_image(data, input_image, reference_image=None, interpolation="linear", segmentation=False,
segmentation_labels=None, threshold=0.5, sum_then_threshold=False, label_hierarchy=False):
if data.dtype == np.float16:
data = np.asarray(data, dtype=np.float32)
pred_image = new_img_like(input_image, data=data)
if reference_image is not None:
pred_image = resample_to_img(pred_image, reference_image,
interpolation=interpolation)
if segmentation:
pred_image = one_hot_image_to_label_map(pred_image,
labels=segmentation_labels,
threshold=threshold,
sum_then_threshold=sum_then_threshold,
label_hierarchy=label_hierarchy)
return pred_image


def write_prediction_image_to_file(pred_image, output_template, subject_id, x_filename, prediction_dir, basename,
verbose=False):
if output_template is None:
while type(x_filename) == list:
x_filename = x_filename[0]
pred_filename = os.path.join(prediction_dir,
"_".join([subject_id,
basename,
os.path.basename(x_filename)]))
else:
pred_filename = os.path.join(prediction_dir,
output_template.format(subject=subject_id))
if verbose:
print("Writing:", pred_filename)
pred_image.to_filename(pred_filename)


def pytorch_predict_batch_array(model, batch, n_gpus=1):
import torch
batch_x = torch.tensor(np.moveaxis(np.asarray(batch), -1, 1)).float()
pred_x = pytorch_predict_batch(batch_x, model, n_gpus)
return np.moveaxis(pred_x.numpy(), 1, -1)


def predict_volumetric_batch(model, batch, batch_references, batch_subjects, batch_filenames,
basename, prediction_dir,
segmentation, output_template, n_gpus, verbose, threshold, interpolation,
segmentation_labels, sum_then_threshold, label_hierarchy, write_input_image=False):
pred_x = pytorch_predict_batch_array(model, batch, n_gpus=n_gpus)
for batch_idx in range(len(batch)):
pred_image = prediction_to_image(pred_x[batch_idx].squeeze(), input_image=batch_references[batch_idx][0],
reference_image=batch_references[batch_idx][1], interpolation=interpolation,
segmentation=segmentation, segmentation_labels=segmentation_labels,
threshold=threshold, sum_then_threshold=sum_then_threshold,
label_hierarchy=label_hierarchy)
write_prediction_image_to_file(pred_image, output_template,
subject_id=batch_subjects[batch_idx],
x_filename=batch_filenames[batch_idx],
prediction_dir=prediction_dir,
basename=basename,
verbose=verbose)
if write_input_image:
write_prediction_image_to_file(batch_references[batch_idx][0], output_template=output_template,
subject_id=batch_subjects[batch_idx] + "_input",
x_filename=batch_filenames[batch_idx],
prediction_dir=prediction_dir,
basename=basename,
verbose=verbose)


def pytorch_volumetric_predictions(model_filename, model_name, n_features, filenames, window,
criterion_name, prediction_dir=None, output_csv=None, reference=None,
n_gpus=1, n_workers=1, batch_size=1, model_kwargs=None, n_outputs=None,
sequence_kwargs=None, spacing=None, sequence=None,
strict_model_loading=True, metric_names=None,
print_prediction_time=True, verbose=True,
evaluate_predictions=False, resample_predictions=False, interpolation="linear",
output_template=None, segmentation=False, segmentation_labels=None,
sum_then_threshold=True, threshold=0.7, label_hierarchy=None,
write_input_images=False):
import torch
# from .train.pytorch import load_criterion

model, dataset, basename = load_volumetric_model_and_dataset(model_name, model_filename, model_kwargs, n_outputs,
n_features, strict_model_loading, n_gpus, sequence,
sequence_kwargs, filenames, window, spacing,
metric_names)

# criterion = load_criterion(criterion_name, n_gpus=n_gpus)
results = list()
print("Dataset: ", len(dataset))
with torch.no_grad():
batch = list()
batch_references = list()
batch_subjects = list()
batch_filenames = list()
for idx in range(len(dataset)):
x_filename, subject_id = get_feature_filename_and_subject_id(dataset, idx, verbose=verbose)
x_image, ref_image = load_images_from_dataset(dataset, idx, resample_predictions)

batch.append(get_nibabel_data(x_image))
batch_references.append((x_image, ref_image))
batch_subjects.append(subject_id)
batch_filenames.append(x_filename)
if len(batch) >= batch_size or idx == (len(dataset) - 1):
predict_volumetric_batch(model=model, batch=batch, batch_references=batch_references,
batch_subjects=batch_subjects, batch_filenames=batch_filenames,
basename=basename, prediction_dir=prediction_dir,
segmentation=segmentation, output_template=output_template, n_gpus=n_gpus,
verbose=verbose, threshold=threshold, interpolation=interpolation,
segmentation_labels=segmentation_labels,
sum_then_threshold=sum_then_threshold, label_hierarchy=label_hierarchy,
write_input_image=write_input_images)
batch = list()
batch_references = list()
batch_subjects = list()
batch_filenames = list()


def save_predictions(prediction, args, basename, metric_names, surface_names, prediction_dir):
ref_filename = args[2][0]
subject_id = args[-1]
Expand Down
27 changes: 27 additions & 0 deletions unet3d/predict/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import numpy as np


def pytorch_predict_batch_array(model, batch, n_gpus=1):
import torch
batch_x = torch.tensor(np.moveaxis(np.asarray(batch), -1, 1)).float()
pred_x = pytorch_predict_batch(batch_x, model, n_gpus)
return np.moveaxis(pred_x.numpy(), 1, -1)


def get_feature_filename_and_subject_id(dataset, idx, verbose=False):
epoch_filenames = dataset.epoch_filenames[idx]
x_filename = epoch_filenames[dataset.feature_index]
if verbose:
print("Reading:", x_filename)
subject_id = epoch_filenames[-1]
return x_filename, subject_id


def pytorch_predict_batch(batch_x, model, n_gpus):
if n_gpus > 0:
batch_x = batch_x.cuda()
if hasattr(model, "test"):
pred_x = model.test(batch_x)
else:
pred_x = model(batch_x)
return pred_x.cpu()
Loading

0 comments on commit 0ab7de6

Please sign in to comment.