Skip to content

Commit

Permalink
Merge branch 'dev' into add-trt-support-for-bundles
Browse files Browse the repository at this point in the history
  • Loading branch information
yiheng-wang-nv authored Sep 6, 2023
2 parents 33f9a04 + 6b55743 commit 396b47a
Show file tree
Hide file tree
Showing 17 changed files with 1,366 additions and 5 deletions.
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ ci:
autoupdate_schedule: quarterly
# submodules: true

exclude: '.*\.ipynb$'

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
Expand Down
2 changes: 2 additions & 0 deletions ci/run_premerge_cpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ remove_pipenv() {
echo "removing pip environment"
pipenv --rm
rm Pipfile Pipfile.lock
pipenv --clear
df -h
}

verify_bundle() {
Expand Down
108 changes: 108 additions & 0 deletions ci/unit_tests/test_endoscopic_tool_segmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
import tempfile
import unittest

import numpy as np
from monai.bundle import ConfigWorkflow
from monai.data import PILWriter
from parameterized import parameterized
from utils import check_workflow

TEST_CASE_1 = [ # train, evaluate
{
"bundle_root": "models/endoscopic_tool_segmentation",
"train#trainer#max_epochs": 2,
"train#dataloader#num_workers": 1,
"validate#dataloader#num_workers": 1,
"train#deterministic_transforms#3#spatial_size": [32, 32],
}
]

TEST_CASE_2 = [ # inference
{
"bundle_root": "models/endoscopic_tool_segmentation",
"handlers#0#_disabled_": True,
"preprocessing#transforms#2#spatial_size": [32, 32],
}
]


class TestEndoscopicSeg(unittest.TestCase):
def setUp(self):
self.dataset_dir = tempfile.mkdtemp()
dataset_size = 10
writer = PILWriter(np.uint8)
shape = (736, 480)
for mode in ["train", "val", "test"]:
for sub_folder in ["inbody", "outbody"]:
sample_dir = os.path.join(self.dataset_dir, f"{mode}/{sub_folder}")
os.makedirs(sample_dir)
for s in range(dataset_size):
image = np.random.randint(low=0, high=5, size=(3, *shape)).astype(np.int8)
image_filename = os.path.join(sample_dir, f"{sub_folder}_{s}.jpg")
writer.set_data_array(image, channel_dim=0)
writer.write(image_filename, verbose=True)
if mode != "test":
label = np.random.randint(low=0, high=5, size=shape).astype(np.int8)
label_filename = os.path.join(sample_dir, f"{sub_folder}_{s}_seg.jpg")
writer.set_data_array(label, channel_dim=None)
writer.write(label_filename, verbose=True)

def tearDown(self):
shutil.rmtree(self.dataset_dir)

@parameterized.expand([TEST_CASE_1])
def test_train_eval_config(self, override):
override["dataset_dir"] = self.dataset_dir
bundle_root = override["bundle_root"]
train_file = os.path.join(bundle_root, "configs/train.json")
eval_file = os.path.join(bundle_root, "configs/evaluate.json")

trainer = ConfigWorkflow(
workflow="train",
config_file=train_file,
logging_file=os.path.join(bundle_root, "configs/logging.conf"),
meta_file=os.path.join(bundle_root, "configs/metadata.json"),
**override,
)
check_workflow(trainer, check_properties=True)

validator = ConfigWorkflow(
# override train.json, thus set the workflow to "train" rather than "eval"
workflow="train",
config_file=[train_file, eval_file],
logging_file=os.path.join(bundle_root, "configs/logging.conf"),
meta_file=os.path.join(bundle_root, "configs/metadata.json"),
**override,
)
check_workflow(validator, check_properties=True)

@parameterized.expand([TEST_CASE_2])
def test_infer_config(self, override):
override["dataset_dir"] = self.dataset_dir
bundle_root = override["bundle_root"]

inferrer = ConfigWorkflow(
workflow="infer",
config_file=os.path.join(bundle_root, "configs/inference.json"),
logging_file=os.path.join(bundle_root, "configs/logging.conf"),
meta_file=os.path.join(bundle_root, "configs/metadata.json"),
**override,
)
check_workflow(inferrer, check_properties=True)


if __name__ == "__main__":
unittest.main()
119 changes: 119 additions & 0 deletions ci/unit_tests/test_endoscopic_tool_segmentation_dist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
import tempfile
import unittest

import numpy as np
import torch
from monai.data import PILWriter
from parameterized import parameterized
from utils import export_config_and_run_mgpu_cmd

TEST_CASE_1 = [
{
"bundle_root": "models/endoscopic_tool_segmentation",
"train#trainer#max_epochs": 1,
"train#dataloader#num_workers": 1,
"validate#dataloader#num_workers": 1,
"train#deterministic_transforms#3#spatial_size": [32, 32],
}
]

TEST_CASE_2 = [
{
"bundle_root": "models/endoscopic_tool_segmentation",
"validate#dataloader#num_workers": 4,
"train#deterministic_transforms#3#spatial_size": [32, 32],
}
]


def test_order(test_name1, test_name2):
def get_order(name):
if "train" in name:
return 1
if "eval" in name:
return 2
if "infer" in name:
return 3
return 4

return get_order(test_name1) - get_order(test_name2)


class TestEndoscopicSegMGPU(unittest.TestCase):
def setUp(self):
self.dataset_dir = tempfile.mkdtemp()
dataset_size = 10
writer = PILWriter(np.uint8)
shape = (3, 256, 256)
for mode in ["train", "val"]:
for sub_folder in ["inbody", "outbody"]:
sample_dir = os.path.join(self.dataset_dir, f"{mode}/{sub_folder}")
os.makedirs(sample_dir)
for s in range(dataset_size):
image = np.random.randint(low=0, high=5, size=(3, *shape)).astype(np.int8)
image_filename = os.path.join(sample_dir, f"{sub_folder}_{s}.jpg")
writer.set_data_array(image, channel_dim=0)
writer.write(image_filename, verbose=True)
label = np.random.randint(low=0, high=5, size=shape).astype(np.int8)
label_filename = os.path.join(sample_dir, f"{sub_folder}_{s}_seg.jpg")
writer.set_data_array(label, channel_dim=None)
writer.write(label_filename, verbose=True)

def tearDown(self):
shutil.rmtree(self.dataset_dir)

@parameterized.expand([TEST_CASE_1])
def test_train_mgpu_config(self, override):
override["dataset_dir"] = self.dataset_dir
bundle_root = override["bundle_root"]
train_file = os.path.join(bundle_root, "configs/train.json")
mgpu_train_file = os.path.join(bundle_root, "configs/multi_gpu_train.json")
output_path = os.path.join(bundle_root, "configs/train_override.json")
n_gpu = torch.cuda.device_count()
export_config_and_run_mgpu_cmd(
config_file=[train_file, mgpu_train_file],
logging_file=os.path.join(bundle_root, "configs/logging.conf"),
meta_file=os.path.join(bundle_root, "configs/metadata.json"),
override_dict=override,
output_path=output_path,
ngpu=n_gpu,
check_config=True,
)

@parameterized.expand([TEST_CASE_2])
def test_evaluate_mgpu_config(self, override):
override["dataset_dir"] = self.dataset_dir
bundle_root = override["bundle_root"]
train_file = os.path.join(bundle_root, "configs/train.json")
evaluate_file = os.path.join(bundle_root, "configs/evaluate.json")
mgpu_evaluate_file = os.path.join(bundle_root, "configs/multi_gpu_evaluate.json")
output_path = os.path.join(bundle_root, "configs/evaluate_override.json")
n_gpu = torch.cuda.device_count()
export_config_and_run_mgpu_cmd(
config_file=[train_file, evaluate_file, mgpu_evaluate_file],
logging_file=os.path.join(bundle_root, "configs/logging.conf"),
meta_file=os.path.join(bundle_root, "configs/metadata.json"),
override_dict=override,
output_path=output_path,
ngpu=n_gpu,
check_config=True,
)


if __name__ == "__main__":
loader = unittest.TestLoader()
loader.sortTestMethodsUsing = test_order
unittest.main(testLoader=loader)
126 changes: 126 additions & 0 deletions ci/unit_tests/test_lung_nodule_ct_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import shutil
import sys
import tempfile
import unittest

import nibabel as nib
import numpy as np
from monai.bundle import ConfigWorkflow
from monai.data import create_test_image_3d
from monai.utils import set_determinism
from parameterized import parameterized
from utils import check_workflow

set_determinism(123)

TEST_CASE_1 = [ # train, evaluate
{
"bundle_root": "models/lung_nodule_ct_detection",
"epochs": 3,
"batch_size": 1,
"val_interval": 2,
"train#dataloader#num_workers": 1,
"validate#dataloader#num_workers": 1,
}
]

TEST_CASE_2 = [{"bundle_root": "models/lung_nodule_ct_detection"}] # inference


def test_order(test_name1, test_name2):
def get_order(name):
if "train" in name:
return 1
if "eval" in name:
return 2
if "infer" in name:
return 3
return 4

return get_order(test_name1) - get_order(test_name2)


class TestLungNoduleDetection(unittest.TestCase):
def setUp(self):
self.dataset_dir = tempfile.mkdtemp()
dataset_size = 3
train_patch_size = (192, 192, 80)
dataset_json = {}

img, _ = create_test_image_3d(train_patch_size[0], train_patch_size[1], train_patch_size[2], 2)
image_filename = os.path.join(self.dataset_dir, "image.nii.gz")
nib.save(nib.Nifti1Image(img, np.eye(4)), image_filename)
label = [0, 0]
box = [[108, 119, 131, 142, 26, 37], [132, 147, 149, 164, 25, 40]]
data = {"box": box, "image": image_filename, "label": label}
dataset_json["training"] = [data for _ in range(dataset_size)]
dataset_json["validation"] = [data for _ in range(dataset_size)]

self.ds_file = os.path.join(self.dataset_dir, "dataset.json")
with open(self.ds_file, "w") as fp:
json.dump(dataset_json, fp, indent=2)

def tearDown(self):
shutil.rmtree(self.dataset_dir)

@parameterized.expand([TEST_CASE_1])
def test_train_eval_config(self, override):
override["dataset_dir"] = self.dataset_dir
override["data_list_file_path"] = self.ds_file
bundle_root = override["bundle_root"]
train_file = os.path.join(bundle_root, "configs/train.json")
eval_file = os.path.join(bundle_root, "configs/evaluate.json")

sys.path.append(bundle_root)
trainer = ConfigWorkflow(
workflow="train",
config_file=train_file,
logging_file=os.path.join(bundle_root, "configs/logging.conf"),
meta_file=os.path.join(bundle_root, "configs/metadata.json"),
**override,
)
check_workflow(trainer, check_properties=False)

validator = ConfigWorkflow(
# override train.json, thus set the workflow to "train" rather than "eval"
workflow="train",
config_file=[train_file, eval_file],
logging_file=os.path.join(bundle_root, "configs/logging.conf"),
meta_file=os.path.join(bundle_root, "configs/metadata.json"),
**override,
)
check_workflow(validator, check_properties=False)

@parameterized.expand([TEST_CASE_2])
def test_infer_config(self, override):
override["dataset_dir"] = self.dataset_dir
override["data_list_file_path"] = self.ds_file
bundle_root = override["bundle_root"]

inferrer = ConfigWorkflow(
workflow="infer",
config_file=os.path.join(bundle_root, "configs/inference.json"),
logging_file=os.path.join(bundle_root, "configs/logging.conf"),
meta_file=os.path.join(bundle_root, "configs/metadata.json"),
**override,
)
check_workflow(inferrer, check_properties=True)


if __name__ == "__main__":
loader = unittest.TestLoader()
loader.sortTestMethodsUsing = test_order
unittest.main(testLoader=loader)
Loading

0 comments on commit 396b47a

Please sign in to comment.