Skip to content

Commit

Permalink
Add caption, rename output type
Browse files Browse the repository at this point in the history
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
  • Loading branch information
alex-jw-brooks committed Jun 26, 2024
1 parent 5ea003b commit a329759
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 11 deletions.
4 changes: 2 additions & 2 deletions caikit_computer_vision/data_model/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from .image_classification import ImageClassificationResult
from .image_segmentation import ImageSegmentationResult
from .object_detection import ObjectDetectionResult
from .text_to_image import TextToImageResult
from .text_to_image import CaptionedImage


# TODO - add support for image DM primitives
Expand Down Expand Up @@ -66,7 +66,7 @@ class ImageSegmentationTask(TaskBase):

@task(
required_parameters={"inputs": str},
output_type=TextToImageResult,
output_type=CaptionedImage,
)
class TextToImageTask(TaskBase):
"""The text to image task is responsible for taking an input text prompt, along with
Expand Down
9 changes: 3 additions & 6 deletions caikit_computer_vision/data_model/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
"""Data structures for text to image."""


# Standard
from typing import List

# Third Party
from py_to_proto.dataclass_to_proto import Annotated, FieldNumber

Expand All @@ -30,7 +27,7 @@


@dataobject(package="caikit_data_model.caikit_computer_vision")
class TextToImageResult(DataObjectBase):
# TODO: Align on the output format
class CaptionedImage(DataObjectBase):
output: Annotated[caikit_dm.Image, FieldNumber(1)]
producer_id: Annotated[ProducerId, FieldNumber(2)]
caption: Annotated[str, FieldNumber(2)]
producer_id: Annotated[ProducerId, FieldNumber(3)]
7 changes: 4 additions & 3 deletions caikit_computer_vision/modules/text_to_image/tti_stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import alog

# Local
from ...data_model import TextToImageResult
from ...data_model import CaptionedImage
from ...data_model.tasks import TextToImageTask

log = alog.use_channel("TTI_STUB")
Expand Down Expand Up @@ -64,13 +64,14 @@ def save(self, model_path: str):
with saver:
saver.update_config({"model_name": self.model_name})

def run(self, inputs: str, height: int, width: int) -> TextToImageResult:
def run(self, inputs: str, height: int, width: int) -> CaptionedImage:
"""Generates an image matching the provided height and width."""
log.debug("STUB - running text to image inference")
r_channel = np.full((height, width), 0, dtype=np.uint8)
g_channel = np.full((height, width), 100, dtype=np.uint8)
b_channel = np.full((height, width), 200, dtype=np.uint8)
img = np.stack((r_channel, g_channel, b_channel), axis=2)
return TextToImageResult(
return CaptionedImage(
output=caikit_dm.Image(img),
caption=inputs,
)
2 changes: 2 additions & 0 deletions tests/modules/text_to_image/test_tti_stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def test_tti_stub():

# Make sure we can run a fake inference on it
pred = model.run("This is a prompt", height=500, width=550)
assert isinstance(pred, caikit_computer_vision.data_model.CaptionedImage)
assert pred.caption == "This is a prompt"
pil_img = pred.output.as_pil()
assert pil_img.width == 550
assert pil_img.height == 500
Expand Down

0 comments on commit a329759

Please sign in to comment.