Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How do I optimize a Python BLS model orchestrating onnx models. #7388

Open
JamesBowerXanda opened this issue Jun 27, 2024 · 1 comment
Open

Comments

@JamesBowerXanda
Copy link

Description
I am using the Sagemaker Triton Inference Server containers to run a MultiModel endpoint. One of the models is a MT5 model. I am trying to optimise for the latency and think I am losing time due to data transfer since when I use an equivalent instance type in a notebook with onnxruntime the generation pipeline takes 0.5 seconds but when I send a request through the triton inference server endpoint (with no other models loaded in) the execution time is around 2.5 seconds.

The model is split into an encoder_model.onnx, decoder_model.onnx and decoder_with_past_model.onnx.

What is the best way to optimise this?

Happy to restructure if there is a better way of doing it but I am running multiple models on the sagme gpu instance.

Triton Information
23.08

Are you using the Triton container or did you build it yourself?
Sagemaker container as mentioned here

To Reproduce
Take a t5 or mt5 model and use optimum to get the constituent onnx model.

optimum-cli export onnx --model google/mt5-small onnx-model --device cuda --optimise O4

Take the encoder_model.onnx, decoder_model.onnx and decoder_with_past_model.onnx files and add them to a triton inference server model repository as onnx model running on GPU. I will put the config.pbtxt files at the bottom.

Create a Python BLS model with the model.py file:

### model.py

import triton_python_backend_utils as pb_utils
import json
import os
from transformers import MT5Tokenizer
import numpy as np
import time
import torch
from torch.utils.dlpack import from_dlpack, to_dlpack
from typing import List, Dict, Tuple, ClassVar
from dataclasses import make_dataclass

class TritonPythonModel:

    def initialize(self, args):
        self.logger = pb_utils.Logger

        self.translated_text_output_type = pb_utils.triton_string_to_numpy(
            pb_utils.get_output_config_by_name(json.loads(args["model_config"]), "translated_text")["data_type"]
        )

        self.logger.log_info("MT5 - Loading Encoder Model")
        self.encoder_name = "encoder"
        if not pb_utils.is_model_ready(model_name = self.encoder_name, model_version = "1"):
            pb_utils.load_model(model_name = self.encoder_name)

        self.logger.log_info("MT5 - Loading Decoder No Past Model")
        self.decoder_no_past_name = "decoder_no_past"
        if not pb_utils.is_model_ready(model_name = self.decoder_no_past_name, model_version = "1"):
            pb_utils.load_model(model_name = self.decoder_no_past_name)

        self.logger.log_info("MT5 - Loading Decoder With Past Model")
        self.decoder_with_past_name = "decoder_with_past"
        if not pb_utils.is_model_ready(model_name = self.decoder_with_past_name, model_version = "1"):
            pb_utils.load_model(model_name = self.decoder_with_past_name)


        self.logger.log_info("Loading Tokenizer")
        tokenizer_dir = os.path.join(args["model_repository"], "1", "tokenizer")
        self.generator = ONNXGenerationModel.from_pretrained(tokenizer_dir, self.encoder_name, self.decoder_no_past_name, self.decoder_with_past_name)

    def execute(self, requests):
        responses = []
        for request in requests:
            
            texts = pb_utils.get_input_tensor_by_name(request, "text").as_numpy().squeeze(1).astype(str).tolist()
            self.logger.log_info(f"Texts: {texts}")
            counter = time.perf_counter()
            texts = self.generator.generate(texts)
            self.logger.log_info(f"Generation Time: {time.perf_counter() - counter}")
            texts_np = np.array(texts).astype(self.translated_text_output_type)

            output_tensors = [
                pb_utils.Tensor("translated_text", texts_np)
            ]

            response = pb_utils.InferenceResponse(
                output_tensors=output_tensors
            )
            responses.append(response)

        return responses

    

