diff --git a/caikit_computer_vision/data_model/tasks.py b/caikit_computer_vision/data_model/tasks.py index b6fdc17..91bfe3a 100644 --- a/caikit_computer_vision/data_model/tasks.py +++ b/caikit_computer_vision/data_model/tasks.py @@ -26,7 +26,7 @@ from .image_classification import ImageClassificationResult from .image_segmentation import ImageSegmentationResult from .object_detection import ObjectDetectionResult -from .text_to_image import TextToImageResult +from .text_to_image import CaptionedImage # TODO - add support for image DM primitives @@ -66,7 +66,7 @@ class ImageSegmentationTask(TaskBase): @task( required_parameters={"inputs": str}, - output_type=TextToImageResult, + output_type=CaptionedImage, ) class TextToImageTask(TaskBase): """The text to image task is responsible for taking an input text prompt, along with diff --git a/caikit_computer_vision/data_model/text_to_image.py b/caikit_computer_vision/data_model/text_to_image.py index dcd1c14..544b586 100644 --- a/caikit_computer_vision/data_model/text_to_image.py +++ b/caikit_computer_vision/data_model/text_to_image.py @@ -14,9 +14,6 @@ """Data structures for text to image.""" -# Standard -from typing import List - # Third Party from py_to_proto.dataclass_to_proto import Annotated, FieldNumber @@ -30,7 +27,7 @@ @dataobject(package="caikit_data_model.caikit_computer_vision") -class TextToImageResult(DataObjectBase): - # TODO: Align on the output format +class CaptionedImage(DataObjectBase): output: Annotated[caikit_dm.Image, FieldNumber(1)] - producer_id: Annotated[ProducerId, FieldNumber(2)] + caption: Annotated[str, FieldNumber(2)] + producer_id: Annotated[ProducerId, FieldNumber(3)] diff --git a/caikit_computer_vision/modules/text_to_image/tti_stub.py b/caikit_computer_vision/modules/text_to_image/tti_stub.py index 098f278..cab805b 100644 --- a/caikit_computer_vision/modules/text_to_image/tti_stub.py +++ b/caikit_computer_vision/modules/text_to_image/tti_stub.py @@ -26,7 +26,7 @@ import alog # Local -from ...data_model import TextToImageResult +from ...data_model import CaptionedImage from ...data_model.tasks import TextToImageTask log = alog.use_channel("TTI_STUB") @@ -64,13 +64,14 @@ def save(self, model_path: str): with saver: saver.update_config({"model_name": self.model_name}) - def run(self, inputs: str, height: int, width: int) -> TextToImageResult: + def run(self, inputs: str, height: int, width: int) -> CaptionedImage: """Generates an image matching the provided height and width.""" log.debug("STUB - running text to image inference") r_channel = np.full((height, width), 0, dtype=np.uint8) g_channel = np.full((height, width), 100, dtype=np.uint8) b_channel = np.full((height, width), 200, dtype=np.uint8) img = np.stack((r_channel, g_channel, b_channel), axis=2) - return TextToImageResult( + return CaptionedImage( output=caikit_dm.Image(img), + caption=inputs, ) diff --git a/tests/modules/text_to_image/test_tti_stub.py b/tests/modules/text_to_image/test_tti_stub.py index 0d62e64..4e4f7f5 100644 --- a/tests/modules/text_to_image/test_tti_stub.py +++ b/tests/modules/text_to_image/test_tti_stub.py @@ -29,6 +29,8 @@ def test_tti_stub(): # Make sure we can run a fake inference on it pred = model.run("This is a prompt", height=500, width=550) + assert isinstance(pred, caikit_computer_vision.data_model.CaptionedImage) + assert pred.caption == "This is a prompt" pil_img = pred.output.as_pil() assert pil_img.width == 550 assert pil_img.height == 500