Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Triton inference is slower than tensorRT #7394

Open
namogg opened this issue Jun 30, 2024 · 0 comments
Open

Triton inference is slower than tensorRT #7394

namogg opened this issue Jun 30, 2024 · 0 comments

Comments

@namogg
Copy link

namogg commented Jun 30, 2024

Description
Im using a simple client inference class base on client example. My tensorRT inference with batchsize 10 with 150ms and my triton with tensorRT backend took 1100ms. This is my client:

import os
import sys
import numpy as np
import tritonclient.grpc as grpcclient
import tritonclient.grpc.model_config_pb2 as mc
import tritonclient.http as httpclient
from PIL import Image
from tritonclient.utils import InferenceServerException, triton_to_np_dtype

if sys.version_info >= (3, 0):
    import queue
else:
    import Queue as queue


class UserData:
    def __init__(self):
        self._completed_requests = queue.Queue()


class InferenceClient:
    def __init__(self, model_name, image_filename, url="localhost:8000", protocol="HTTP",
                 model_version="", batch_size=1, classes=1, scaling="NONE", async_set=False,
                 streaming=False, verbose=False):
        self.model_name = model_name
        self.image_filename = image_filename
        self.url = url
        self.protocol = protocol
        self.model_version = model_version
        self.batch_size = batch_size
        self.classes = classes
        self.scaling = scaling
        self.async_set = async_set
        self.streaming = streaming
        self.verbose = verbose
        self.user_data = UserData()

        if self.streaming and self.protocol.lower() != "grpc":
            raise Exception("Streaming is only allowed with gRPC protocol")

        try:
            if self.protocol.lower() == "grpc":
                self.client = grpcclient.InferenceServerClient(
                    url=self.url, verbose=self.verbose
                )
            else:
                concurrency = 20 if self.async_set else 1
                self.client = httpclient.InferenceServerClient(
                    url=self.url, verbose=self.verbose, concurrency=concurrency
                )
        except Exception as e:
            print("Client creation failed: " + str(e))
            sys.exit(1)

        self.model_metadata = self.get_model_metadata()
        self.model_config = self.get_model_config()

        if self.protocol.lower() == "grpc":
            self.model_config = self.model_config.config
        else:
            self.model_metadata, self.model_config = self.convert_http_metadata_config(
                self.model_metadata, self.model_config
            )

        (
            self.max_batch_size,
            self.input_name,
            self.output_name,
            self.c,
            self.h,
            self.w,
            self.format,
            self.dtype,
        ) = self.parse_model(self.model_metadata, self.model_config)

    def get_model_metadata(self):
        try:
            return self.client.get_model_metadata(
                model_name=self.model_name, model_version=self.model_version
            )
        except InferenceServerException as e:
            print("Failed to retrieve the metadata: " + str(e))
            sys.exit(1)

    def get_model_config(self):
        try:
            return self.client.get_model_config(
                model_name=self.model_name, model_version=self.model_version
            )
        except InferenceServerException as e:
            print("Failed to retrieve the config: " + str(e))
            sys.exit(1)

    def parse_model(self, model_metadata, model_config):

        input_metadata = model_metadata.inputs[0]
        input_config = model_config.input[0]
        output_metadatas = model_metadata.outputs
        output_names = []
        for output_metadata in output_metadatas:
            if output_metadata.datatype != "FP32":
                raise Exception(
                    "Expecting output datatype to be FP32, model '"
                    + model_metadata.name
                    + "' output type is "
                    + output_metadata.datatype
                )
            output_names.append(output_metadata.name)

        input_batch_dim = model_config.max_batch_size > 0
        expected_input_dims = 3 + (1 if input_batch_dim else 0)
        if len(input_metadata.shape) != expected_input_dims:
            raise Exception(
                "Expecting input to have {} dimensions, model '{}' input has {}".format(
                    expected_input_dims, model_metadata.name, len(input_metadata.shape)
                )
            )

        if type(input_config.format) == str:
            FORMAT_ENUM_TO_INT = dict(mc.ModelInput.Format.items())
            input_config.format = FORMAT_ENUM_TO_INT[input_config.format]
        input_config.format = mc.ModelInput.FORMAT_NHWC
        if (input_config.format != mc.ModelInput.FORMAT_NCHW) and (input_config.format != mc.ModelInput.FORMAT_NHWC):
            raise Exception(
                "Unexpected input format "
                + mc.ModelInput.Format.Name(input_config.format)
                + ", expecting "
                + mc.ModelInput.Format.Name(mc.ModelInput.FORMAT_NCHW)
                + " or "
                + mc.ModelInput.Format.Name(mc.ModelInput.FORMAT_NHWC)
            )

        # if input_config.format == mc.ModelInput.FORMAT_NHWC:
        h = input_metadata.shape[1 if input_batch_dim else 0]
        w = input_metadata.shape[2 if input_batch_dim else 1]
        c = input_metadata.shape[3 if input_batch_dim else 2]
        # else:
        #     c = input_metadata.shape[1 if input_batch_dim else 0]
        #     h = input_metadata.shape[2 if input_batch_dim else 1]
        #     w = input_metadata.shape[3 if input_batch_dim else 2]

        return (
            model_config.max_batch_size,
            input_metadata.name,
            output_names,
            c,
            h,
            w,
            input_config.format,
            input_metadata.datatype,
        )

    def preprocess(self, img):
        if self.c == 1:
            sample_img = img.convert("L")
        else:
            sample_img = img.convert("RGB")

        resized_img = sample_img.resize((self.w, self.h), Image.BILINEAR)
        resized = np.array(resized_img)
        if resized.ndim == 2:
            resized = resized[:, :, np.newaxis]

        npdtype = triton_to_np_dtype(self.dtype)
        typed = resized.astype(npdtype)

        if self.scaling == "INCEPTION":
            scaled = (typed / 127.5) - 1
        elif self.scaling == "VGG":
            if self.c == 1:
                scaled = typed - np.asarray((128,), dtype=npdtype)
            else:
                scaled = typed - np.asarray((123, 117, 104), dtype=npdtype)
        else:
            scaled = typed

        if self.format == mc.ModelInput.FORMAT_NCHW:
            ordered = np.transpose(scaled, (2, 0, 1))
        else:
            ordered = scaled

        return ordered

    def postprocess(self, results):
        output_array = results.as_numpy(self.output_name)
        if self.max_batch_size > 0 and len(output_array) != self.batch_size:
            raise Exception("Expected {} results, got {}".format(self.batch_size, len(output_array)))

        for result in output_array:
            if output_array.dtype.type == np.object_:
                cls = "".join(chr(x) for x in result).split(":")
            else:
                cls = result.split(":")
            print("    {} ({}) = {}".format(cls[0], cls[1], cls[2]))

    def request_generator(self, batched_image_data):
        protocol = self.protocol.lower()
        client = grpcclient if protocol == "grpc" else httpclient

        inputs = [client.InferInput(self.input_name, batched_image_data.shape, self.dtype)]
        inputs[0].set_data_from_numpy(batched_image_data)
    
        outputs = [client.InferRequestedOutput(output_name, binary_data=True) for output_name in self.output_name]

        yield inputs, outputs, self.model_name, self.model_version

    def convert_http_metadata_config(self, _metadata, _config):
        try:
            from attrdict import AttrDict
        except ImportError:
            import collections
            import collections.abc

            for type_name in collections.abc.__all__:
                setattr(collections, type_name, getattr(collections.abc, type_name))
            from attrdict import AttrDict

        return AttrDict(_metadata), AttrDict(_config)

    def run_inference(self):
        filenames = []
        if os.path.isdir(self.image_filename):
            filenames = [
                os.path.join(self.image_filename, f)
                for f in os.listdir(self.image_filename)
                if os.path.isfile(os.path.join(self.image_filename, f))
            ]
        else:
            filenames = [self.image_filename]

        filenames.sort()

        image_data = []
        for filename in filenames:
            img = Image.open(filename)
            image_data.append(self.preprocess(img))

        requests = []
        responses = []
        request_ids = []
        image_idx = 0
        last_request = False

        async_requests = []
        sent_count = 0

        if self.streaming:
            self.client.start_stream(partial(completion_callback, self.user_data))

        while not last_request:
            input_filenames = []
            repeated_image_data = []

            for idx in range(self.batch_size):
                input_filenames.append(filenames[image_idx])
                repeated_image_data.append(image_data[image_idx])
                image_idx = (image_idx + 1) % len(filenames)
                if image_idx == 0:
                    last_request = True

            batched_image_data = np.stack(repeated_image_data, axis=0)
            for inputs, outputs, model_name, model_version in self.request_generator(
                            batched_image_data
                        ):
                
                if self.streaming:
                    self.client.async_stream_infer(
                        self.model_name,
                        inputs,
                        partial(completion_callback, self.user_data),
                        self.model_version,
                        self.classes,
                    )
                elif self.async_set:
                    if self.protocol.lower() == "grpc":
                        async_requests.append(
                            self.client.async_infer(
                                self.model_name,
                                inputs,
                            request_id=str(sent_count),
                            model_version= self.model_version,
                            outputs = outputs
                            )
                        )
                    else:
                        async_requests.append(
                            self.client.async_infer(
                                self.model_name,
                                inputs,
                            request_id=str(sent_count),
                            model_version= self.model_version,
                            outputs = outputs
                            )
                        )
                else:
                    responses.append(
                        self.client.infer(
                            self.model_name,
                            inputs,
                            request_id=str(sent_count),
                            model_version= self.model_version,
                            outputs = outputs
                        )
                    )
                    request_ids.append(sent_count)

                sent_count += 1

        if self.streaming:
            self.client.stop_stream()

        if self.async_set:
            for async_request in async_requests:
                responses.append(async_request.get_result())

        return responses

