Skip to content

Commit

Permalink
Add tti docs and demo
Browse files Browse the repository at this point in the history
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
  • Loading branch information
alex-jw-brooks committed Jun 28, 2024
1 parent 96aa2e7 commit deb93c6
Show file tree
Hide file tree
Showing 5 changed files with 243 additions and 0 deletions.
133 changes: 133 additions & 0 deletions examples/text_to_image/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Text To Image (SDXL)
This directory provides guidance for running text to image inference for text to image and a few useful scripts for getting started.

## Task and Module Overview
The text to image task has only one required parameter, the input text, and produces a `caikit_computer_vision.data_model.CaptionedImage` in response, which wraps the provided input text, as well as the generated image.

Currently there are two modules for text to image.:
- `caikit_computer_vision.modules.text_to_image.TTIStub` - A simple stub which produces a blue image of the request height and width at inference. This module is purely used for testing purposes.

- `caikit_computer_vision.modules.text_to_image.SDXL` - A module implementing text to image via SDXL.

This document will help you get started with both at the library & runtime level, ending with a sample gRPC client that can be usde to hit models running in a Caikit runtime container.

## Building the Environment
The easiest way to get started is to build a virtual environment in the root directory of this repo. Make sure the root of this project is on the `PYTHONPATH` so that `caikit_computer_vision` is findable.

To install the project:
```bash
python3 -m venv venv
source venv/bin/activate
pip install .
```

Note that if you prefer running in Docker, you can build an image as you normally would, and mount things into a running container:
```bash
docker build -t caikit-computer-vision:latest .
```

## Creating the Models
For the remainder of this demo, commands are intended to be run from this directory. First, we will be creating our models & runtime config in a directory named `caikit`, which is convenient for running locally or mounting into a container.

Copy the runtime config from the root of this project into the `caikit` directory.
```bash
mkdir -p caikit/models
cp ../../runtime_config.yaml caikit/runtime_config.yaml
```

Next, create your models.
```bash
python create_tti_models.py
```

This will create two models.
1. The stub model, at `caikit/models/stub_model`
2. The SDXL turbo model, at `caikit/models/sdxl_turbo_model`

Note that the names of these directories will be their model IDs in caikit runtime.

## Running Local Inference / API Overview
The text to image API is simple.

### Stub Module
For the stub module, we take an input prompt, a height, and a width, and create a blue image of the specified height and width.
```python
run(
inputs: str,
height: int,
width: int
) -> CaptionedImage:
```

Example using the stub model created from above:
```python
>>> import caikit_computer_vision, caikit
>>> stub_model = caikit.load("caikit/models/stub_model")
>>> res = stub_model.run("This is a text", height=512, width=512)
```

The resulting object holds the provided input text under `.caption`:
```python
>>> res.caption
'This is a text'
```
And the image bytes stored as PNG under `.output.image_data`
```python
>>> res.output.image_data
b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x02\x00\x00\x00\x02\x00 ...
```
Note that the `output` object is a `Caikit` image backed by PIL. If you need a handle to it, you can call `as_pil()` to get handle to the PIL object as shown below.
```
>>> pil_im = res.output.as_pil()
>>> type(pil_im)
<class 'PIL.Image.Image'>
```

Grabbing a handle to the PIL image and then `.save()` on the result is the easiest way to save the image to disk.

### SDXL Module
The SDXL module is signature to the stub, with some additional options.
```python
run(
inputs: str,
height: int = 512,
width: int = 512,
num_steps: int = 1,
guidance_scale: float = 0.0,
negative_prompt: Optional[str] = None,
image_format: str = "png",
) -> CaptionedImage:
```

A full description of these args can be seen with `help(caikit_computer_vision.modules.text_to_image.SDXL.run)`. Notably, for SDXL turbo, guidance scale and negative prompt should be left as the defaults, as they were not used to train this model.

The `image_format` arg follows the same conventions as PIL and controls the format of the serialized bytes. An example for this module similar to the previous one is shown below, where we generate a picture of a puppy in a field, storing the image in jpeg format for serialization purposes.

```python
>>> import caikit_computer_vision, caikit
>>> stub_model = caikit.load("caikit/models/sdxl_turbo_model")
>>> res = stub_model.run("A golden retriever puppy sitting in a grassy field", height=512, width=512, num_steps=2, image_format="jpeg")
```


## Inference Through Runtime
To write a client, you'll need to export the proto files to compile. To do so, run `python export_protos.py`; this will use the runtime file you had previously copied to create a new directory called `protos`, containing the exported data model / task protos from caikit runtime.

Then to compile them, you can do something like the following; note that you may need to `pip install grpcio-tools` if it's not present in your environment, since it's not a dependency of `caikit_computer_vision`:
```bash
python -m grpc_tools.protoc -I protos --python_out=generated --grpc_python_out=generated protos/*.proto
```

In general, you will want to run Caikit Runtime in a Docker container. The easiest way to do this is to mount the `caikit` directory with your models into the container as shown below.
```bash
docker run -e CONFIG_FILES=/caikit/runtime_config.yaml \
-v $PWD/caikit/:/caikit \
-p 8080:8080 -p 8085:8085 \
caikit-computer-vision:latest python -m caikit.runtime
```

