Skip to content

Commit

Permalink
Add text to image task and output types
Browse files Browse the repository at this point in the history
Add SDXL stub

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
  • Loading branch information
alex-jw-brooks committed Jun 26, 2024
1 parent 77cfede commit 3d32fbb
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 1 deletion.
1 change: 1 addition & 0 deletions caikit_computer_vision/data_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
from .image_classification import *
from .image_segmentation import *
from .object_detection import *
from .text_to_image import *
12 changes: 12 additions & 0 deletions caikit_computer_vision/data_model/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .image_classification import ImageClassificationResult
from .image_segmentation import ImageSegmentationResult
from .object_detection import ObjectDetectionResult
from .text_to_image import TextToImageResult


# TODO - add support for image DM primitives
Expand Down Expand Up @@ -61,3 +62,14 @@ class ImageSegmentationTask(TaskBase):
Note that at the moment, this task encapsulates all segmentation types,
I.e., instance, object, semantic, etc...
"""


@task(
required_parameters={"inputs": str},
output_type=TextToImageResult,
)
class TextToImageTask(TaskBase):
"""The text to image task is responsible for taking an input text prompt, along with
other optional image generation parameters, e.g., image height and width,
and generating an image.
"""
36 changes: 36 additions & 0 deletions caikit_computer_vision/data_model/text_to_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data structures for segmentation in images."""


# Standard
from typing import List

# Third Party
from py_to_proto.dataclass_to_proto import Annotated, FieldNumber

# First Party
from caikit.core import DataObjectBase, dataobject
from caikit.interfaces.common.data_model import ProducerId
from caikit.interfaces.vision import data_model as caikit_dm
import alog

log = alog.use_channel("DATAM")


@dataobject(package="caikit_data_model.caikit_computer_vision")
class TextToImageResult(DataObjectBase):
# TODO: Align on the output format
output: Annotated[caikit_dm.Image, FieldNumber(1)]
producer_id: Annotated[ProducerId, FieldNumber(2)]
2 changes: 1 addition & 1 deletion caikit_computer_vision/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Local
from . import object_detection, segmentation
from . import object_detection, segmentation, text_to_image
16 changes: 16 additions & 0 deletions caikit_computer_vision/modules/text_to_image/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Local
from .sdxl import SDXLStub
76 changes: 76 additions & 0 deletions caikit_computer_vision/modules/text_to_image/sdxl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for text to image via SDXL.
"""
# Standard
from typing import Union, get_args
import os

# Third Party
import numpy as np

# First Party
from caikit.core.modules import ModuleBase, ModuleConfig, ModuleSaver, module
from caikit.interfaces.vision import data_model as caikit_dm
import alog

# Local
from ...data_model import TextToImageResult
from ...data_model.tasks import TextToImageTask

log = alog.use_channel("SDXL")


@module(
id="28aa938b-1a33-11a0-11a3-bb9c3b1cbb11",
name="Stub module for Text to Image",
version="0.1.0",
task=TextToImageTask,
)
class SDXLStub(ModuleBase):
def __init__(
self,
model_name,
) -> "SDXLStub":
log.debug("STUB - initializing text to image instance")
super().__init__()
self.model_name = model_name

@classmethod
def load(cls, model_path: Union[str, "ModuleConfig"]) -> "SDXLStub":
config = ModuleConfig.load(model_path)
return cls.bootstrap(config.model_name)

@classmethod
def bootstrap(cls, model_name: str) -> "SDXLStub":
return cls(model_name)

def save(self, model_path: str):
saver = ModuleSaver(
self,
model_path=model_path,
)
with saver:
saver.update_config({"model_name": self.model_name})

def run(self, inputs: str, height: int, width: int) -> TextToImageResult:
"""Generates an image matching the provided height and width."""
log.debug("STUB - running text to image inference")
r_channel = np.full((height, width), 0, dtype=np.uint8)
g_channel = np.full((height, width), 100, dtype=np.uint8)
b_channel = np.full((height, width), 200, dtype=np.uint8)
img = np.stack((r_channel, g_channel, b_channel), axis=2)
return TextToImageResult(
output=caikit_dm.Image(img),
)

0 comments on commit 3d32fbb

Please sign in to comment.