Triton Information
What version of Triton are you using?
2.42

Are you using the Triton container or did you build it yourself?
container
To Reproduce
Model config:

name: "retina_1280_16FP"
platform: "tensorrt_plan"
max_batch_size: 10
input [
  {
    name: "input"
    data_type: TYPE_FP32
    format: FORMAT_NHWC
    dims: [704, 1280 ,3]
  }
]
instance_group [
    {
      count: 1
      kind: KIND_GPU
      gpus: [ 0 ]
    }
  ]
dynamic_batching {
  preferred_batch_size: [ 4, 8, 10 ]
  max_queue_delay_microseconds: 1
}

Perf analyzer:

root@docker-desktop:/workspace# perf_analyzer -m retina_1280_16FP --concurrency-range 1:3 -i grpc --async -b 10
*** Measurement Settings ***
  Batch size: 10
  Service Kind: Triton
  Using "time_windows" mode for stabilization
  Measurement window: 5000 msec
  Latency limit: 0 msec
  Concurrency limit: 3 concurrent requests
  Using asynchronous calls for inference
  Stabilizing using average latency

Request concurrency: 1
  Client:
    Request count: 82
    Throughput: 45.5499 infer/sec
    Avg latency: 218899 usec (standard deviation 19045 usec)
    p50 latency: 215273 usec
    p90 latency: 235609 usec
    p95 latency: 245463 usec
    p99 latency: 290641 usec
    Avg gRPC time: 218885 usec ((un)marshal request/response 6746 usec + response wait 212139 usec)
  Server:
    Inference count: 820
    Execution count: 82
    Successful request count: 82
    Avg request latency: 143637 usec (overhead 6674 usec + queue 170 usec + compute input 18415 usec + compute infer 89239 usec + compute output 29138 usec)

