Skip to content

Commit

Permalink
refactor: Updating image segmentation module inference
Browse files Browse the repository at this point in the history
Signed-off-by: Yagmur Gizem Cinar <yagmur.cinar@ibm.com>
  • Loading branch information
Yagmur Gizem Cinar committed Feb 26, 2024
1 parent 95bd2c9 commit a35f8cc
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 31 deletions.
4 changes: 3 additions & 1 deletion caikit_computer_vision/data_model/image_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ class ObjectSegment(DataObjectBase):
# This mask should be image mode (L), 8 bit grayscale image treated
# as a binary image, where 0 is background, and 255 is part of the
# object to align with HF task definitions.
mask: Annotated[Union[caikit_dm.Image, List[float]], FieldNumber(5)]
# mask or polygon -- one of them is returned
polygon: Annotated[List[float], FieldNumber(5)]
mask: Annotated[caikit_dm.Image, FieldNumber(6)]


@dataobject(package="caikit_data_model.caikit_computer_vision")
Expand Down
8 changes: 4 additions & 4 deletions caikit_computer_vision/data_model/object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ class Point2d(DataObjectBase):

@dataobject(package="caikit_data_model.caikit_computer_vision")
class BoundingBox(DataObjectBase):
xmin: Annotated[int, FieldNumber(1)]
xmax: Annotated[int, FieldNumber(2)]
ymin: Annotated[int, FieldNumber(3)]
ymax: Annotated[int, FieldNumber(4)]
xmin: Annotated[float, FieldNumber(1)]
xmax: Annotated[float, FieldNumber(2)]
ymin: Annotated[float, FieldNumber(3)]
ymax: Annotated[float, FieldNumber(4)]


@dataobject(package="caikit_data_model.caikit_computer_vision")
Expand Down
29 changes: 17 additions & 12 deletions caikit_computer_vision/modules/segmentation/image_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def maskrcnn_postprocess(
results (dict): COCO json from an inference task
uuid (str): Unique identifier of image
image_id (int, optional): Image ID to store in the task. Defaults to 1.
threshold (float): Threshold in range (0,1] to be used for filtering bounding box predictions.
Default 0.2
Returns:
dict: Formatted COCO JSON in OCL format
"""
Expand All @@ -62,8 +63,12 @@ def maskrcnn_postprocess(
try:
result_instances = detector_postprocess(instances, image_size[0], image_size[1])
except IndexError as e:
log.error(f"Length of MASKS {len(results['MASKS'])}")
log.error("No instances found, setting empty annotations", exc_info=e)
error("<CCV61312083E>", f"Length of MASKS {len(results['MASKS'])}")
error(
"<CCV61312083E>",
"No instances found, setting empty annotations",
exc_info=e,
)
result_instances = []
coco_instances = instances_to_coco_json(result_instances, image_id)
coco_json = _assemble_images_coco(
Expand Down Expand Up @@ -167,11 +172,13 @@ def detectron2_to_coco(
if isinstance(inf_coco, str):
inf_coco = COCO(inf_coco)
else:
assert isinstance(inf_coco, COCO), "inf_coco must be an instance of class COCO"
error.type_check("<CCV61312083E>", COCO, inf_coco=inf_coco)

assert (
"categories" and "images" in inf_coco.dataset
), "inf_coco should contain both 'categories' and 'images' keys"
error.value_check(
"<CCV61312083E>",
("categories" and "images" in inf_coco.dataset),
"inf_coco should contain both 'categories' and 'images' keys",
)

# give ids to annotations
# correct the placement of the score key as in coco
Expand Down Expand Up @@ -230,10 +237,10 @@ def convert_seg_poly(inf_coco: Type[COCO]) -> Type[COCO]:
"""converts segmentation from detectron2 format to polygons.
Args:
inf_coco : coco object with annotations containing detectron2 formatted segmentations
inf_coco : coco object with annotations dictionary containing detectron2 formatted segmentations
Returns:
inf_coco : coco dict with annotations containing coco formatted polygons
inf_coco : coco object with annotations dictionary containing coco formatted polygons
"""

def mask_to_polygons(mask):
Expand All @@ -245,14 +252,12 @@ def mask_to_polygons(mask):
mask = np.ascontiguousarray(
mask
) # some versions of cv2 does not support incontiguous arr
res = cv2.findContours(
res, hierarchy = cv2.findContours(
mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE
)
hierarchy = res[-1]
if hierarchy is None: # empty mask
return [], False
has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0
res = res[-2]
res = [x.flatten() for x in res]
# These coordinates from OpenCV are integers in range [0, W-1 or H-1].
# We add 0.5 to turn them into real-value coordinate space. A better solution
Expand Down
10 changes: 2 additions & 8 deletions caikit_computer_vision/modules/segmentation/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,24 +84,18 @@ def load(cls, model_file: Union[str, "ModuleConfig"]) -> "ViTSegmenter":
def save(
self,
model_path: str,
segmentation_dirname: str = "image_segmentation",
):
"""Save the in-memory model to the given path.
Args:
model_path: str
Path that we want to export the segmentation model to.
segmentation_dirname: str
Subdirectory to which we want to save the segmentation model.
Default: image_segmentation
"""
saver = ModuleSaver(
self,
model_path=model_path,
)
segmentation_rel_path, segmentation_abs_path = saver.add_dir(
segmentation_dirname
)
segmentation_rel_path, segmentation_abs_path = saver.add_dir(model_path)
with saver:
saver.update_config(
{
Expand Down Expand Up @@ -181,7 +175,7 @@ def run(
score=results[idx]["attributes"]["score"],
category_id=results[idx]["category_id"],
bbox=BoundingBox(*results[idx]["bbox"]),
mask=results[idx]["segmentation"],
polygon=results[idx]["segmentation"],
area=results[idx]["area"],
)
for idx in range(num_objects)
Expand Down
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ dependencies = [
"transformers>=4.27.1,<5",
"torch>2.0,<3",
"timm>=0.9.5,<1",
"opencv-python",
"pycocotools",
"detectron2 @git+https://github.com/facebookresearch/detectron2.git",
"shapely"
"opencv-python>=4.9.0.80,<5",
"pycocotools>=2.0.7,<3",
"detectron2 @git+https://github.com/facebookresearch/detectron2.git@e70b9229d77aa39d85f8fa5266e6ea658e92eed3",
"shapely>=2.0.2,<3",
]

[project.urls]
Expand Down
3 changes: 1 addition & 2 deletions tests/modules/segmentation/test_segmentation_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ def test_save_model_and_bootstrap():
"""Ensure that we can save a model and reload it."""
model = ViTSegmenter.bootstrap(SEGMENTATION_MODEL_CKPT)
with tempfile.TemporaryDirectory() as model_dir:
model.save(model_dir, "segmenter")
del model
model.save(model_dir)
new_model = ViTSegmenter.load(model_dir)
preds = new_model.run(np.ones((800, 800, 3), dtype=np.uint8))
assert isinstance(preds, ImageSegmentationResult)

0 comments on commit a35f8cc

Please sign in to comment.