diff --git a/caikit_computer_vision/data_model/__init__.py b/caikit_computer_vision/data_model/__init__.py index ac4effe..29a0d6f 100644 --- a/caikit_computer_vision/data_model/__init__.py +++ b/caikit_computer_vision/data_model/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. # Local -from . import image_classification, object_detection, tasks +from . import image_classification, image_segmentation, object_detection, tasks from .image_classification import * +from .image_segmentation import * from .object_detection import * diff --git a/caikit_computer_vision/data_model/image_segmentation.py b/caikit_computer_vision/data_model/image_segmentation.py new file mode 100644 index 0000000..1b76136 --- /dev/null +++ b/caikit_computer_vision/data_model/image_segmentation.py @@ -0,0 +1,46 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Data structures for segmentation in images.""" + + +# Standard +from typing import List + +# Third Party +from py_to_proto.dataclass_to_proto import Annotated, FieldNumber + +# First Party +from caikit.core import DataObjectBase, dataobject +from caikit.interfaces.common.data_model import ProducerId +from caikit.interfaces.vision import data_model as caikit_dm +import alog + +log = alog.use_channel("DATAM") + + +@dataobject(package="caikit_data_model.caikit_computer_vision") +class ObjectSegment(DataObjectBase): + score: Annotated[float, FieldNumber(1)] + label: Annotated[str, FieldNumber(2)] + # TODO: We should be able to specify subtype, i.e., PIL image mode. + # This mask should be image mode (L), 8 bit grayscale image treated + # as a binary image, where 0 is background, and 255 is part of the + # object to align with HF task definitions. + mask: Annotated[caikit_dm.Image, FieldNumber(3)] + + +@dataobject(package="caikit_data_model.caikit_computer_vision") +class ImageSegmentationResult(DataObjectBase): + object_segments: Annotated[List[ObjectSegment], FieldNumber(1)] + producer_id: Annotated[ProducerId, FieldNumber(2)] diff --git a/caikit_computer_vision/data_model/object_detection.py b/caikit_computer_vision/data_model/object_detection.py index ae9accb..763f493 100644 --- a/caikit_computer_vision/data_model/object_detection.py +++ b/caikit_computer_vision/data_model/object_detection.py @@ -24,8 +24,23 @@ from caikit.interfaces.common.data_model import ProducerId import alog +# Local +from .image_segmentation import ObjectSegment + log = alog.use_channel("DATAM") +# Image coordinates - TODO: Probably should standardize what we use for these... +@dataobject(package="caikit_data_model.caikit_computer_vision") +class Point2f(DataObjectBase): + x: Annotated[float, FieldNumber(1)] + y: Annotated[float, FieldNumber(2)] + + +@dataobject(package="caikit_data_model.caikit_computer_vision") +class Point2d(DataObjectBase): + x: Annotated[int, FieldNumber(1)] + y: Annotated[int, FieldNumber(2)] + @dataobject(package="caikit_data_model.caikit_computer_vision") class BoundingBox(DataObjectBase): @@ -35,11 +50,35 @@ class BoundingBox(DataObjectBase): ymax: Annotated[int, FieldNumber(4)] +@dataobject(package="caikit_data_model.caikit_computer_vision") +class AnomalyRegion(DataObjectBase): + score: Annotated[float, FieldNumber(1)] + # Bounding box and primary focus area of the detected anomaly; + # note that these coordinates are relative to the detected object. + box: Annotated[BoundingBox, FieldNumber(2)] + anomaly_hotspot: Annotated[Point2d, FieldNumber(3)] + + +@dataobject(package="caikit_data_model.caikit_computer_vision") +class Anomaly(DataObjectBase): + score: Annotated[float, FieldNumber(1)] + anomaly_threshold: Annotated[float, FieldNumber(2)] + detail_data: Annotated[str, FieldNumber(3)] + regions: Annotated[List[AnomalyRegion], FieldNumber(4)] + + @dataobject(package="caikit_data_model.caikit_computer_vision") class DetectedObject(DataObjectBase): score: Annotated[float, FieldNumber(1)] label: Annotated[str, FieldNumber(2)] box: Annotated[BoundingBox, FieldNumber(3)] + ### Optional segmentation information + # list of pixel coordinates representing the segmentation mask of the object. + object_segments: Annotated[List[Point2f], FieldNumber(4)] + # Optional run-length encoding of the object being described. + rle: Annotated[str, FieldNumber(5)] + ### Optional anomaly detection information + anomaly: Annotated[Anomaly, FieldNumber(6)] @dataobject(package="caikit_data_model.caikit_computer_vision") diff --git a/caikit_computer_vision/data_model/tasks.py b/caikit_computer_vision/data_model/tasks.py index 32fe4d8..cdef0a6 100644 --- a/caikit_computer_vision/data_model/tasks.py +++ b/caikit_computer_vision/data_model/tasks.py @@ -20,6 +20,7 @@ # Local from .image_classification import ImageClassificationResult +from .image_segmentation import ImageSegmentationResult from .object_detection import ObjectDetectionResult @@ -44,3 +45,15 @@ class ImageClassificationTask(TaskBase): and producing an iterable of objects containing class names and typically confidence scores. """ + + +@task( + required_parameters={"inputs": bytes}, + output_type=ImageSegmentationResult, +) +class ImageSegmentationTask(TaskBase): + """The image classification task is responsible for taking an input image + and producing a pixel mask with optional class names and confidence scores. + Note that at the moment, this task encapsulates all segmentation types, + I.e., instance, object, semantic, etc... + """ diff --git a/tests/data_model/test_tasks.py b/tests/data_model/test_tasks.py index e641249..db758d0 100644 --- a/tests/data_model/test_tasks.py +++ b/tests/data_model/test_tasks.py @@ -64,6 +64,7 @@ class InvalidType: ( tasks.ObjectDetectionTask, tasks.ImageClassificationTask, + tasks.ImageSegmentationTask, ), ) def test_tasks(