-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Lazy eval refactor #6257
Lazy eval refactor #6257
Changes from 1 commit
06d5842
2507a0b
56b618e
315e5b7
4e69313
99fa35b
a798e0e
114b8cc
5df4261
84d65f7
7ec442f
8f62776
2492223
0064326
0646602
05eb3f9
788cd74
4bd22f8
1e0510c
c36b69f
8a4f310
3c69776
6b94aa1
bf1c01a
19745f9
b219f13
6af8f24
173c161
2262ba9
aaaea13
088d14e
3e6a637
98a53bf
28a0175
4e9b6c6
ad23bea
81549e2
77eff5b
7c53d26
ee5378f
28c05d0
d96fef0
654cf48
64ec242
f38eb9b
14349ff
51fda45
c5742c0
ef5173e
916b0a9
20c1cfc
c94341f
cc51f2f
06aad01
3ae46e9
95a6a5f
81559fa
af9cad1
a9c8849
aadbae3
1cccf9b
9ae1140
bb2ad0f
5abe294
c2bf4a7
b16a87f
816666e
8318121
0e10c24
f323ff9
4b69ec6
29ce2b7
6774aaa
de5b9b8
1c7d172
bb22c13
2159ce8
2c3172d
bb1214c
5a7820c
503db85
1af35fd
75771a8
d97095a
af6cd62
cd246a8
1403d77
5c4c10a
5755a65
ca0905a
e8fdf5a
70d3517
b361a13
a6ad51e
7ec8429
305b88e
2aa0ada
7b092ca
62431ff
6cdf4c6
2ea89c1
628a7f8
5b34755
d547b81
03386dc
d2fa521
eb641c6
d7d7cdb
29b37bf
3194980
5ff516d
eebb3de
d966856
e39e0a3
cf58b9f
883f8c5
e81f359
db1486b
b151607
24286f1
43f7bc0
b217dbe
ed1d952
e7f3a07
0751c52
6114a67
240059b
23ff89c
5e1b81d
5dd6113
74bed34
76e4b85
1c4fa6c
8dde40d
7fd32e2
10d95ce
0d6a847
1ceeb1a
38bdef6
6e0cfd9
f2fb06a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
Signed-off-by: Ben Murray <ben.murray@gmail.com>
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -24,10 +24,13 @@ | |||||
import monai | ||||||
from monai.apps.utils import get_logger | ||||||
from monai.config import NdarrayOrTensor | ||||||
from monai.data.meta_tensor import MetaTensor | ||||||
from monai.transforms.inverse import InvertibleTransform | ||||||
from monai.transforms.lazy.array import ApplyPending | ||||||
from monai.transforms.lazy.dictionary import ApplyPendingd | ||||||
|
||||||
# For backwards compatibility (so this still works: from monai.transforms.compose import MapTransform) | ||||||
from monai.transforms.lazy.functional import apply_pending_transforms | ||||||
from monai.transforms.lazy.executors import apply_pending_transforms | ||||||
from monai.transforms.traits import LazyTrait, ThreadUnsafe | ||||||
from monai.transforms.transform import ( # noqa: F401 | ||||||
LazyTransform, | ||||||
|
@@ -52,6 +55,7 @@ def execute_compose( | |||||
start: int = 0, | ||||||
end: int | None = None, | ||||||
lazy: bool | None = False, | ||||||
lazy_strategy: str = "in_order", | ||||||
overrides: dict | None = None, | ||||||
threading: bool = False, | ||||||
logger_name: str | None = None, | ||||||
|
@@ -75,6 +79,9 @@ def execute_compose( | |||||
lazy: whether to enable lazy evaluation for lazy transforms. If False, transforms will be | ||||||
carried out on a transform by transform basis. If True, all lazy transforms will | ||||||
be executed by accumulating changes and resampling as few times as possible. | ||||||
lazy_strategy: this field controls how execution occurs when processing data lazily. Permitted | ||||||
options are "in_order", "out_of_order". Please see `Compose`_ for more details of what these | ||||||
options mean. In general, you should not need to change this from its default. | ||||||
overrides: this optional parameter allows you to specify a dictionary of parameters that should be overridden | ||||||
when executing a pipeline. These each parameter that is compatible with a given transform is then applied | ||||||
to that transform before it is executed. Note that overrides are currently only applied when lazy | ||||||
|
@@ -109,12 +116,138 @@ def execute_compose( | |||||
if threading: | ||||||
_transform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform | ||||||
data = apply_transform( | ||||||
_transform, data, map_items, unpack_items, lazy=lazy, overrides=overrides, logger_name=logger_name | ||||||
_transform, | ||||||
data, | ||||||
map_items, | ||||||
unpack_items, | ||||||
lazy=lazy, | ||||||
lazy_strategy=lazy_strategy, | ||||||
overrides=overrides, | ||||||
logger_name=logger_name, | ||||||
) | ||||||
data = apply_pending_transforms(data, overrides, logger_name=logger_name) | ||||||
data = apply_pending_transforms(data, None, overrides, logger_name=logger_name) | ||||||
return data | ||||||
|
||||||
|
||||||
class ComposeCompiler: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As we've discussed today at the meeting I'd suggest something like |
||||||
""" | ||||||
The ComposeCompiler is an implementation class that is required to parse options for Compose. It should currently | ||||||
be considered an implementation detail that should not be interacted with directly by users of MONAI, although that | ||||||
may change in subsequent releases. Its job is to parse options provided to `Compose.__call__`_ to set execution | ||||||
modes for lazy resampling. | ||||||
|
||||||
See `Compose`_ for a detailed explanation of lazy resampling. | ||||||
""" | ||||||
|
||||||
def __init__(self): | ||||||
# construct the list of options | ||||||
options = {"reorder": {"lazy_last": self.reorder_lazy_last, "lazy_last_nosync": self.reorder_lazy_last_nosync}} | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would avoid having nested structures like this, instead:
Suggested change
Also the string keywords here should be in enums and the structure that's expected should be documented in the docstring. |
||||||
self.options = options | ||||||
|
||||||
def __call__(self, transforms, lazy: bool | None, options: dict | None = None): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||
""" | ||||||
Get a policy object that controls the way `Compose`_ executes a list of transforms. | ||||||
|
||||||
Args: | ||||||
transforms: a list of transforms to be executed | ||||||
lazy: the current lazy mode (False, None, or True) | ||||||
options: the options that determine the execution policy | ||||||
|
||||||
Returns: | ||||||
a dictionary specifying the execution policy | ||||||
|
||||||
""" | ||||||
if lazy is False or options is None: | ||||||
return ComposeCompiler.generate_policy() | ||||||
|
||||||
if len(options.keys()) > 1: | ||||||
raise ValueError("Only one option can currently be set") | ||||||
|
||||||
for k, v in options.items(): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||
if k not in self.options.keys(): | ||||||
raise KeyError( | ||||||
f"'{k}' is not a valid option key. Valid options are " f"{tuple(k for k in self.options.keys())}" | ||||||
) | ||||||
|
||||||
option = self.options[k] | ||||||
if v not in option.keys(): | ||||||
raise KeyError(f"'{v}' is not a valid option value. Value options for " f"'{k}' are {option.keys()}") | ||||||
|
||||||
action = option[v] | ||||||
|
||||||
return action(transforms=transforms, lazy=lazy) | ||||||
|
||||||
@classmethod | ||||||
def reorder_lazy_last(cls, *, transforms: list, lazy: bool | None, **kwargs): | ||||||
subsections = list() | ||||||
subsection_starts = list() | ||||||
# pass 1: split the transform list into subsections | ||||||
i_s = 0 | ||||||
for i_t in range(len(transforms)): | ||||||
if isinstance(transforms[i_t], (ApplyPending, ApplyPendingd)): | ||||||
# this subsection ends and is added to the subsection list | ||||||
if i_s < i_t: | ||||||
subsections.append(transforms[i_s:i_t]) | ||||||
subsection_starts.append(i_s) | ||||||
# add apply pending in its own list | ||||||
subsections.append([transforms[i_t]]) | ||||||
subsection_starts.append(i_t) | ||||||
i_s = i_t + 1 | ||||||
|
||||||
if i_s != len(transforms): | ||||||
subsections.append(transforms[i_s:]) | ||||||
subsection_starts.append(i_s) | ||||||
|
||||||
# pass 2: calculate the permuted indices | ||||||
permuted_indices = list() | ||||||
for sub_start, subsection in zip(subsection_starts, subsections): | ||||||
for i_s, s in enumerate(subsection): | ||||||
if not cls._executing_lazily(s, lazy): | ||||||
permuted_indices.append(i_s + sub_start) | ||||||
for i_s, s in enumerate(subsection): | ||||||
if cls._executing_lazily(s, lazy): | ||||||
permuted_indices.append(i_s + sub_start) | ||||||
|
||||||
# pass 2: sort the subsections | ||||||
reordered = list() | ||||||
for subsection in subsections: | ||||||
# non-lazy, lazy | ||||||
subsection = [t for t in subsection if not cls._executing_lazily(t, lazy)] + [ | ||||||
t for t in subsection if cls._executing_lazily(t, lazy) | ||||||
] | ||||||
reordered.extend(subsection) | ||||||
|
||||||
return ComposeCompiler.generate_policy({"indices": permuted_indices}) | ||||||
|
||||||
@classmethod | ||||||
def reorder_lazy_last_nosync(cls, *, transforms: list, **_): | ||||||
""" | ||||||
'reorder: lazy_last_nosync' is implemented through use of the 'out_of_order' execution | ||||||
policy. See 'Compose'_ for details of this policy. | ||||||
Args: | ||||||
transforms: Not used by this method | ||||||
|
||||||
Returns: | ||||||
|
||||||
""" | ||||||
return cls.generate_policy({"can_invert": False, "lazy_policy": "out_of_order"}) | ||||||
|
||||||
@staticmethod | ||||||
def generate_policy(overrides: dict | None = None): | ||||||
default_policy = {"indices": None, "transforms": None, "can_invert": True, "lazy_policy": "in_order"} | ||||||
if overrides is not None: | ||||||
for k, v in overrides.items(): | ||||||
default_policy[k] = v | ||||||
return default_policy | ||||||
|
||||||
@staticmethod | ||||||
def _executing_lazily(t, lazy_policy): | ||||||
if isinstance(t, LazyTrait): | ||||||
lazy_ = t.lazy if lazy_policy is None else lazy_policy | ||||||
return lazy_ | ||||||
return False | ||||||
|
||||||
|
||||||
class Compose(Randomizable, InvertibleTransform): | ||||||
""" | ||||||
``Compose`` provides the ability to chain a series of callables together in | ||||||
|
@@ -262,6 +395,7 @@ def __init__( | |||||
unpack_items: bool = False, | ||||||
lazy: bool | None = False, | ||||||
overrides: dict | None = None, | ||||||
options: dict | None = None, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please add documentation for this |
||||||
logger_name: str | None = None, | ||||||
) -> None: | ||||||
if transforms is None: | ||||||
|
@@ -272,6 +406,7 @@ def __init__( | |||||
self.set_random_state(seed=get_seed()) | ||||||
self.lazy = lazy | ||||||
self.overrides = overrides | ||||||
self.options = options | ||||||
self.logger_name = logger_name | ||||||
|
||||||
def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> Compose: | ||||||
|
@@ -350,33 +485,94 @@ def __len__(self): | |||||
return len(self.flatten().transforms) | ||||||
|
||||||
def __call__(self, input_, start=0, end=None, threading=False, lazy: bool | None = None): | ||||||
return execute_compose( | ||||||
lazy_ = self.lazy if lazy is None else lazy | ||||||
policy = ComposeCompiler()(self.transforms, lazy_) | ||||||
lazy_strategy = policy["lazy_policy"] | ||||||
indices = policy["indices"] | ||||||
|
||||||
# permute the transforms if required | ||||||
transforms = self.transforms if indices is None else [self.transforms[i] for i in indices] | ||||||
|
||||||
result = execute_compose( | ||||||
input_, | ||||||
self.transforms, | ||||||
transforms, | ||||||
start=start, | ||||||
end=end, | ||||||
map_items=self.map_items, | ||||||
unpack_items=self.unpack_items, | ||||||
lazy=self.lazy, # type: ignore | ||||||
lazy_strategy=lazy_strategy, | ||||||
overrides=self.overrides, | ||||||
threading=threading, | ||||||
logger_name=self.logger_name, | ||||||
) | ||||||
|
||||||
# if the transforms were permuted, record it in the metadata for inversion | ||||||
if indices is not None: | ||||||
if isinstance(result, monai.data.MetaTensor): | ||||||
self.push_transform(result, extra_info={"applied_order": indices}) | ||||||
elif isinstance(result, Mapping): | ||||||
for key in result: # dictionary not change size during iteration | ||||||
if isinstance(result[key], monai.data.MetaTensor) or self.trace_key(key) in result: | ||||||
self.push_transform(result, key, extra_info={"applied_order": indices}) | ||||||
|
||||||
return result | ||||||
|
||||||
def inverse(self, data): | ||||||
invertible_transforms = [t for t in self.flatten().transforms if isinstance(t, InvertibleTransform)] | ||||||
if not invertible_transforms: | ||||||
warnings.warn("inverse has been called but no invertible transforms have been supplied") | ||||||
if self.options is not None: | ||||||
compiler = ComposeCompiler() | ||||||
policy = compiler(self.transforms, self.lazy, self.options) | ||||||
else: | ||||||
policy = ComposeCompiler().generate_policy() | ||||||
|
||||||
indices = policy["indices"] | ||||||
can_invert = policy["can_invert"] | ||||||
if can_invert is False: | ||||||
raise ValueError("'inverse' is not supported with options {self.options}") | ||||||
|
||||||
data_ = deepcopy(data) | ||||||
|
||||||
if indices is not None: | ||||||
applied_order = None | ||||||
if isinstance(data_, monai.data.MetaTensor): | ||||||
applied_order = self.pop_transform(data_)[TraceKeys.EXTRA_INFO]["applied_order"] | ||||||
elif isinstance(data_, Mapping): | ||||||
for key in data_: | ||||||
if isinstance(data_[key], monai.data.MetaTensor) or self.trace_key(key) in data_: | ||||||
applied_order = self.pop_transform(data_, key)[TraceKeys.EXTRA_INFO]["applied_order"] | ||||||
else: | ||||||
raise RuntimeError( | ||||||
f"Inverse only implemented for Mapping (dictionary) or MetaTensor data, got type {type(data)}." | ||||||
) | ||||||
if applied_order is None: | ||||||
# no invertible transforms have been applied | ||||||
return data_ | ||||||
|
||||||
# loop backwards over transforms | ||||||
for o in reversed(applied_order): | ||||||
if isinstance(self.transforms[o], InvertibleTransform): | ||||||
data_ = apply_transform(self.transforms[o].inverse, data_, self.map_items, self.unpack_items) | ||||||
else: | ||||||
invertible_transforms = [t for t in self.flatten().transforms if isinstance(t, InvertibleTransform)] | ||||||
if not invertible_transforms: | ||||||
warnings.warn("inverse has been called but no invertible transforms have been supplied") | ||||||
|
||||||
# loop backwards over transforms | ||||||
for t in reversed(invertible_transforms): | ||||||
if isinstance(t, LazyTrait) and t.lazy: | ||||||
if self.lazy is not False: | ||||||
warnings.warn( | ||||||
f"inversing {t.__class__.__name__} lazily may not implemented" | ||||||
"please set `lazy=False` before calling inverse." | ||||||
f"'lazy' is set to {self.lazy} but lazy execution is not supported when inverting. " | ||||||
f"'lazy' has been overridden to False for the call to inverse" | ||||||
) | ||||||
data = apply_transform(t.inverse, data, self.map_items, self.unpack_items) | ||||||
return data | ||||||
# loop backwards over transforms | ||||||
for t in reversed(invertible_transforms): | ||||||
# if isinstance(t, LazyTrait) and t.lazy: | ||||||
# warnings.warn( | ||||||
# f"inversing {t.__class__.__name__} lazily may not implemented" | ||||||
# "please set `lazy=False` before calling inverse." | ||||||
# ) | ||||||
data_ = apply_transform( | ||||||
t.inverse, data_, self.map_items, self.unpack_items, lazy=False, logger_name=self.logger_name | ||||||
) | ||||||
return data_ | ||||||
|
||||||
|
||||||
class OneOf(Compose): | ||||||
|
@@ -759,8 +955,7 @@ def inverse(self, data): | |||||
|
||||||
# loop backwards over transforms | ||||||
for o in reversed(applied_order): | ||||||
transform = self.transforms[o] | ||||||
if isinstance(transform, InvertibleTransform): | ||||||
if isinstance(self.transforms[o], InvertibleTransform): | ||||||
data = apply_transform(self.transforms[o].inverse, data, self.map_items, self.unpack_items) | ||||||
|
||||||
return data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to specify here we could accept a dict or a nested dict here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I'll create an example