Skip to content

Commit

Permalink
Merge pull request #300 from ellisdg/dgx
Browse files Browse the repository at this point in the history
FIX: Updates Predict and Train Scripts

closes #298
  • Loading branch information
ellisdg authored Mar 8, 2022
2 parents 92def62 + bc7ebbf commit 668728b
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 24 deletions.
2 changes: 2 additions & 0 deletions unet3d/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,8 @@ def pytorch_predict_batch(batch_x, model, n_gpus):

def prediction_to_image(data, input_image, reference_image=None, interpolation="linear", segmentation=False,
segmentation_labels=None, threshold=0.5, sum_then_threshold=False, label_hierarchy=False):
if data.dtype == np.float16:
data = np.asarray(data, dtype=np.float32)
pred_image = new_img_like(input_image, data=data)
if reference_image is not None:
pred_image = resample_to_img(pred_image, reference_image,
Expand Down
20 changes: 13 additions & 7 deletions unet3d/scripts/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def run_inference(namespace):
config = load_json(namespace.config_filename)
key = namespace.group + "_filenames"

machine_config = get_machine_config(namespace)
system_config = get_machine_config(namespace)

if namespace.filenames:
filenames = list()
Expand All @@ -77,10 +77,18 @@ def run_inference(namespace):
in
config["generate_filenames_kwargs"][_key]]
if namespace.directory_template is not None:
machine_config["directory"] = namespace.directory_template
directory = namespace.directory_template
elif "directory" in system_config and system_config["directory"]:
directory = system_config["directory"]
elif "directory" in config:
directory = config["directory"]
else:
directory = ""
if namespace.subjects_config_filename:
config[namespace.group] = load_json(namespace.subjects_config_filename)[namespace.group]
filenames = generate_filenames(config, namespace.group, machine_config,
else:
load_subject_ids(config, namespace.group)
filenames = generate_filenames(config, namespace.group, directory,
skip_targets=(not namespace.eval))

else:
Expand All @@ -93,8 +101,6 @@ def run_inference(namespace):
if not os.path.exists(namespace.output_directory):
os.makedirs(namespace.output_directory)

load_subject_ids(config)

if "evaluation_metric" in config and config["evaluation_metric"] is not None:
criterion_name = config['evaluation_metric']
else:
Expand Down Expand Up @@ -177,9 +183,9 @@ def run_inference(namespace):
window=config["window"],
criterion_name=criterion_name,
package=config['package'],
n_gpus=machine_config['n_gpus'],
n_gpus=system_config['n_gpus'],
batch_size=config['validation_batch_size'],
n_workers=machine_config["n_workers"],
n_workers=system_config["n_workers"],
model_kwargs=model_kwargs,
sequence_kwargs=sequence_kwargs,
sequence=sequence,
Expand Down
14 changes: 8 additions & 6 deletions unet3d/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,16 +150,18 @@ def main():
else:
groups = ("training", "validation")

for name in groups:
key = name + "_filenames"
if key not in config:
config[key] = generate_filenames(config, name, system_config,
raise_if_not_exists=namespace.debug)
if "directory" in system_config:
directory = system_config.pop("directory")
elif "directory" in config:
directory = config["directory"]
else:
directory = "."
directory = ""

for name in groups:
key = name + "_filenames"
if key not in config:
config[key] = generate_filenames(config, name, directory,
raise_if_not_exists=namespace.debug)
if "sequence" in config:
sequence_class = load_sequence(config["sequence"])
elif "_wb_" in os.path.basename(namespace.config_filename):
Expand Down
10 changes: 5 additions & 5 deletions unet3d/train/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def run_pytorch_training(config, model_filename, training_log_filename, verbose=
n_workers=1, max_queue_size=5, model_name='resnet_34', n_gpus=1, regularized=False,
sequence_class=WholeBrainCIFTI2DenseScalarDataset, directory=None, test_input=1,
metric_to_monitor="loss", model_metrics=(), bias=None, pin_memory=False, amp=False,
**unused_args):
prefetch_factor=1, **unused_args):
"""
:param test_input: integer with the number of inputs from the generator to write to file. 0, False, or None will
write no inputs to file.
Expand Down Expand Up @@ -110,7 +110,8 @@ def run_pytorch_training(config, model_filename, training_log_filename, verbose=
shuffle=True,
num_workers=n_workers,
collate_fn=collate_fn,
pin_memory=pin_memory)
pin_memory=pin_memory,
prefetch_factor=prefetch_factor)

if test_input:
for index in range(test_input):
Expand Down Expand Up @@ -148,7 +149,8 @@ def run_pytorch_training(config, model_filename, training_log_filename, verbose=
shuffle=False,
num_workers=n_workers,
collate_fn=collate_fn,
pin_memory=pin_memory)
pin_memory=pin_memory,
prefetch_factor=prefetch_factor)

train(model=model, optimizer=optimizer, criterion=criterion, n_epochs=config["n_epochs"], verbose=bool(verbose),
training_loader=training_loader, validation_loader=validation_loader, model_filename=model_filename,
Expand Down Expand Up @@ -205,7 +207,6 @@ def train(model, optimizer, criterion, n_epochs, training_loader, validation_loa
scaler = None

for epoch in range(start_epoch, n_epochs):
print("save_last_n_models", save_last_n_models)
# early stopping
if (training_log and early_stopping_patience
and np.asarray(training_log)[:, training_log_header.index(metric_to_monitor)].argmin()
Expand All @@ -214,7 +215,6 @@ def train(model, optimizer, criterion, n_epochs, training_loader, validation_loa
break

# train the model
print("n gpus:", n_gpus)
loss = epoch_training(training_loader, model, criterion, optimizer=optimizer, epoch=epoch, n_gpus=n_gpus,
regularized=regularized, vae=vae, scaler=scaler)
try:
Expand Down
2 changes: 1 addition & 1 deletion unet3d/train/pytorch_training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def epoch_validatation(val_loader, model, criterion, n_gpus, print_freq=1, regul
progress = ProgressMeter(
len(val_loader),
[batch_time, losses],
prefix='Test: ')
prefix='Validation: ')

# switch to evaluate mode
model.eval()
Expand Down
6 changes: 3 additions & 3 deletions unet3d/utils/filenames.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,19 +144,19 @@ def generate_filenames_from_multisource_templates(subject_ids, feature_templates
return filenames


def generate_filenames(config, name, system_config, skip_targets=False, raise_if_not_exists=False):
def generate_filenames(config, name, directory="", skip_targets=False, raise_if_not_exists=False):
if name not in config:
load_subject_ids(config, name)
if "generate_filenames" not in config or config["generate_filenames"] == "classic":
return generate_hcp_filenames(in_config('directory', system_config, ""),
return generate_hcp_filenames(directory,
config['surface_basename_template']
if "surface_basename_template" in config else None,
config['target_basenames'],
config['feature_basenames'],
config[name],
config['hemispheres'] if 'hemispheres' in config else None)
elif config["generate_filenames"] == "paired":
return generate_paired_filenames(in_config('directory', system_config, ""),
return generate_paired_filenames(directory,
config[name],
name,
raise_if_not_exists=raise_if_not_exists,
Expand Down
8 changes: 6 additions & 2 deletions unet3d/utils/sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,16 +616,20 @@ def __init__(self, *args, target_interpolation="nearest", target_index=2, labels
def resample_input(self, input_filenames):
input_image, target_image = self.resample_image(input_filenames)
target_data = get_nibabel_data(target_image)
if self.labels is None:
self.labels = np.asarray(np.unique(target_data), dtype=int)
assert len(target_data.shape) == 4
if target_data.shape[3] == 1:
if self.labels is None:
self.labels = np.asarray(np.unique(target_data)[1:], dtype=int)
target_data = np.moveaxis(
compile_one_hot_encoding(np.moveaxis(target_data, 3, 0),
n_labels=len(self.labels),
labels=self.labels,
return_4d=True), 0, 3)
else:
if self.labels is None:
self.labels = np.asarray([np.unique(target_data[:, :, :, channel])[1:]
for channel in np.arange(target_data.shape[self.channel_axis])],
dtype=int)
_target_data = list()
for channel, labels in zip(range(target_data.shape[self.channel_axis]), self.labels):
if type(labels) != list:
Expand Down

0 comments on commit 668728b

Please sign in to comment.