Then, you can hit it with a gRPC client using your compiled protobufs. A full example of inference via gRPC client calling both models can be found in `sample_client.py`.

Running `python sample_client.py` should produce two images.
- `stub_response_image.png` - blue image generated from the stub module
- `turbo_response_image.png` - picture of a golden retriever in a field generated by SDXL turbo
44 changes: 44 additions & 0 deletions examples/text_to_image/create_tti_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Creates and exports SDXL Turbo as a caikit module.
"""
from caikit_computer_vision.modules.text_to_image import TTIStub
import os

SCRIPT_DIR = os.path.dirname(__file__)
MODELS_DIR = os.path.join(SCRIPT_DIR, "caikit", "models")
STUB_MODEL_PATH = os.path.join(MODELS_DIR, "stub_model")
SDXL_TURBO_MODEL_PATH = os.path.join(MODELS_DIR, "sdxl_turbo_model")

if not os.path.isdir(MODELS_DIR):
os.mkdir(MODELS_DIR)

model = TTIStub.bootstrap("foobar")
model.save(STUB_MODEL_PATH)


### Make the model for SDXL turbo
import diffusers
from caikit_computer_vision.modules.text_to_image import SDXL

### Download the model for SDXL turbo...
sdxl_model = SDXL.bootstrap("stabilityai/sdxl-turbo")
sdxl_model.save(SDXL_TURBO_MODEL_PATH)
# There appears to be a bug in the way that sharded safetensors are reloaded into the
# pipeline from diffusers, and there ALSO appears to be a bug where passing the max
# safetensor shard size to diffusers on a pipeline doesn't work as exoected.
#
# it is unfortunate that we need this workaround, but delete
# the sharded u-net, and reexport it as one file. By default the
# max shard size if 10GB, and the turbo unit is barely larger than 10.
from shutil import rmtree
unet_path = os.path.join(SDXL_TURBO_MODEL_PATH, "sdxl_model", "unet")
try:
diffusers.UNet2DConditionModel.from_pretrained(unet_path)
except RuntimeError:
print("Unable to reload turbo u-net due to sharding issues; reexporting as single file")
rmtree(unet_path)
sdxl_model.pipeline.unet.save_pretrained(unet_path, max_shard_size="12GB")

# Make sure the model can be loaded and that we can get an image out of it
reloaded_model = SDXL.load(SDXL_TURBO_MODEL_PATH)
cap_im = reloaded_model.run("A golden retriever sitting in a grassy field")
print("[DONE]")
30 changes: 30 additions & 0 deletions examples/text_to_image/export_protos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Standard
from inspect import signature
from shutil import rmtree
import os

# First Party
from caikit.runtime.dump_services import dump_grpc_services
import caikit

SCRIPT_DIR=os.path.dirname(__file__)
PROTO_EXPORT_DIR=os.path.join(SCRIPT_DIR, "protos")
RUNTIME_CONFIG_PATH=os.path.join(SCRIPT_DIR, "caikit", "runtime_config.yaml")

if os.path.isdir(PROTO_EXPORT_DIR):
rmtree(PROTO_EXPORT_DIR)
# Configure caikit runtime
caikit.config.configure(config_yml_path=RUNTIME_CONFIG_PATH)

# Export gRPC services...
grpc_service_dumper_kwargs = {
"output_dir": PROTO_EXPORT_DIR,
"write_modules_file": True,
}
# Only keep things in the signature, e.g., old versions don't take write_modules_file
expected_grpc_params = signature(dump_grpc_services).parameters
grpc_service_dumper_kwargs = {
k: v for k, v in grpc_service_dumper_kwargs.items() if k in expected_grpc_params
}
dump_grpc_services(**grpc_service_dumper_kwargs)
# NOTE: If you need an http client for inference, use `dump_http_services` from caikit instead.
Empty file.
36 changes: 36 additions & 0 deletions examples/text_to_image/sample_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import io

from generated import (
computervisionservice_pb2_grpc,
)
from generated.ccv import texttoimagetaskrequest_pb2

from PIL import Image
import grpc


# Setup the client
port = 8085
channel = grpc.insecure_channel(f"localhost:{port}")

inference_stub = computervisionservice_pb2_grpc.ComputerVisionServiceStub(
channel=channel
)

inference_request = texttoimagetaskrequest_pb2.TextToImageTaskRequest(
inputs="A golden retriever sitting in a grassy field",
height=512,
width=512,
)

# Call to stub model...
response = inference_stub.TextToImageTaskPredict(
inference_request, metadata=[("mm-model-id", "stub_model")], timeout=60
)
Image.open(io.BytesIO(response.output.image_data)).save("stub_response_image.png")

# Call to SDXL turbo model...
response = inference_stub.TextToImageTaskPredict(
inference_request, metadata=[("mm-model-id", "sdxl_turbo_model")], timeout=60
)
Image.open(io.BytesIO(response.output.image_data)).save("turbo_response_image.png")

0 comments on commit deb93c6

Please sign in to comment.