generated from caikit/caikit-template
-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add text to image task and output types
Add SDXL stub Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
- Loading branch information
1 parent
77cfede
commit 3d32fbb
Showing
6 changed files
with
142 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# Copyright The Caikit Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Data structures for segmentation in images.""" | ||
|
||
|
||
# Standard | ||
from typing import List | ||
|
||
# Third Party | ||
from py_to_proto.dataclass_to_proto import Annotated, FieldNumber | ||
|
||
# First Party | ||
from caikit.core import DataObjectBase, dataobject | ||
from caikit.interfaces.common.data_model import ProducerId | ||
from caikit.interfaces.vision import data_model as caikit_dm | ||
import alog | ||
|
||
log = alog.use_channel("DATAM") | ||
|
||
|
||
@dataobject(package="caikit_data_model.caikit_computer_vision") | ||
class TextToImageResult(DataObjectBase): | ||
# TODO: Align on the output format | ||
output: Annotated[caikit_dm.Image, FieldNumber(1)] | ||
producer_id: Annotated[ProducerId, FieldNumber(2)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# Copyright The Caikit Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# Local | ||
from .sdxl import SDXLStub |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# Copyright The Caikit Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Module for text to image via SDXL. | ||
""" | ||
# Standard | ||
from typing import Union, get_args | ||
import os | ||
|
||
# Third Party | ||
import numpy as np | ||
|
||
# First Party | ||
from caikit.core.modules import ModuleBase, ModuleConfig, ModuleSaver, module | ||
from caikit.interfaces.vision import data_model as caikit_dm | ||
import alog | ||
|
||
# Local | ||
from ...data_model import TextToImageResult | ||
from ...data_model.tasks import TextToImageTask | ||
|
||
log = alog.use_channel("SDXL") | ||
|
||
|
||
@module( | ||
id="28aa938b-1a33-11a0-11a3-bb9c3b1cbb11", | ||
name="Stub module for Text to Image", | ||
version="0.1.0", | ||
task=TextToImageTask, | ||
) | ||
class SDXLStub(ModuleBase): | ||
def __init__( | ||
self, | ||
model_name, | ||
) -> "SDXLStub": | ||
log.debug("STUB - initializing text to image instance") | ||
super().__init__() | ||
self.model_name = model_name | ||
|
||
@classmethod | ||
def load(cls, model_path: Union[str, "ModuleConfig"]) -> "SDXLStub": | ||
config = ModuleConfig.load(model_path) | ||
return cls.bootstrap(config.model_name) | ||
|
||
@classmethod | ||
def bootstrap(cls, model_name: str) -> "SDXLStub": | ||
return cls(model_name) | ||
|
||
def save(self, model_path: str): | ||
saver = ModuleSaver( | ||
self, | ||
model_path=model_path, | ||
) | ||
with saver: | ||
saver.update_config({"model_name": self.model_name}) | ||
|
||
def run(self, inputs: str, height: int, width: int) -> TextToImageResult: | ||
"""Generates an image matching the provided height and width.""" | ||
log.debug("STUB - running text to image inference") | ||
r_channel = np.full((height, width), 0, dtype=np.uint8) | ||
g_channel = np.full((height, width), 100, dtype=np.uint8) | ||
b_channel = np.full((height, width), 200, dtype=np.uint8) | ||
img = np.stack((r_channel, g_channel, b_channel), axis=2) | ||
return TextToImageResult( | ||
output=caikit_dm.Image(img), | ||
) |