diff --git a/caikit_computer_vision/data_model/__init__.py b/caikit_computer_vision/data_model/__init__.py index fc1f138..05d88eb 100644 --- a/caikit_computer_vision/data_model/__init__.py +++ b/caikit_computer_vision/data_model/__init__.py @@ -18,3 +18,4 @@ from .image_classification import * from .image_segmentation import * from .object_detection import * +from .text_to_image import * diff --git a/caikit_computer_vision/data_model/tasks.py b/caikit_computer_vision/data_model/tasks.py index bcc2a29..b6fdc17 100644 --- a/caikit_computer_vision/data_model/tasks.py +++ b/caikit_computer_vision/data_model/tasks.py @@ -26,6 +26,7 @@ from .image_classification import ImageClassificationResult from .image_segmentation import ImageSegmentationResult from .object_detection import ObjectDetectionResult +from .text_to_image import TextToImageResult # TODO - add support for image DM primitives @@ -61,3 +62,14 @@ class ImageSegmentationTask(TaskBase): Note that at the moment, this task encapsulates all segmentation types, I.e., instance, object, semantic, etc... """ + + +@task( + required_parameters={"inputs": str}, + output_type=TextToImageResult, +) +class TextToImageTask(TaskBase): + """The text to image task is responsible for taking an input text prompt, along with + other optional image generation parameters, e.g., image height and width, + and generating an image. + """ diff --git a/caikit_computer_vision/data_model/text_to_image.py b/caikit_computer_vision/data_model/text_to_image.py new file mode 100644 index 0000000..3e65bac --- /dev/null +++ b/caikit_computer_vision/data_model/text_to_image.py @@ -0,0 +1,36 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Data structures for segmentation in images.""" + + +# Standard +from typing import List + +# Third Party +from py_to_proto.dataclass_to_proto import Annotated, FieldNumber + +# First Party +from caikit.core import DataObjectBase, dataobject +from caikit.interfaces.common.data_model import ProducerId +from caikit.interfaces.vision import data_model as caikit_dm +import alog + +log = alog.use_channel("DATAM") + + +@dataobject(package="caikit_data_model.caikit_computer_vision") +class TextToImageResult(DataObjectBase): + # TODO: Align on the output format + output: Annotated[caikit_dm.Image, FieldNumber(1)] + producer_id: Annotated[ProducerId, FieldNumber(2)] diff --git a/caikit_computer_vision/modules/__init__.py b/caikit_computer_vision/modules/__init__.py index 2d3a94c..b4dbe7a 100644 --- a/caikit_computer_vision/modules/__init__.py +++ b/caikit_computer_vision/modules/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # Local -from . import object_detection, segmentation +from . import object_detection, segmentation, text_to_image diff --git a/caikit_computer_vision/modules/text_to_image/__init__.py b/caikit_computer_vision/modules/text_to_image/__init__.py new file mode 100644 index 0000000..1c5180f --- /dev/null +++ b/caikit_computer_vision/modules/text_to_image/__init__.py @@ -0,0 +1,16 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Local +from .sdxl import SDXLStub diff --git a/caikit_computer_vision/modules/text_to_image/sdxl.py b/caikit_computer_vision/modules/text_to_image/sdxl.py new file mode 100644 index 0000000..b746eb8 --- /dev/null +++ b/caikit_computer_vision/modules/text_to_image/sdxl.py @@ -0,0 +1,76 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module for text to image via SDXL. +""" +# Standard +from typing import Union, get_args +import os + +# Third Party +import numpy as np + +# First Party +from caikit.core.modules import ModuleBase, ModuleConfig, ModuleSaver, module +from caikit.interfaces.vision import data_model as caikit_dm +import alog + +# Local +from ...data_model import TextToImageResult +from ...data_model.tasks import TextToImageTask + +log = alog.use_channel("SDXL") + + +@module( + id="28aa938b-1a33-11a0-11a3-bb9c3b1cbb11", + name="Stub module for Text to Image", + version="0.1.0", + task=TextToImageTask, +) +class SDXLStub(ModuleBase): + def __init__( + self, + model_name, + ) -> "SDXLStub": + log.debug("STUB - initializing text to image instance") + super().__init__() + self.model_name = model_name + + @classmethod + def load(cls, model_path: Union[str, "ModuleConfig"]) -> "SDXLStub": + config = ModuleConfig.load(model_path) + return cls.bootstrap(config.model_name) + + @classmethod + def bootstrap(cls, model_name: str) -> "SDXLStub": + return cls(model_name) + + def save(self, model_path: str): + saver = ModuleSaver( + self, + model_path=model_path, + ) + with saver: + saver.update_config({"model_name": self.model_name}) + + def run(self, inputs: str, height: int, width: int) -> TextToImageResult: + """Generates an image matching the provided height and width.""" + log.debug("STUB - running text to image inference") + r_channel = np.full((height, width), 0, dtype=np.uint8) + g_channel = np.full((height, width), 100, dtype=np.uint8) + b_channel = np.full((height, width), 200, dtype=np.uint8) + img = np.stack((r_channel, g_channel, b_channel), axis=2) + return TextToImageResult( + output=caikit_dm.Image(img), + )