Skip to content

Commit

Permalink
Add modelAPI related stuff (openvinotoolkit#1219)
Browse files Browse the repository at this point in the history
* Add modelAPI related stuff

* Address PR comments + fix tests

* Add comment + modify export

* Update src/anomalib/deploy/export.py

Co-authored-by: Samet Akcay <samet.akcay@intel.com>

---------

Co-authored-by: Ashwin Vaidya <ashwinitinvaidya@gmail.com>
Co-authored-by: Samet Akcay <samet.akcay@intel.com>
  • Loading branch information
3 people committed Aug 21, 2023
1 parent cdedf05 commit 0b5d969
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 4 deletions.
77 changes: 75 additions & 2 deletions src/anomalib/deploy/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
from enum import Enum
from pathlib import Path
from typing import Any
from warnings import warn

import numpy as np
import torch
from openvino.runtime import Core, serialize
from torch import Tensor
from torch.types import Number

Expand Down Expand Up @@ -113,7 +115,7 @@ def export(
# Export model to onnx and convert to OpenVINO IR if export mode is set to OpenVINO.
onnx_path = export_to_onnx(model, input_size, export_path)
if export_mode == ExportMode.OPENVINO:
export_to_openvino(export_path, onnx_path)
export_to_openvino(export_path, onnx_path, metadata, input_size)

else:
raise ValueError(f"Unknown export mode {export_mode}")
Expand Down Expand Up @@ -156,12 +158,83 @@ def export_to_onnx(model: AnomalyModule, input_size: tuple[int, int], export_pat
return onnx_path


def export_to_openvino(export_path: str | Path, onnx_path: Path) -> None:
def export_to_openvino(
export_path: str | Path, onnx_path: Path, metadata: dict[str, Any], input_size: tuple[int, int]
) -> None:
"""Convert onnx model to OpenVINO IR.
Args:
export_path (str | Path): Path to the root folder of the exported model.
onnx_path (Path): Path to the exported onnx model.
metadata (dict[str, Any]): Metadata for the exported model.
input_size (tuple[int, int]): Input size of the model. Used for adding metadata to the IR.
"""
optimize_command = ["mo", "--input_model", str(onnx_path), "--output_dir", str(export_path)]
subprocess.run(optimize_command, check=True) # nosec
_add_metadata_to_ir(str(export_path) + f"/{onnx_path.with_suffix('.xml').name}", metadata, input_size)


def _add_metadata_to_ir(xml_file: str, metadata: dict[str, Any], input_size: tuple[int, int]) -> None:
"""Adds the metadata to the model IR.
Adds the metadata to the model IR. So that it can be used with the new modelAPI.
This is because the metadata.json is not used by the new modelAPI.
# TODO CVS-114640
# TODO: Remove this function when Anomalib is upgraded as the model graph will contain the required ops
Args:
xml_file (str): Path to the xml file.
metadata (dict[str, Any]): Metadata to add to the model.
input_size (tuple[int, int]): Input size of the model.
"""
core = Core()
model = core.read_model(xml_file)

_metadata = {}
for key, value in metadata.items():
if key in ("transform", "min", "max"):
continue
_metadata[("model_info", key)] = value

# Add transforms
if "transform" in metadata:
for transform_dict in metadata["transform"]["transform"]["transforms"]:
transform = transform_dict["__class_fullname__"]
if transform == "Normalize":
_metadata[("model_info", "mean_values")] = _serialize_list([x * 255.0 for x in transform_dict["mean"]])
_metadata[("model_info", "scale_values")] = _serialize_list([x * 255.0 for x in transform_dict["std"]])
elif transform == "Resize":
_metadata[("model_info", "orig_height")] = transform_dict["height"]
_metadata[("model_info", "orig_width")] = transform_dict["width"]
else:
warn(f"Transform {transform} is not supported currently")

# Since we only need the diff of max and min, we fuse the min and max into one op
if "min" in metadata and "max" in metadata:
_metadata[("model_info", "normalization_scale")] = metadata["max"] - metadata["min"]

_metadata[("model_info", "reverse_input_channels")] = True
_metadata[("model_info", "model_type")] = "AnomalyDetection"
_metadata[("model_info", "labels")] = ["Normal", "Anomaly"]
_metadata[("model_info", "image_shape")] = _serialize_list(input_size)

for k, data in _metadata.items():
model.set_rt_info(data, list(k))

tmp_xml_path = Path(xml_file).parent / "tmp.xml"
serialize(model, str(tmp_xml_path))
tmp_xml_path.rename(xml_file)
# since we create new openvino IR files, we don't need the bin file. So we delete it.
tmp_xml_path.with_suffix(".bin").unlink()


def _serialize_list(arr: list[int] | list[float] | tuple[int, int]) -> str:
"""Serializes the list to a string.
Args:
arr (list[int] | list[float] | tuple[int, int]): List to serialize.
Returns:
str: Serialized list.
"""
return " ".join(map(str, arr))
2 changes: 1 addition & 1 deletion tests/pre_merge/tools/test_openvino_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_openvino_inference(
"--input",
get_dummy_inference_image,
"--output",
project_path + "/output",
project_path + "/output.png",
]
)
infer(arguments)
2 changes: 1 addition & 1 deletion tests/pre_merge/tools/test_torch_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_torch_inference(
"--input",
get_dummy_inference_image,
"--output",
project_path + "/output",
project_path + "/output.png",
]
)
infer(arguments)

0 comments on commit 0b5d969

Please sign in to comment.