Skip to content

Commit

Permalink
scripts/vsmlrt.py: add prefer_nhwc flag to the ort_cuda backend
Browse files Browse the repository at this point in the history
  • Loading branch information
WolframRhodium committed Apr 20, 2024
1 parent 187249d commit 0abb2a3
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions scripts/vsmlrt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "3.20.6"
__version__ = "3.20.7"

__all__ = [
"Backend", "BackendV2",
Expand Down Expand Up @@ -87,6 +87,7 @@ class ORT_CUDA:
fp16: bool = False
use_cuda_graph: bool = False # preview, not supported by all models
fp16_blacklist_ops: typing.Optional[typing.Sequence[str]] = None
prefer_nhwc: bool = False

# internal backend attributes
supports_onnx_serialization: bool = True
Expand Down Expand Up @@ -2032,6 +2033,17 @@ def _inference(
fp16_blacklist_ops=backend.fp16_blacklist_ops
)
elif isinstance(backend, Backend.ORT_CUDA):
kwargs = dict()

version_list = core.ort.Version().get("onnxruntime_version", b"0.0.0").split(b'.')
if len(version_list) != 3:
version = (0, 0, 0)
else:
version = tuple(map(int, version_list))

if version >= (1, 18, 0):
kwargs["prefer_nhwc"] = backend.prefer_nhwc

clip = core.ort.Model(
clips, network_path,
overlap=overlap, tilesize=tilesize,
Expand All @@ -2043,7 +2055,8 @@ def _inference(
fp16=backend.fp16,
path_is_serialization=path_is_serialization,
use_cuda_graph=backend.use_cuda_graph,
fp16_blacklist_ops=backend.fp16_blacklist_ops
fp16_blacklist_ops=backend.fp16_blacklist_ops,
**kwargs
)
elif isinstance(backend, Backend.OV_CPU):
version = tuple(map(int, core.ov.Version().get("openvino_version", b"0.0.0").split(b'-')[0].split(b'.')))
Expand Down

0 comments on commit 0abb2a3

Please sign in to comment.