Skip to content

Commit

Permalink
Update vista3d to use MONAI components (Project-MONAI#627)
Browse files Browse the repository at this point in the history
Fixes # .

### Description
A few sentences describing the changes proposed in this pull request.

### Status
**Ready/Work in progress/Hold**

### Please ensure all the checkboxes:
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Codeformat tests passed locally by running `./runtests.sh
--codeformat`.
- [ ] In-line docstrings updated.
- [ ] Update `version` and `changelog` in `metadata.json` if changing an
existing bundle.
- [ ] Please ensure the naming rules in config files meet our
requirements (please refer to: `CONTRIBUTING.md`).
- [ ] Ensure versions of packages such as `monai`, `pytorch` and `numpy`
are correct in `metadata.json`.
- [ ] Descriptions should be consistent with the content, such as
`eval_metrics` of the provided weights and TorchScript modules.
- [ ] Files larger than 25MB are excluded and replaced by providing
download links in `large_file.yml`.
- [ ] Avoid using path that contains personal information within config
files (such as use `/home/your_name/` for `"bundle_root"`).

---------

Signed-off-by: Yiheng Wang <vennw@nvidia.com>
Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
  • Loading branch information
yiheng-wang-nv and KumoLiu authored Aug 26, 2024
1 parent a87ec66 commit db0f350
Show file tree
Hide file tree
Showing 22 changed files with 33 additions and 2,475 deletions.
2 changes: 2 additions & 0 deletions ci/verify_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,8 @@ def check_properties(**kwargs):
"""
app_properties_path = kwargs.get("properties_path", "")
kwargs.pop("properties_path", None)
print(kwargs)

workflow = create_workflow(**kwargs)
if app_properties_path is not None and os.path.isfile(app_properties_path):
shutil.copy(app_properties_path, "ci/bundle_properties.py")
Expand Down
4 changes: 2 additions & 2 deletions models/vista3d/configs/evaluate.json
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
]
},
{
"_target_": "scripts.monai_trans_utils.RelabelD",
"_target_": "monai.apps.vista3d.transforms.Relabeld",
"keys": "label",
"label_mappings": "@label_mappings",
"dtype": "$torch.uint8"
Expand All @@ -113,7 +113,7 @@
"sigmoid": true
},
{
"_target_": "scripts.monai_trans_utils.VistaPostTransform",
"_target_": "monai.apps.vista3d.transforms.VistaPostTransformd",
"keys": "pred"
},
{
Expand Down
16 changes: 11 additions & 5 deletions models/vista3d/configs/inference.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
"$import glob",
"$import os",
"$import scripts",
"$import numpy as np"
"$import numpy as np",
"$import json"
],
"bundle_root": "./",
"image_key": "image",
Expand All @@ -14,6 +15,11 @@
"separate_folder": true,
"input_dict": "${'image': '/data/Task09_Spleen/imagesTr/spleen_10.nii.gz', 'label_prompt': [3]}",
"everything_labels": "$list(set([i+1 for i in range(132)]) - set([2,16,18,20,21,23,24,25,26,27,128,129,130,131,132]))",
"metadata_path": "$@bundle_root + '/configs/metadata.json'",
"metadata_file": "$open(@metadata_path,'r', encoding='utf8')",
"metadata": "$json.load(@metadata_file)",
"close_metadata_file": "$metadata_file.close()",
"labels_dict": "$@metadata['network_data_format']['outputs']['pred']['channel_def']",
"subclass": {
"2": [
14,
Expand Down Expand Up @@ -43,7 +49,7 @@
"device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
"use_cfp": false,
"use_point_window": true,
"network_def": "$scripts.vista3d.vista_model_registry['vista3d_segresnet_d'](in_channels=@input_channels, image_size=@patch_size)",
"network_def": "$monai.networks.nets.vista3d132(in_channels=@input_channels)",
"network": "$@network_def.to(@device)",
"preprocessing_transforms": [
{
Expand Down Expand Up @@ -75,10 +81,10 @@
"source_key": "@image_key"
},
{
"_target_": "scripts.monai_trans_utils.VistaPreTransform",
"_target_": "monai.apps.vista3d.transforms.VistaPreTransformd",
"keys": "@image_key",
"subclass": "@subclass",
"bundle_root": "@bundle_root"
"labels_dict": "@labels_dict"
},
{
"_target_": "ScaleIntensityRanged",
Expand Down Expand Up @@ -134,7 +140,7 @@
"_disabled_": true
},
{
"_target_": "scripts.monai_trans_utils.VistaPostTransform",
"_target_": "monai.apps.vista3d.transforms.VistaPostTransformd",
"keys": "pred"
},
{
Expand Down
15 changes: 8 additions & 7 deletions models/vista3d/configs/metadata.json
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
{
"schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20240725.json",
"version": "0.4.1",
"version": "0.4.2",
"changelog": {
"0.4.2": "use MONAI components for network and utils",
"0.4.1": "initial OSS version"
},
"monai_version": "1.3.1",
"pytorch_version": "2.2.2",
"monai_version": "1.4.0",
"pytorch_version": "2.4.0",
"numpy_version": "1.24.4",
"matplotlib": "3.8.3",
"einops": "0.7.0",
"required_packages_version": {
"scikit-image": "0.22.0",
"matplotlib": "3.9.1",
"einops": "0.7.0",
"scikit-image": "0.23.2",
"nibabel": "5.2.1",
"pytorch-ignite": "0.4.11",
"cucim": "23.08.00"
"cucim-cu12": "24.6.0"
},
"supported_apps": {
"vista3d-nim": ""
Expand Down
2 changes: 1 addition & 1 deletion models/vista3d/configs/train.json
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
128
],
"patch_size_valid": "$@patch_size",
"network_def": "$scripts.vista3d.vista_model_registry['vista3d_segresnet_d'](in_channels=@input_channels, image_size=@patch_size)",
"network_def": "$monai.networks.nets.vista3d132(in_channels=@input_channels)",
"network": "$@network_def.to(@device)",
"loss": {
"_target_": "DiceCELoss",
Expand Down
4 changes: 2 additions & 2 deletions models/vista3d/configs/train_continual.json
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
"warn": false
},
{
"_target_": "scripts.monai_trans_utils.RelabelD",
"_target_": "monai.apps.vista3d.transforms.Relabeld",
"keys": "label",
"label_mappings": "@label_mappings",
"dtype": "$torch.uint8"
Expand All @@ -101,7 +101,7 @@
},
"validate#preprocessing#transforms": "$@train#deterministic_transforms + [@valid_remap]",
"valid_remap": {
"_target_": "scripts.monai_trans_utils.RelabelD",
"_target_": "monai.apps.vista3d.transforms.Relabeld",
"keys": "label",
"label_mappings": "${'default': [[c, i] for i, c in enumerate(@val_label_set)]}",
"dtype": "$torch.uint8"
Expand Down
2 changes: 1 addition & 1 deletion models/vista3d/large_files.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
large_files:
- path: "models/model.pt"
url: "https://drive.google.com/file/d/1eLIxQwnxGsjggxiVjdcAyNvJ5DYtqmdc/view?usp=sharing"
url: "https://drive.google.com/file/d/1Sbe6GjlgH-GIcXolZzUiwgqR4DBYNLQ3/view?usp=drive_link"
1 change: 0 additions & 1 deletion models/vista3d/scripts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,4 @@
# from .evaluator import EnsembleEvaluator, Evaluator, SupervisedEvaluator
# from .multi_gpu_supervised_trainer import create_multigpu_supervised_evaluator, create_multigpu_supervised_trainer

from . import vista3d
from .early_stop_score_function import score_function
4 changes: 2 additions & 2 deletions models/vista3d/scripts/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,8 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten
label_set = np.arange(output_classes).tolist()
label_prompt = torch.tensor(label_set).to(engine.state.device).unsqueeze(-1)
# point prompt is generated withing vista3d,provide empty points
points = torch.zeros(label_prompt.shape[0], 1, 3)
point_labels = -1 + torch.zeros(label_prompt.shape[0], 1)
points = torch.zeros(label_prompt.shape[0], 1, 3).to(inputs.device)
point_labels = -1 + torch.zeros(label_prompt.shape[0], 1).to(inputs.device)
if engine.hyper_kwargs["drop_point_prob"] > 0.99:
# automatic only validation
points = None
Expand Down
6 changes: 2 additions & 4 deletions models/vista3d/scripts/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,10 @@
from typing import List, Union

import torch
from monai.inferers.inferer import Inferer
from monai.apps.vista3d.inferer import point_based_window_inferer
from monai.inferers import Inferer, sliding_window_inference
from torch import Tensor

from .monai_utils import sliding_window_inference
from .utils import point_based_window_inferer


class Vista3dInferer(Inferer):
"""
Expand Down
Loading

0 comments on commit db0f350

Please sign in to comment.