Skip to content

Commit

Permalink
Merge branch 'main' into pin_caikit
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Brooks <alex.brooks@ibm.com>
  • Loading branch information
alex-jw-brooks authored Jan 16, 2024
2 parents aff0807 + 2128fa4 commit 764aae5
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 1 deletion.
3 changes: 2 additions & 1 deletion caikit_computer_vision/data_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

# Local
from . import image_classification, object_detection, tasks
from . import image_classification, image_segmentation, object_detection, tasks
from .image_classification import *
from .image_segmentation import *
from .object_detection import *
46 changes: 46 additions & 0 deletions caikit_computer_vision/data_model/image_segmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data structures for segmentation in images."""


# Standard
from typing import List

# Third Party
from py_to_proto.dataclass_to_proto import Annotated, FieldNumber

# First Party
from caikit.core import DataObjectBase, dataobject
from caikit.interfaces.common.data_model import ProducerId
from caikit.interfaces.vision import data_model as caikit_dm
import alog

log = alog.use_channel("DATAM")


@dataobject(package="caikit_data_model.caikit_computer_vision")
class ObjectSegment(DataObjectBase):
score: Annotated[float, FieldNumber(1)]
label: Annotated[str, FieldNumber(2)]
# TODO: We should be able to specify subtype, i.e., PIL image mode.
# This mask should be image mode (L), 8 bit grayscale image treated
# as a binary image, where 0 is background, and 255 is part of the
# object to align with HF task definitions.
mask: Annotated[caikit_dm.Image, FieldNumber(3)]


@dataobject(package="caikit_data_model.caikit_computer_vision")
class ImageSegmentationResult(DataObjectBase):
object_segments: Annotated[List[ObjectSegment], FieldNumber(1)]
producer_id: Annotated[ProducerId, FieldNumber(2)]
39 changes: 39 additions & 0 deletions caikit_computer_vision/data_model/object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,23 @@
from caikit.interfaces.common.data_model import ProducerId
import alog

# Local
from .image_segmentation import ObjectSegment

log = alog.use_channel("DATAM")

# Image coordinates - TODO: Probably should standardize what we use for these...
@dataobject(package="caikit_data_model.caikit_computer_vision")
class Point2f(DataObjectBase):
x: Annotated[float, FieldNumber(1)]
y: Annotated[float, FieldNumber(2)]


@dataobject(package="caikit_data_model.caikit_computer_vision")
class Point2d(DataObjectBase):
x: Annotated[int, FieldNumber(1)]
y: Annotated[int, FieldNumber(2)]


@dataobject(package="caikit_data_model.caikit_computer_vision")
class BoundingBox(DataObjectBase):
Expand All @@ -35,11 +50,35 @@ class BoundingBox(DataObjectBase):
ymax: Annotated[int, FieldNumber(4)]


@dataobject(package="caikit_data_model.caikit_computer_vision")
class AnomalyRegion(DataObjectBase):
score: Annotated[float, FieldNumber(1)]
# Bounding box and primary focus area of the detected anomaly;
# note that these coordinates are relative to the detected object.
box: Annotated[BoundingBox, FieldNumber(2)]
anomaly_hotspot: Annotated[Point2d, FieldNumber(3)]


@dataobject(package="caikit_data_model.caikit_computer_vision")
class Anomaly(DataObjectBase):
score: Annotated[float, FieldNumber(1)]
anomaly_threshold: Annotated[float, FieldNumber(2)]
detail_data: Annotated[str, FieldNumber(3)]
regions: Annotated[List[AnomalyRegion], FieldNumber(4)]


@dataobject(package="caikit_data_model.caikit_computer_vision")
class DetectedObject(DataObjectBase):
score: Annotated[float, FieldNumber(1)]
label: Annotated[str, FieldNumber(2)]
box: Annotated[BoundingBox, FieldNumber(3)]
### Optional segmentation information
# list of pixel coordinates representing the segmentation mask of the object.
object_segments: Annotated[List[Point2f], FieldNumber(4)]
# Optional run-length encoding of the object being described.
rle: Annotated[str, FieldNumber(5)]
### Optional anomaly detection information
anomaly: Annotated[Anomaly, FieldNumber(6)]


@dataobject(package="caikit_data_model.caikit_computer_vision")
Expand Down
13 changes: 13 additions & 0 deletions caikit_computer_vision/data_model/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

# Local
from .image_classification import ImageClassificationResult
from .image_segmentation import ImageSegmentationResult
from .object_detection import ObjectDetectionResult


Expand All @@ -44,3 +45,15 @@ class ImageClassificationTask(TaskBase):
and producing an iterable of objects containing class names and typically
confidence scores.
"""


@task(
required_parameters={"inputs": bytes},
output_type=ImageSegmentationResult,
)
class ImageSegmentationTask(TaskBase):
"""The image classification task is responsible for taking an input image
and producing a pixel mask with optional class names and confidence scores.
Note that at the moment, this task encapsulates all segmentation types,
I.e., instance, object, semantic, etc...
"""
1 change: 1 addition & 0 deletions tests/data_model/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class InvalidType:
(
tasks.ObjectDetectionTask,
tasks.ImageClassificationTask,
tasks.ImageSegmentationTask,
),
)
def test_tasks(
Expand Down

0 comments on commit 764aae5

Please sign in to comment.