Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPU Scaling issue : multi-gpu inference #7385

Open
lionsheep24 opened this issue Jun 27, 2024 · 0 comments
Open

GPU Scaling issue : multi-gpu inference #7385

lionsheep24 opened this issue Jun 27, 2024 · 0 comments

Comments

@lionsheep24
Copy link

lionsheep24 commented Jun 27, 2024

Description

  • I deployed huggingface model on tritonserver and encountered a performance issue. When I increased the number of GPUs from 2 to 3, the TPS decreased and latency increased. The metrics are collected from locust(HTTP). Here are the details of my setup and observations:

Server setting & observations:

  • Each GPU has a single model instance.
  • GPU utilization was very low.
  • When I add more gpus(2->3), latency increases for every request

What I expected:

  • Based on my understanding, increasing the number of GPUs should increase the number of model instances (i.e., workers), which should linearly increase the throughput. However, in my case, the performance actually dropped. When deploying a model across multiple GPUs, how is the architecture structured? From what I understand, Triton Server queues the requests and distributes them evenly across the models running on multiple GPUs. Is this correct?

Any insights or suggestions on what might be causing this issue and how to resolve it would be greatly appreciated.
Thank you!

Triton Information
What version of Triton are you using?
-> nvcr.io/nvidia/tritonserver:23.10-py3

To Reproduce
1.model.py

import io
import json
from datetime import datetime


import numpy as np
import torch

from transformers import AutoProcessor,AutoModelForSpeechSeq2Seq,pipeline

# triton_python_backend_utils is available in every Triton Python model. You
# need to use this module to create inference requests and responses. It also
# contains some utility functions for extracting information from model_config
# and converting Triton input/output types to numpy types.
import triton_python_backend_utils as pb_utils


