Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update demo #55

Merged
merged 7 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@ demo/train_output/
# Demo runtime files and folders that may be autocreated
examples/runtime/generated/*_pb2.py
examples/runtime/generated/*_pb2_grpc.py
examples/runtime/generated/ccv
examples/runtime/generated/caikit_data_model
examples/runtime/modules.json
examples/runtime/models
examples/runtime/protos
examples/runtime/train_data
23 changes: 18 additions & 5 deletions examples/runtime/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,11 @@ class ObjectDetectionResult(DataObjectBase):
To export the `.proto` files, run `python3 dump_protos.py`. This will delete your existing `protos` directory and recreate it. Inspecting the contents of the `protos` folder; you should see the following:

- A file named `openapi.json`
- Lots of `.proto` files, including an `objectdetectionresult.proto`, which contains the definition for the message type corresponding to the `ObjectDetectionResult` data model class, shown below
- Lots of `.proto` files, including one of the following:
- More recent versions of Caikit: `caikit_data_model.caikit_computer_vision.objectdetectionresult.proto`
- Older versions of Caikit: `objectdetectionresult.proto`

Regardless of its name, this file contains the `ObjectDetectionResult` data model class, shown below:

```protobuf
...
Expand Down Expand Up @@ -99,7 +103,10 @@ The declaration of `.run()` for the transformer-based module is shown below:
) -> ObjectDetectionResult:
```

Where `image_pil_backend.PIL_SOURCE_TYPES` is a union of types that can be resolved into a PIL image, one of which is `bytes`. As such, the `.run()` declaration is compatible with the task declaration; because of the way the task is defined, the runtime expects `inputs` to be of type `bytes`, with `threshold` as an optional float parameter. This is exactly what is defined in the `objectdetectiontaskrequest.proto` file, shown below.
Where `image_pil_backend.PIL_SOURCE_TYPES` is a union of types that can be resolved into a PIL image, one of which is `bytes`. As such, the `.run()` declaration is compatible with the task declaration; because of the way the task is defined, the runtime expects `inputs` to be of type `bytes`, with `threshold` as an optional float parameter. We can find the message type defining exactly this in the task request definition proto file, which is one of the following:

- More recent versions of Caikit: `ccv.objectdetectiontaskrequest.proto`; here, ccv is the `service_generation` name defined in our runtime config
- Older versions of Caikit: `objectdetectiontaskrequest.proto`

```protobuf
...
Expand Down Expand Up @@ -137,7 +144,7 @@ Notice that here, `threshold` was changed to a `str` which causes a type collisi
Cannot generate task rpc for <class 'caikit_computer_vision.data_model.tasks.ObjectDetectionTask'>: Conflicting value types for arg threshold: <class 'str'> != <class 'float'>
```

and you won't have a `objectdetectiontaskrequest.proto` in the dumped protos.
and you won't have a `ccv.objectdetectiontaskrequest.proto`/`objectdetectiontaskrequest.proto` in the dumped protos.

#### Train Messages

Expand All @@ -157,13 +164,19 @@ Currently, the stub example for `.train` on the object detector class is defined
```

Where `ObjectDetectionTrainSet` is a data model object. The train message file name is built with the following name, in all lowercase letters:
`{{task_name}}task{{impl_class_name}}trainrequest.proto`
`{{service_gen_name}}.{{task_name}}task{{impl_class_name}}trainrequest.proto`
where:

- `{{service_gen_name}}` is the name of your service generation key from your runtime config, e.g., `ccv` As before, this is only used for more recent versions of Caikit; in older versions, the leading `{{service_gen_name}}.` is omitted.
- `{{task_name}}` is the name of the task, in this case `objectdetection`
- `{{impl_class_name}}` is the name of the implementing module class being considered, in this cases `transformersobjectdetector`

As a result we get a file named: `objectdetectiontasktransformersobjectdetectortrainrequest.proto`, which contains the message definition to trigger a train request.
As a result we get a file named one of the following:

- More recent versions of Caikit: `ccv.objectdetectiontasktransformersobjectdetectortrainrequest.proto`
- Older versions of Caikit: `objectdetectiontasktransformersobjectdetectortrainrequest.proto`

which contains the message definition to trigger a train request.

```protobuf
message ObjectDetectionTaskTransformersObjectDetectorTrainRequest {
Expand Down
16 changes: 14 additions & 2 deletions examples/runtime/dump_protos.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""

# Standard
from inspect import signature
from pathlib import Path
from shutil import rmtree
import json
Expand All @@ -40,8 +41,19 @@ def export_protos():
rmtree(PROTO_EXPORT_DIR)
# Configure caikit runtime
caikit.config.configure(config_dict=RUNTIME_CONFIG)
# Dump proto files
dump_grpc_services(output_dir=PROTO_EXPORT_DIR)
# Export gRPC services first...
# grpc service dumper kwargs depend on the version of caikit we are using
grpc_service_dumper_kwargs = {
"output_dir": PROTO_EXPORT_DIR,
"write_modules_file": True,
}
# Only keep things in the signature, e.g., old versions don't take write_modules_file
expected_grpc_params = signature(dump_grpc_services).parameters
grpc_service_dumper_kwargs = {
k: v for k, v in grpc_service_dumper_kwargs.items() if k in expected_grpc_params
}
dump_grpc_services(**grpc_service_dumper_kwargs)
# Then export HTTP services...
dump_http_services(output_dir=PROTO_EXPORT_DIR)


Expand Down
112 changes: 80 additions & 32 deletions examples/runtime/run_train_and_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,70 +30,118 @@
)

# pylint: disable=no-name-in-module,import-error
from generated import (
computervisionservice_pb2_grpc,
computervisiontrainingservice_pb2_grpc,
objectdetectiontaskrequest_pb2,
)
from generated import (
objectdetectiontasktransformersobjectdetectortrainrequest_pb2 as odt_request_pb2,
)
from generated import objectdetectiontrainset_pb2
try:
# Third Party
from generated import (
computervisionservice_pb2_grpc,
computervisiontrainingservice_pb2_grpc,
)
except ImportError:
raise ImportError("Failed to import cv service; did you compile your protos?")

# The location of these imported message types depends on the version of Caikit
# that we are using.
try:
# Third Party
from generated.caikit_data_model.caikit_computer_vision import (
objectdetectiontrainset_pb2,
)
from generated.ccv import objectdetectiontaskrequest_pb2
from generated.ccv import (
objectdetectiontasktransformersobjectdetectortrainparameters_pb2 as odt_params_pb2,
)
from generated.ccv import (
objectdetectiontasktransformersobjectdetectortrainrequest_pb2 as odt_request_pb2,
)

IS_LEGACY = False
except ModuleNotFoundError:
# older versions of Caikit / py to proto create a flat proto structure
# Third Party
from generated import objectdetectiontaskrequest_pb2
from generated import (
objectdetectiontasktransformersobjectdetectortrainrequest_pb2 as odt_request_pb2,
)
from generated import objectdetectiontrainset_pb2

IS_LEGACY = True

# Third Party
import grpc
import numpy as np

# First Party
from caikit.interfaces.vision import data_model as caikit_dm


### build the training request
# Training params; the only thing that changes between newer/older versions of caikit is that
# newer caikit versions pass all of these in under a parameters key and proto type, while old
# versions just pass them in directly.
def get_train_request():
train_param_dict = {
"model_path": os.path.join(MODELS_DIR, DEMO_MODEL_ID),
"train_data": objectdetectiontrainset_pb2.ObjectDetectionTrainSet(
img_dir_path=TRAINING_IMG_DIR,
labels_file=TRAINING_LABELS_FILE,
),
"num_epochs": 10,
"learning_rate": 0.3,
}
if not IS_LEGACY:
train_param_dict = {
"parameters": odt_params_pb2.ObjectDetectionTaskTransformersObjectDetectorTrainParameters(
**train_param_dict
)
}
return odt_request_pb2.ObjectDetectionTaskTransformersObjectDetectorTrainRequest(
model_name="new_model", **train_param_dict
)


### Build the inference request
def get_inference_request():
# For inference, just pick a random training image
random_img_name = np.random.choice(os.listdir(TRAINING_IMG_DIR))
with open(os.path.join(TRAINING_IMG_DIR, random_img_name), "rb") as f:
im_bytes = f.read()

return objectdetectiontaskrequest_pb2.ObjectDetectionTaskRequest(
inputs=im_bytes,
threshold=0,
)


if __name__ == "__main__":
model_id = "new_model"

# Setup the client
port = 8085
channel = grpc.insecure_channel(f"localhost:{port}")

# send train request
request = odt_request_pb2.ObjectDetectionTaskTransformersObjectDetectorTrainRequest(
model_name=model_id,
model_path=os.path.join(MODELS_DIR, DEMO_MODEL_ID),
train_data=objectdetectiontrainset_pb2.ObjectDetectionTrainSet(
img_dir_path=TRAINING_IMG_DIR,
labels_file=TRAINING_LABELS_FILE,
),
num_epochs=10,
learning_rate=0.3,
)
training_stub = (
computervisiontrainingservice_pb2_grpc.ComputerVisionTrainingServiceStub(
channel=channel
)
)
response = training_stub.ObjectDetectionTaskTransformersObjectDetectorTrain(request)
response = training_stub.ObjectDetectionTaskTransformersObjectDetectorTrain(
get_train_request()
)
print("*" * 30)
print("RESPONSE from TRAIN gRPC\n")
print(response)
print("*" * 30)

sleep(5)

# Then make sure we can hit the new model with an inference request...
# TODO: transformers does not seem happy is this is a png
random_img_name = np.random.choice(os.listdir(TRAINING_IMG_DIR))

with open(os.path.join(TRAINING_IMG_DIR, random_img_name), "rb") as f:
im_bytes = f.read()
request = objectdetectiontaskrequest_pb2.ObjectDetectionTaskRequest(
inputs=im_bytes,
threshold=0.0,
)
inference_stub = computervisionservice_pb2_grpc.ComputerVisionServiceStub(
channel=channel
)
# NOTE: This just hits the old model, since normally the loading would be handled by something
# like kserve/model mesh. But it might be more helpful to show how to manually load the model
# and hit it here, just for reference.
response = inference_stub.ObjectDetectionTaskPredict(
request, metadata=[("mm-model-id", DEMO_MODEL_ID)], timeout=1
get_inference_request(), metadata=[("mm-model-id", DEMO_MODEL_ID)], timeout=1
)
print("*" * 30)
print("RESPONSE from INFERENCE gRPC\n")
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ classifiers=[
"License :: OSI Approved :: Apache Software License"
]
dependencies = [
"caikit[runtime-grpc,runtime-http,interfaces-vision]==0.16.0",
"py-to-proto==0.4.1",
"caikit[runtime-grpc,runtime-http,interfaces-vision]>=0.16.0,<0.27.0",
"transformers>=4.27.1,<5",
"torch>2.0,<3",
"timm>=0.9.5,<1"
Expand Down
4 changes: 4 additions & 0 deletions runtime_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ runtime:
training:
save_with_id: False
output_dir: models
# This should be set to something that is NOT in your site packages, otherwise it'll cause
# conflicts leading to import issues. For now, we set ccv for caikit computer vision.
service_generation:
package: ccv

log:
formatter: pretty # optional: log formatter is set to json by default
Expand Down