diff --git a/.gitignore b/.gitignore index c077a88..521e4e6 100644 --- a/.gitignore +++ b/.gitignore @@ -172,6 +172,9 @@ demo/train_output/ # Demo runtime files and folders that may be autocreated examples/runtime/generated/*_pb2.py examples/runtime/generated/*_pb2_grpc.py +examples/runtime/generated/ccv +examples/runtime/generated/caikit_data_model +examples/runtime/modules.json examples/runtime/models examples/runtime/protos examples/runtime/train_data diff --git a/examples/runtime/README.md b/examples/runtime/README.md index b574ce7..eb1ae93 100644 --- a/examples/runtime/README.md +++ b/examples/runtime/README.md @@ -60,7 +60,11 @@ class ObjectDetectionResult(DataObjectBase): To export the `.proto` files, run `python3 dump_protos.py`. This will delete your existing `protos` directory and recreate it. Inspecting the contents of the `protos` folder; you should see the following: - A file named `openapi.json` -- Lots of `.proto` files, including an `objectdetectionresult.proto`, which contains the definition for the message type corresponding to the `ObjectDetectionResult` data model class, shown below +- Lots of `.proto` files, including one of the following: + - More recent versions of Caikit: `caikit_data_model.caikit_computer_vision.objectdetectionresult.proto` + - Older versions of Caikit: `objectdetectionresult.proto` + +Regardless of its name, this file contains the `ObjectDetectionResult` data model class, shown below: ```protobuf ... @@ -99,7 +103,10 @@ The declaration of `.run()` for the transformer-based module is shown below: ) -> ObjectDetectionResult: ``` -Where `image_pil_backend.PIL_SOURCE_TYPES` is a union of types that can be resolved into a PIL image, one of which is `bytes`. As such, the `.run()` declaration is compatible with the task declaration; because of the way the task is defined, the runtime expects `inputs` to be of type `bytes`, with `threshold` as an optional float parameter. This is exactly what is defined in the `objectdetectiontaskrequest.proto` file, shown below. +Where `image_pil_backend.PIL_SOURCE_TYPES` is a union of types that can be resolved into a PIL image, one of which is `bytes`. As such, the `.run()` declaration is compatible with the task declaration; because of the way the task is defined, the runtime expects `inputs` to be of type `bytes`, with `threshold` as an optional float parameter. We can find the message type defining exactly this in the task request definition proto file, which is one of the following: + +- More recent versions of Caikit: `ccv.objectdetectiontaskrequest.proto`; here, ccv is the `service_generation` name defined in our runtime config +- Older versions of Caikit: `objectdetectiontaskrequest.proto` ```protobuf ... @@ -137,7 +144,7 @@ Notice that here, `threshold` was changed to a `str` which causes a type collisi Cannot generate task rpc for : Conflicting value types for arg threshold: != ``` -and you won't have a `objectdetectiontaskrequest.proto` in the dumped protos. +and you won't have a `ccv.objectdetectiontaskrequest.proto`/`objectdetectiontaskrequest.proto` in the dumped protos. #### Train Messages @@ -157,13 +164,19 @@ Currently, the stub example for `.train` on the object detector class is defined ``` Where `ObjectDetectionTrainSet` is a data model object. The train message file name is built with the following name, in all lowercase letters: -`{{task_name}}task{{impl_class_name}}trainrequest.proto` +`{{service_gen_name}}.{{task_name}}task{{impl_class_name}}trainrequest.proto` where: +- `{{service_gen_name}}` is the name of your service generation key from your runtime config, e.g., `ccv` As before, this is only used for more recent versions of Caikit; in older versions, the leading `{{service_gen_name}}.` is omitted. - `{{task_name}}` is the name of the task, in this case `objectdetection` - `{{impl_class_name}}` is the name of the implementing module class being considered, in this cases `transformersobjectdetector` -As a result we get a file named: `objectdetectiontasktransformersobjectdetectortrainrequest.proto`, which contains the message definition to trigger a train request. +As a result we get a file named one of the following: + +- More recent versions of Caikit: `ccv.objectdetectiontasktransformersobjectdetectortrainrequest.proto` +- Older versions of Caikit: `objectdetectiontasktransformersobjectdetectortrainrequest.proto` + +which contains the message definition to trigger a train request. ```protobuf message ObjectDetectionTaskTransformersObjectDetectorTrainRequest { diff --git a/examples/runtime/dump_protos.py b/examples/runtime/dump_protos.py index 10784c5..d4b66e9 100644 --- a/examples/runtime/dump_protos.py +++ b/examples/runtime/dump_protos.py @@ -16,6 +16,7 @@ """ # Standard +from inspect import signature from pathlib import Path from shutil import rmtree import json @@ -40,8 +41,19 @@ def export_protos(): rmtree(PROTO_EXPORT_DIR) # Configure caikit runtime caikit.config.configure(config_dict=RUNTIME_CONFIG) - # Dump proto files - dump_grpc_services(output_dir=PROTO_EXPORT_DIR) + # Export gRPC services first... + # grpc service dumper kwargs depend on the version of caikit we are using + grpc_service_dumper_kwargs = { + "output_dir": PROTO_EXPORT_DIR, + "write_modules_file": True, + } + # Only keep things in the signature, e.g., old versions don't take write_modules_file + expected_grpc_params = signature(dump_grpc_services).parameters + grpc_service_dumper_kwargs = { + k: v for k, v in grpc_service_dumper_kwargs.items() if k in expected_grpc_params + } + dump_grpc_services(**grpc_service_dumper_kwargs) + # Then export HTTP services... dump_http_services(output_dir=PROTO_EXPORT_DIR) diff --git a/examples/runtime/run_train_and_inference.py b/examples/runtime/run_train_and_inference.py index b5bd52e..817f9b3 100644 --- a/examples/runtime/run_train_and_inference.py +++ b/examples/runtime/run_train_and_inference.py @@ -30,45 +30,103 @@ ) # pylint: disable=no-name-in-module,import-error -from generated import ( - computervisionservice_pb2_grpc, - computervisiontrainingservice_pb2_grpc, - objectdetectiontaskrequest_pb2, -) -from generated import ( - objectdetectiontasktransformersobjectdetectortrainrequest_pb2 as odt_request_pb2, -) -from generated import objectdetectiontrainset_pb2 +try: + # Third Party + from generated import ( + computervisionservice_pb2_grpc, + computervisiontrainingservice_pb2_grpc, + ) +except ImportError: + raise ImportError("Failed to import cv service; did you compile your protos?") + +# The location of these imported message types depends on the version of Caikit +# that we are using. +try: + # Third Party + from generated.caikit_data_model.caikit_computer_vision import ( + objectdetectiontrainset_pb2, + ) + from generated.ccv import objectdetectiontaskrequest_pb2 + from generated.ccv import ( + objectdetectiontasktransformersobjectdetectortrainparameters_pb2 as odt_params_pb2, + ) + from generated.ccv import ( + objectdetectiontasktransformersobjectdetectortrainrequest_pb2 as odt_request_pb2, + ) + + IS_LEGACY = False +except ModuleNotFoundError: + # older versions of Caikit / py to proto create a flat proto structure + # Third Party + from generated import objectdetectiontaskrequest_pb2 + from generated import ( + objectdetectiontasktransformersobjectdetectortrainrequest_pb2 as odt_request_pb2, + ) + from generated import objectdetectiontrainset_pb2 + + IS_LEGACY = True + +# Third Party import grpc import numpy as np # First Party from caikit.interfaces.vision import data_model as caikit_dm + +### build the training request +# Training params; the only thing that changes between newer/older versions of caikit is that +# newer caikit versions pass all of these in under a parameters key and proto type, while old +# versions just pass them in directly. +def get_train_request(): + train_param_dict = { + "model_path": os.path.join(MODELS_DIR, DEMO_MODEL_ID), + "train_data": objectdetectiontrainset_pb2.ObjectDetectionTrainSet( + img_dir_path=TRAINING_IMG_DIR, + labels_file=TRAINING_LABELS_FILE, + ), + "num_epochs": 10, + "learning_rate": 0.3, + } + if not IS_LEGACY: + train_param_dict = { + "parameters": odt_params_pb2.ObjectDetectionTaskTransformersObjectDetectorTrainParameters( + **train_param_dict + ) + } + return odt_request_pb2.ObjectDetectionTaskTransformersObjectDetectorTrainRequest( + model_name="new_model", **train_param_dict + ) + + +### Build the inference request +def get_inference_request(): + # For inference, just pick a random training image + random_img_name = np.random.choice(os.listdir(TRAINING_IMG_DIR)) + with open(os.path.join(TRAINING_IMG_DIR, random_img_name), "rb") as f: + im_bytes = f.read() + + return objectdetectiontaskrequest_pb2.ObjectDetectionTaskRequest( + inputs=im_bytes, + threshold=0, + ) + + if __name__ == "__main__": - model_id = "new_model" # Setup the client port = 8085 channel = grpc.insecure_channel(f"localhost:{port}") # send train request - request = odt_request_pb2.ObjectDetectionTaskTransformersObjectDetectorTrainRequest( - model_name=model_id, - model_path=os.path.join(MODELS_DIR, DEMO_MODEL_ID), - train_data=objectdetectiontrainset_pb2.ObjectDetectionTrainSet( - img_dir_path=TRAINING_IMG_DIR, - labels_file=TRAINING_LABELS_FILE, - ), - num_epochs=10, - learning_rate=0.3, - ) training_stub = ( computervisiontrainingservice_pb2_grpc.ComputerVisionTrainingServiceStub( channel=channel ) ) - response = training_stub.ObjectDetectionTaskTransformersObjectDetectorTrain(request) + response = training_stub.ObjectDetectionTaskTransformersObjectDetectorTrain( + get_train_request() + ) print("*" * 30) print("RESPONSE from TRAIN gRPC\n") print(response) @@ -76,16 +134,6 @@ sleep(5) - # Then make sure we can hit the new model with an inference request... - # TODO: transformers does not seem happy is this is a png - random_img_name = np.random.choice(os.listdir(TRAINING_IMG_DIR)) - - with open(os.path.join(TRAINING_IMG_DIR, random_img_name), "rb") as f: - im_bytes = f.read() - request = objectdetectiontaskrequest_pb2.ObjectDetectionTaskRequest( - inputs=im_bytes, - threshold=0.0, - ) inference_stub = computervisionservice_pb2_grpc.ComputerVisionServiceStub( channel=channel ) @@ -93,7 +141,7 @@ # like kserve/model mesh. But it might be more helpful to show how to manually load the model # and hit it here, just for reference. response = inference_stub.ObjectDetectionTaskPredict( - request, metadata=[("mm-model-id", DEMO_MODEL_ID)], timeout=1 + get_inference_request(), metadata=[("mm-model-id", DEMO_MODEL_ID)], timeout=1 ) print("*" * 30) print("RESPONSE from INFERENCE gRPC\n") diff --git a/pyproject.toml b/pyproject.toml index 4ef157a..250d541 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,8 +14,7 @@ classifiers=[ "License :: OSI Approved :: Apache Software License" ] dependencies = [ - "caikit[runtime-grpc,runtime-http,interfaces-vision]==0.16.0", - "py-to-proto==0.4.1", + "caikit[runtime-grpc,runtime-http,interfaces-vision]>=0.16.0,<0.27.0", "transformers>=4.27.1,<5", "torch>2.0,<3", "timm>=0.9.5,<1" diff --git a/runtime_config.yaml b/runtime_config.yaml index 99ada58..2cfa699 100644 --- a/runtime_config.yaml +++ b/runtime_config.yaml @@ -14,6 +14,10 @@ runtime: training: save_with_id: False output_dir: models + # This should be set to something that is NOT in your site packages, otherwise it'll cause + # conflicts leading to import issues. For now, we set ccv for caikit computer vision. + service_generation: + package: ccv log: formatter: pretty # optional: log formatter is set to json by default