Request concurrency: 2
  Client:
    Request count: 157
    Throughput: 87.1606 infer/sec
    Avg latency: 230597 usec (standard deviation 28602 usec)
    p50 latency: 222884 usec
    p90 latency: 270551 usec
    p95 latency: 297211 usec
    p99 latency: 313338 usec
    Avg gRPC time: 230583 usec ((un)marshal request/response 7123 usec + response wait 223460 usec)
  Server:
    Inference count: 1570
    Execution count: 157
    Successful request count: 157
    Avg request latency: 147526 usec (overhead 6904 usec + queue 358 usec + compute input 11582 usec + compute infer 89941 usec + compute output 38740 usec)

Request concurrency: 3
  Client:
    Request count: 186
    Throughput: 103.187 infer/sec
    Avg latency: 288448 usec (standard deviation 30406 usec)
    p50 latency: 279053 usec
    p90 latency: 327908 usec
    p95 latency: 348555 usec
    p99 latency: 402516 usec
    Avg gRPC time: 288393 usec ((un)marshal request/response 10181 usec + response wait 278212 usec)
  Server:
    Inference count: 1870
    Execution count: 187
    Successful request count: 187
    Avg request latency: 182586 usec (overhead 7013 usec + queue 8149 usec + compute input 2156 usec + compute infer 91174 usec + compute output 74092 usec)

Inferences/Second vs. Client Average Batch Latency
Concurrency: 1, throughput: 45.5499 infer/sec, latency 218899 usec
Concurrency: 2, throughput: 87.1606 infer/sec, latency 230597 usec
Concurrency: 3, throughput: 103.187 infer/sec, latency 288448 usec

Expected behavior
Triton should be able to run the same speed with tensorRT and even better with concurrent

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

No branches or pull requests

1 participant