diff --git a/models/vista3d/configs/inference.json b/models/vista3d/configs/inference.json index 92f2c0dd..3e62fa81 100644 --- a/models/vista3d/configs/inference.json +++ b/models/vista3d/configs/inference.json @@ -14,8 +14,10 @@ "output_ext": ".nii.gz", "output_dtype": "$np.float32", "output_postfix": "trans", - "separate_folder": true, - "input_dict": "${'image': '/data/Task09_Spleen/imagesTr/spleen_10.nii.gz', 'label_prompt': [3]}", + "separate_folder": false, + "save_image": true, + "gpu_load_image": false, + "input_dict": "${'image': '/workspace/Task03_Liver/imagesTr_decompress/liver_100.nii', 'label_prompt': [1]}", "everything_labels": "$list(set([i+1 for i in range(132)]) - set([2,16,18,20,21,23,24,25,26,27,128,129,130,131,132]))", "metadata_path": "$@bundle_root + '/configs/metadata.json'", "metadata": "$json.loads(pathlib.Path(@metadata_path).read_text())", @@ -54,18 +56,15 @@ { "_target_": "LoadImaged", "keys": "@image_key", - "image_only": true + "reader": "NibabelReader", + "image_only": true, + "gpu_load": "@gpu_load_image", + "device": "@device" }, { "_target_": "EnsureChannelFirstd", "keys": "@image_key" }, - { - "_target_": "EnsureTyped", - "keys": "@image_key", - "device": "@device", - "track_meta": true - }, { "_target_": "Spacingd", "keys": "@image_key", @@ -157,6 +156,7 @@ { "_target_": "SaveImaged", "keys": "pred", + "_disabled_": "$not @save_image", "resample": false, "output_dir": "@output_dir", "output_ext": "@output_ext", diff --git a/models/vista3d/scripts/evaluator.py b/models/vista3d/scripts/evaluator.py index f20261b3..3b75b391 100644 --- a/models/vista3d/scripts/evaluator.py +++ b/models/vista3d/scripts/evaluator.py @@ -22,6 +22,7 @@ from monai.utils import ForwardMode, IgniteInfo, RankFilter, min_version, optional_import from monai.utils.enums import CommonKeys as Keys from torch.utils.data import DataLoader +from .warmup import warm_up rearrange, _ = optional_import("einops", name="rearrange") @@ -133,6 +134,7 @@ def __init__( self.inferer = SimpleInferer() if inferer is None else inferer self.hyper_kwargs = hyper_kwargs self.logger.addFilter(RankFilter()) + warm_up() def transform_points(self, point, affine): """transform point to the coordinates of the transformed image diff --git a/models/vista3d/scripts/warmup.py b/models/vista3d/scripts/warmup.py new file mode 100644 index 00000000..c0bbbdcb --- /dev/null +++ b/models/vista3d/scripts/warmup.py @@ -0,0 +1,16 @@ +import tempfile +import cupy as cp +import kvikio + +def warm_up(): + a = cp.arange(100) + with tempfile.NamedTemporaryFile(delete=False) as tmp_file: + tmp_file_name = tmp_file.name + f = kvikio.CuFile(tmp_file_name, "w") + # Write whole array to file + f.write(a) + f.close() + + b = cp.empty_like(a) + f = kvikio.CuFile(tmp_file_name, "r") + f.read(b)