class ONNXGenerationModel:

    def __init__(self, tonizer_dir, encoder_name, decoder_no_past_name, decoder_with_past_name, logger=None):
        self.name = "MT5"
        self.logger = logger
        self.encoder_name = encoder_name
        self.decoder_no_past_name = decoder_no_past_name
        self.decoder_with_past_name = decoder_with_past_name
        

        self.decoder_no_past_outputs = [
            "logits"
        ]
        decoder_layers = 8
        encoder_decoder = [
            "encoder", "decoder"
        ]
        key_value = [
            "key", "value"
        ]
        for i in range(decoder_layers):
            for layer in encoder_decoder:
                for kv in key_value:
                    self.decoder_no_past_outputs.append(f"present.{i}.{layer}.{kv}")

        self.decoder_with_past_outputs = [
            "logits"
        ]
        for i in range(decoder_layers):
            for kv in key_value:
                self.decoder_with_past_outputs.append(f"present.{i}.decoder.{kv}")
                    

        self.tokenizer = MT5Tokenizer.from_pretrained(tonizer_dir)

        self.use_cache = True

        self.onnx_type_map = {
            "tensor(int64)": np.int64,
            "tensor(float)": np.float32,
            "tensor(float16)": np.float16,
        }

        self.initial_dims = {
            "past_decoder_sequence_length": 1,
            "encoder_sequence_length_out": 1,
            "decoder_sequence_length": 1
        }

    def generate(self, texts: List[str]):
        start_counter = time.perf_counter() # Start Counter Start
        batch_size = len(texts)
        still_running = [True] * batch_size
        final_outputs = [None] * batch_size
        decoder_times = []
        admin_times = []
        generated_sequences = np.zeros((batch_size, 1), dtype=np.int64)

        running_params = {}

        tokenizer_outputs = self._run_tokenizer(texts)
        running_params = self._update_running_params(running_params, tokenizer_outputs)
        start_time = time.perf_counter() - start_counter # Start Counter End

        encoder_counter = time.perf_counter() # Encoder Counter Start
        self.log_contiguity(running_params)
        encoder_outputs = self._run_encoder(running_params)
        running_params["encoder_attention_mask"] = running_params.pop("attention_mask")
        running_params["input_ids"] = torch.zeros((batch_size, 1), dtype=torch.int64).to("cuda")
        running_params = self._update_running_params(running_params, encoder_outputs)
        encoder_time = time.perf_counter() - encoder_counter # Encoder Counter End

        while any(still_running):
            decoder_counter = time.perf_counter() # Decoder Counter Start
            self.log_contiguity(running_params)
            decoder_outputs = self._run_decoder(running_params)
            decoder_times.append(time.perf_counter() - decoder_counter) # Decoder Counter End
            
            admin_counter = time.perf_counter() # Admin Counter Start
            running_params = self._update_running_params(running_params, decoder_outputs)

            next_tokens = running_params["input_ids"]
            next_tokens_np = np.expand_dims(next_tokens.cpu().numpy()[:, -1], axis=-1)
            generated_sequences = np.concatenate([generated_sequences, next_tokens_np], axis=-1)
            self.log(f"Generated Sequences: {generated_sequences}")
            finished_sequence_indexes = np.where(next_tokens_np[:,0] == self.tokenizer.eos_token_id)[0]
            sequence_to_keep_indexes = np.where(next_tokens_np[:,0] != self.tokenizer.eos_token_id)[0].tolist()

            if len(finished_sequence_indexes) > 0:
                finished_sequences = {}
                for i in finished_sequence_indexes:
                    global_index = self._get_global_index(i, still_running)
                    finished_sequences[global_index] = generated_sequences[i]
                running_params = {
                    name: value[sequence_to_keep_indexes].contiguous() for name, value in running_params.items()
                }
                generated_sequences = np.delete(generated_sequences, finished_sequence_indexes, axis=0)
                for index, sequence in finished_sequences.items():
                    final_outputs[index] = sequence.tolist()
                    still_running[index] = False
            admin_times.append(time.perf_counter() - admin_counter) # Admin Counter End

        decoder_tokenizer_counter = time.perf_counter() # Decoder Tokenizer Counter Start
        final_outputs = self.tokenizer.batch_decode(final_outputs, skip_special_tokens=True)
        decoder_tokenizer_time = time.perf_counter() - decoder_tokenizer_counter # Decoder Tokenizer Counter End
        self.logger.log_info(f"MT5 - Final Outputs: {final_outputs}")
        self.log(f"Encoder Time: {encoder_time}: Decoder Time: {sum(decoder_times)}: Admin Time: {sum(admin_times)}: Start Time: {start_time}, Decoder Tokenizer Time: {decoder_tokenizer_time}")

        return final_outputs
    
    def _run_tokenizer(self, text: List[str]) -> Dict[str, np.ndarray]:
        inputs = self.tokenizer(text, padding=True, return_tensors="pt").to("cuda")
        return dict(inputs)
    
    def _update_running_params(self, running_parameters: Dict[str,np.ndarray], update_params: Dict[str,np.ndarray]) -> Dict[str,np.ndarray]:
        
        for name, value in update_params.items():
            running_parameters[name] = value

        return running_parameters
    
    def _run_encoder(self, encoder_inputs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
        input_ids = encoder_inputs["input_ids"]
        attention_mask = encoder_inputs["attention_mask"]

        input_ids_tensor = pb_utils.Tensor.from_dlpack("input_ids", to_dlpack(input_ids))
        attention_mask_tensor = pb_utils.Tensor.from_dlpack("attention_mask", to_dlpack(attention_mask))

        inference_request = pb_utils.InferenceRequest(
            inputs=[input_ids_tensor, attention_mask_tensor],
            model_name=self.encoder_name,
            model_version=1,
            requested_output_names=["last_hidden_state"]
        )


        response = inference_request.exec()

        if response.has_error():
            raise pb_utils.TritonModelException(
                response.error().message()
            )
        
        last_hidden_state_tensor = pb_utils.get_output_tensor_by_name(response, "last_hidden_state")

        return {
            "encoder_hidden_states": torch.from_dlpack(last_hidden_state_tensor.to_dlpack()),
        }
    
    def _run_decoder(self, running_parameters: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
        if "past_key_values.0.encoder.key" in running_parameters and running_parameters.get("input_ids").shape[1] == 1: 
            with_past = True
        else:
            with_past = False
        input_tensors = []
        for name, value in running_parameters.items():
            self.log(f"Converting {name} to DLPack")
            input_tensors.append(pb_utils.Tensor.from_dlpack(name, to_dlpack(value)))
        input_names = [tensor.name() for tensor in input_tensors]
        self.log(f"Input Tensors: {input_names}")
        inference_request = pb_utils.InferenceRequest(
            inputs=input_tensors,
            model_name=self.decoder_with_past_name if with_past else self.decoder_no_past_name,
            model_version=1,
            requested_output_names=self.decoder_with_past_outputs if with_past else self.decoder_no_past_outputs
        )

        response = inference_request.exec()

        if response.has_error():
            raise pb_utils.TritonModelException(
                "James - " + response.error().message()
            )
        
        decoder_outputs = {
            tensor.name(): torch.from_dlpack(tensor.to_dlpack()) for tensor in response.output_tensors()
        }

        logits = decoder_outputs.pop("logits")[:, -1, :]
        next_tokens = self._get_next_tokens(logits)

        if decoder_outputs["present.0.decoder.key"].shape[2] == 1:
            self.log("REPEAT NO PAST MODEL")
            return {
                "input_ids": torch.concatenate([running_parameters["input_ids"], next_tokens], dim=-1).contiguous(),
            }
        else:
            params = {}
            params["input_ids"] = next_tokens
            for name, value in decoder_outputs.items():
                params[name.replace("present", "past_key_values")] = value

            return params
    
    def _get_next_tokens(self, logits: np.ndarray) -> np.ndarray:
        return torch.unsqueeze(logits.argmax(dim=-1), dim=-1)
    
    def _get_global_index(self, local_index: int, still_running: List[bool]) -> int:
        true_indices = [index for index, value in enumerate(still_running) if value]
        return true_indices[local_index]
    
    def log(self, message: str):
        text = f"{self.name} - {message}"
        self.logger.log_info(text)

    def log_contiguity(self, running_params: Dict[str, np.ndarray]):
        message = "Contiguity Check\n\n"
        for name, value in running_params.items():
            message += f"{name} - Contiguous: {value.is_contiguous()}    Shape: {value.shape}    Stride: {value.stride()}\n"
        self.log(message)


    @classmethod
    def from_pretrained(cls, tokenizer_dir, encoder_name, decoder_no_past_name, decoder_with_past_name):
        return cls(tokenizer_dir, encoder_name, decoder_no_past_name, decoder_with_past_name,logger=pb_utils.Logger)

Run the inference with the model. Below are the relevant config.pbtxt files.

name: "bls"
backend: "python"
max_batch_size: 1
input [
    {
        name: "text"
        data_type: TYPE_STRING
        dims: [ 1 ]
    }
]
output [
	{
		name: "translated_text"
		data_type: TYPE_STRING
		dims: [ 1 ]
	}
]
instance_group {
    count: 1
    kind: KIND_CPU
}
dynamic_batching { }
name: "encoder"
backend: "onnxruntime"
max_batch_size: 16
input [
    {
        name: "input_ids"
        data_type: TYPE_INT64
        dims: [ -1 ]
    },
    {
        name: "attention_mask"
        data_type: TYPE_INT64
        dims: [ -1 ]
    }
]
output [
    {
        name: "last_hidden_state"
        data_type: TYPE_FP32
        dims: [ -1, 512 ]
    }
]
instance_group {
    count: 1
    kind: KIND_GPU
}
dynamic_batching { }
parameters {
  key: "ONNXRUNTIME_LOG_LEVEL"
  value: { string_value: "WARNING" }
}
name: "decoder_no_past"
backend: "onnxruntime"
max_batch_size: 1
input [
	{
		name: "encoder_attention_mask"
		data_type: TYPE_INT64
		dims: [-1]
	},
	{
		name: "encoder_hidden_states"
		data_type: TYPE_FP32
		dims: [-1, 512]
	},
	{
		name: "input_ids"
		data_type: TYPE_INT64
		dims: [-1]
	}
]
output [
	{
		name: "logits"
		data_type: TYPE_FP32
		dims: [-1, 250112]
	},
	{
		name: "present.0.encoder.key"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.0.encoder.value"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.0.decoder.key"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.0.decoder.value"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.1.encoder.key"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.1.encoder.value"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.1.decoder.key"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.1.decoder.value"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.2.encoder.key"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.2.encoder.value"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.2.decoder.key"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.2.decoder.value"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.3.encoder.key"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.3.encoder.value"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.3.decoder.key"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.3.decoder.value"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.4.encoder.key"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.4.encoder.value"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.4.decoder.key"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.4.decoder.value"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.5.encoder.key"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.5.encoder.value"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.5.decoder.key"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.5.decoder.value"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.6.encoder.key"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.6.encoder.value"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.6.decoder.key"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.6.decoder.value"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.7.encoder.key"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.7.encoder.value"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.7.decoder.key"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.7.decoder.value"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	}
]
instance_group {
	kind: KIND_GPU
	count: 1
}
parameters {
  key: "ONNXRUNTIME_LOG_LEVEL"
  value: { string_value: "WARNING" }
}
name: "decoder_with_past"
backend: "onnxruntime"
max_batch_size: 1
input [
	{
		name: "encoder_attention_mask"
		data_type: TYPE_INT64
		dims: [ -1 ]
	},
	{
		name: "encoder_hidden_states"
		data_type: TYPE_FP32
		dims: [-1, 512]
	},
	{
		name: "input_ids"
		data_type: TYPE_INT64
		dims: [1]
	},
	{
		name: "past_key_values.0.encoder.key"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "past_key_values.0.encoder.value"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "past_key_values.0.decoder.key"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "past_key_values.0.decoder.value"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "past_key_values.1.encoder.key"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "past_key_values.1.encoder.value"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "past_key_values.1.decoder.key"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "past_key_values.1.decoder.value"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "past_key_values.2.encoder.key"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "past_key_values.2.encoder.value"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "past_key_values.2.decoder.key"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "past_key_values.2.decoder.value"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "past_key_values.3.encoder.key"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "past_key_values.3.encoder.value"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "past_key_values.3.decoder.key"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "past_key_values.3.decoder.value"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "past_key_values.4.encoder.key"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "past_key_values.4.encoder.value"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "past_key_values.4.decoder.key"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "past_key_values.4.decoder.value"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "past_key_values.5.encoder.key"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "past_key_values.5.encoder.value"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "past_key_values.5.decoder.key"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "past_key_values.5.decoder.value"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "past_key_values.6.encoder.key"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "past_key_values.6.encoder.value"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "past_key_values.6.decoder.key"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "past_key_values.6.decoder.value"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "past_key_values.7.encoder.key"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "past_key_values.7.encoder.value"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "past_key_values.7.decoder.key"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "past_key_values.7.decoder.value"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	}
]
output [
	{
		name: "logits"
		data_type: TYPE_FP32
		dims: [-1, 250112]
	},
	{
		name: "present.0.decoder.key"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.0.decoder.value"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.1.decoder.key"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.1.decoder.value"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.2.decoder.key"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.2.decoder.value"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.3.decoder.key"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.3.decoder.value"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.4.decoder.key"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.4.decoder.value"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.5.decoder.key"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.5.decoder.value"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.6.decoder.key"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.6.decoder.value"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.7.decoder.key"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	},
	{
		name: "present.7.decoder.value"
		data_type: TYPE_FP32
		dims: [6, -1, 64]
	}
]
instance_group {
	kind: KIND_GPU
	count: 1
}
parameters {
  key: "ONNXRUNTIME_LOG_LEVEL"
  value: { string_value: "WARNING" }
}

Expected behavior
I expected the pipeline for a generation request to take approximately the same amount of time rather than 5 times longer.

@geraldstanje
Copy link

geraldstanje commented Jun 28, 2024

hi @JamesBowerXanda there is also a newer version available: 007439368137.dkr.ecr.us-east-2.amazonaws.com/sagemaker-tritonserver:24.03-py3 regarding image: https://github.com/aws/deep-learning-containers/blob/master/available_images.md#nvidia-triton-inference-containers-sm-support-only - did you try it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

No branches or pull requests

2 participants