diff --git a/caikit_computer_vision/data_model/__init__.py b/caikit_computer_vision/data_model/__init__.py index 29a0d6f..fc1f138 100644 --- a/caikit_computer_vision/data_model/__init__.py +++ b/caikit_computer_vision/data_model/__init__.py @@ -14,6 +14,7 @@ # Local from . import image_classification, image_segmentation, object_detection, tasks +from .flat_image import * from .image_classification import * from .image_segmentation import * from .object_detection import * diff --git a/caikit_computer_vision/data_model/tasks.py b/caikit_computer_vision/data_model/tasks.py index bbe8f73..bcc2a29 100644 --- a/caikit_computer_vision/data_model/tasks.py +++ b/caikit_computer_vision/data_model/tasks.py @@ -22,6 +22,7 @@ from caikit.core import TaskBase, task # Local +from .flat_image import FlatImage from .image_classification import ImageClassificationResult from .image_segmentation import ImageSegmentationResult from .object_detection import ObjectDetectionResult @@ -29,7 +30,7 @@ # TODO - add support for image DM primitives @task( - required_parameters={"inputs": Union[bytes, str]}, + required_parameters={"inputs": Union[bytes, str, FlatImage]}, output_type=ObjectDetectionResult, ) class ObjectDetectionTask(TaskBase): @@ -40,7 +41,7 @@ class ObjectDetectionTask(TaskBase): @task( - required_parameters={"inputs": Union[bytes, str]}, + required_parameters={"inputs": Union[bytes, str, FlatImage]}, output_type=ImageClassificationResult, ) class ImageClassificationTask(TaskBase): @@ -51,7 +52,7 @@ class ImageClassificationTask(TaskBase): @task( - required_parameters={"inputs": Union[bytes, str]}, + required_parameters={"inputs": Union[bytes, str, FlatImage]}, output_type=ImageSegmentationResult, ) class ImageSegmentationTask(TaskBase):