class TritonPythonModel:
    """Your Python model must use the same class name. Every Python model
    that is created must have "TritonPythonModel" as the class name.
    """

    def initialize(self, args):
        """`initialize` is called only once when the model is being loaded.
        Implementing `initialize` function is optional. This function allows
        the model to initialize any state associated with this model.
        Parameters
        ----------
        args : dict
          Both keys and values are strings. The dictionary keys and values are:
          * model_config: A JSON string containing the model configuration
          * model_instance_kind: A string containing model instance kind
          * model_instance_device_id: A string containing model instance device ID
          * model_repository: Model repository path
          * model_version: Model version
          * model_name: Model name
        """

        # You must parse model_config. JSON string is not parsed here
        self.model_name = args["model_name"]
        self.model_version = args["model_version"]
        self.model_config = json.loads(args["model_config"])
        
        self.logger = pb_utils.Logger
        self.device_id = args["model_instance_device_id"]
        self.device = torch.device(f"cuda:{self.device_id}" if torch.cuda.is_available() else "cpu")
        # Get OUTPUT0 configuration
        output0_config = pb_utils.get_output_config_by_name(self.model_config, "pipeline_output")

        # Convert Triton types to numpy types
        self.output0_dtype = pb_utils.triton_string_to_numpy(output0_config["data_type"])

        self.init_model(self.model_config["parameters"])

    def init_model(self, parameters):
        for key, value in parameters.items():
            parameters[key] = value["string_value"]
        model_dir = parameters["model_dir"]
        self.processor = AutoProcessor.from_pretrained(model_dir, language="ko", mode="transcribe")
        self.batch_model = AutoModelForSpeechSeq2Seq.from_pretrained(
            model_dir,
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True,
            use_safetensors=True,
            use_flash_attention_2=True,
        )
        self.batch_model.to(self.device)
        self.pipeline:pipeline = pipeline(
            "automatic-speech-recognition",
            model=self.batch_model,
            tokenizer=self.processor.tokenizer,
            feature_extractor=self.processor.feature_extractor,
            max_new_tokens=128,
            chunk_length_s=30,
            batch_size=16,
            return_timestamps=False,
            torch_dtype=torch.float16,
            device=self.device,
            generate_kwargs={"language": "ko", "num_beams": 1, "do_sample": False},
        )
       

    def execute(self, requests):
        """`execute` MUST be implemented in every Python model. `execute`
        function receives a list of pb_utils.InferenceRequest as the only
        argument. This function is called when an inference request is made
        for this model. Depending on the batching configuration (e.g. Dynamic
        Batching) used, `requests` may contain multiple requests. Every
        Python model, must create one pb_utils.InferenceResponse for every
        pb_utils.InferenceRequest in `requests`. If there is an error, you can
        set the error argument when creating a pb_utils.InferenceResponse
        Parameters
        ----------
        requests : list
          A list of pb_utils.InferenceRequest
        Returns
        -------
        list
          A list of pb_utils.InferenceResponse. The length of this list must
          be the same as `requests`
        """

        audio_arrays = []
        responses = []

        # Every Python backend must iterate over everyone of the requests
        # and create a pb_utils.InferenceResponse for each of them.
        self.logger.log_info(
            f"[{self.model_name} v{self.model_version}] Starting execute function with {len(requests)} requests."
        )
        for request in requests:
            # Get INPUT0
            in_0 = pb_utils.get_input_tensor_by_name(request, "pipeline_input")
            audio_array: np.ndarray = in_0.as_numpy()
            self.logger.log_info(f"[{self.model_name} v{self.model_version}] AUDIO ARRAY SHAPE : {audio_array[0].shape}")
            audio_arrays.append(audio_array[0])

        batch_audio_array = [np.array(audio, dtype=np.float16) for audio in audio_arrays]
        self.logger.log_info(
            f"[{self.model_name} v{self.model_version}] BATCH AUDIO ARRAY SHAPE: {np.array(batch_audio_array).shape}"
        )
        start = datetime.now()
        batch_transcripts: torch.tensor = self.pipeline(batch_audio_array)
        self.logger.log_info(
            f"[{self.model_name} v{self.model_version}] BATCH TRANSCRIPTS : {batch_transcripts}"
        )
        end = datetime.now()
        self.logger.log_info(f"[{self.model_name} v{self.model_version}] LATENCY: {(end-start).total_seconds()} AT DEVICE : {self.device}")

        # Create InferenceResponse. You can set an error here in case
        # there was a problem with handling this inference request.
        # Below is an example of how you can set errors in inference
        # response:
        #
        # pb_utils.InferenceResponse(
        # output_tensors=..., TritonError("An error occurred"))
        # You should return a list of pb_utils.InferenceResponse. Length
        # of this list must match the length of `requests` list.

        for i, request in enumerate(requests):
            transcripts: str = batch_transcripts[i]["text"].rstrip("\n")
            transcripts_np = np.array([transcripts], dtype=np.object_)
            out_tensor_0 = pb_utils.Tensor("pipeline_output", transcripts_np)
            inference_response = pb_utils.InferenceResponse(output_tensors=[out_tensor_0])
            responses.append(inference_response)

        return responses

    def finalize(self):
        """`finalize` is called only once when the model is being unloaded.
        Implementing `finalize` function is OPTIONAL. This function allows
        the model to perform any necessary clean ups before exit.
        """
        print("Cleaning up...")
2. config.pbtxt

Model Configuration:
name: "pipeline_model"
backend: "python"
max_batch_size: 256

parameters: [
{
key: "model_dir"
value: {
string_value: "/workspace/models/pipeline_model/1/pipeline_model"
}
}
]

input [
{
name: "pipeline_input"
data_type: TYPE_FP32
dims: [ -1 ]
}
]

output [
{
name: "pipeline_output"
data_type: TYPE_STRING
dims: [ -1 ]
}
]

instance_group: [
{
kind: KIND_GPU,
count: 1,
gpus: [ 0 ]
},
{
kind: KIND_GPU,
count: 1,
gpus: [ 1 ]
},
{
kind: KIND_GPU,
count: 1,
gpus: [ 2 ]
},
]

dynamic_batching {
preferred_batch_size: [ 2, 4 ]
max_queue_delay_microseconds: 200000
}

Expected behavior
Increasing of TPS with more gpus.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

No branches or pull requests

1 participant