Skip to content

Commit

Permalink
Update core checks
Browse files Browse the repository at this point in the history
  • Loading branch information
yinggeh committed Jul 23, 2024
1 parent 86a5573 commit b336ecc
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 33 deletions.
54 changes: 36 additions & 18 deletions src/infer_request.cc
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,7 @@ Status
InferenceRequest::Normalize()
{
const inference::ModelConfig& model_config = model_raw_->Config();
const std::string& model_name = ModelName();

// Fill metadata for raw input
if (!raw_input_name_.empty()) {
Expand All @@ -918,7 +919,7 @@ InferenceRequest::Normalize()
std::to_string(original_inputs_.size()) +
") to be deduced but got " +
std::to_string(model_config.input_size()) + " inputs in '" +
ModelName() + "' model configuration");
model_name + "' model configuration");
}
auto it = original_inputs_.begin();
if (raw_input_name_ != it->first) {
Expand Down Expand Up @@ -1036,7 +1037,7 @@ InferenceRequest::Normalize()
Status::Code::INVALID_ARG,
LogRequest() + "input '" + input.Name() +
"' has no shape but model requires batch dimension for '" +
ModelName() + "'");
model_name + "'");
}

if (batch_size_ == 0) {
Expand All @@ -1045,7 +1046,7 @@ InferenceRequest::Normalize()
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "input '" + input.Name() +
"' batch size does not match other inputs for '" + ModelName() +
"' batch size does not match other inputs for '" + model_name +
"'");
}

Expand All @@ -1061,7 +1062,7 @@ InferenceRequest::Normalize()
Status::Code::INVALID_ARG,
LogRequest() + "inference request batch-size must be <= " +
std::to_string(model_config.max_batch_size()) + " for '" +
ModelName() + "'");
model_name + "'");
}

// Verify that each input shape is valid for the model, make
Expand All @@ -1070,17 +1071,17 @@ InferenceRequest::Normalize()
const inference::ModelInput* input_config;
RETURN_IF_ERROR(model_raw_->GetInput(pr.second.Name(), &input_config));

auto& input_id = pr.first;
auto& input_name = pr.first;
auto& input = pr.second;
auto shape = input.MutableShape();

if (input.DType() != input_config->data_type()) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "inference input '" + input_id + "' data-type is '" +
LogRequest() + "inference input '" + input_name + "' data-type is '" +
std::string(
triton::common::DataTypeToProtocolString(input.DType())) +
"', but model '" + ModelName() + "' expects '" +
"', but model '" + model_name + "' expects '" +
std::string(triton::common::DataTypeToProtocolString(
input_config->data_type())) +
"'");
Expand All @@ -1100,7 +1101,7 @@ InferenceRequest::Normalize()
Status::Code::INVALID_ARG,
LogRequest() +
"All input dimensions should be specified for input '" +
input_id + "' for model '" + ModelName() + "', got " +
input_name + "' for model '" + model_name + "', got " +
triton::common::DimsListToString(input.OriginalShape()));
} else if (
(config_dims[i] != triton::common::WILDCARD_DIM) &&
Expand Down Expand Up @@ -1129,8 +1130,8 @@ InferenceRequest::Normalize()
}
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "unexpected shape for input '" + input_id +
"' for model '" + ModelName() + "'. Expected " +
LogRequest() + "unexpected shape for input '" + input_name +
"' for model '" + model_name + "'. Expected " +
triton::common::DimsListToString(full_dims) + ", got " +
triton::common::DimsListToString(input.OriginalShape()) + ". " +
implicit_batch_note);
Expand Down Expand Up @@ -1192,8 +1193,8 @@ InferenceRequest::Normalize()
// (prepend 4 bytes to specify string length), so need to add all the
// first 4 bytes for each element to find expected byte size
if (data_type == inference::DataType::TYPE_STRING) {
RETURN_IF_ERROR(
ValidateBytesInputs(input_id, input, &input_memory_type));
RETURN_IF_ERROR(ValidateBytesInputs(
input_name, input, model_name, &input_memory_type));
// FIXME: Temporarily skips byte size checks for GPU tensors. See
// DLIS-6820.
skip_byte_size_check |=
Expand All @@ -1209,7 +1210,7 @@ InferenceRequest::Normalize()
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "input byte size mismatch for input '" +
input_id + "' for model '" + ModelName() + "'. Expected " +
input_name + "' for model '" + model_name + "'. Expected " +
std::to_string(expected_byte_size) + ", got " +
std::to_string(byte_size));
}
Expand Down Expand Up @@ -1283,7 +1284,8 @@ InferenceRequest::ValidateRequestInputs()

Status
InferenceRequest::ValidateBytesInputs(
const std::string& input_id, const Input& input,
const std::string& input_name, const Input& input,
const std::string& model_name,
TRITONSERVER_MemoryType* buffer_memory_type) const
{
const auto& input_dims = input.ShapeWithBatchDim();
Expand Down Expand Up @@ -1322,13 +1324,28 @@ InferenceRequest::ValidateBytesInputs(
return Status(
Status::Code::INVALID_ARG,
LogRequest() +
"element byte size indicator exceeds the end of the buffer.");
"incomplete string length indicator for inference input '" +
input_name + "' for model '" + model_name + "', expecting " +
std::to_string(sizeof(uint32_t)) + " bytes but only " +
std::to_string(remaining_buffer_size) +
" bytes available. Please make sure the string length "
"indicator is in one buffer.");
}

// Start the next element and reset the remaining element size.
remaining_element_size = *(reinterpret_cast<const uint32_t*>(buffer));
element_checked++;

// Early stop
if (element_checked > element_count) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "unexpected number of string elements " +
std::to_string(element_checked) + " for inference input '" +
input_name + "' for model '" + model_name + "', expecting " +
std::to_string(element_count));
}

// Advance pointer and remainder by the indicator size.
buffer += kElementSizeIndicator;
remaining_buffer_size -= kElementSizeIndicator;
Expand All @@ -1354,16 +1371,17 @@ InferenceRequest::ValidateBytesInputs(
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "expected " + std::to_string(buffer_count) +
" buffers for inference input '" + input_id + "', got " +
std::to_string(buffer_next_idx));
" buffers for inference input '" + input_name + "' for model '" +
model_name + "', got " + std::to_string(buffer_next_idx));
}

// Validate the number of processed elements exactly match expectations.
if (element_checked != element_count) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "expected " + std::to_string(element_count) +
" string elements for inference input '" + input_id + "', got " +
" string elements for inference input '" + input_name +
"' for model '" + model_name + "', got " +
std::to_string(element_checked));
}

Expand Down
1 change: 1 addition & 0 deletions src/infer_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,7 @@ class InferenceRequest {

Status ValidateBytesInputs(
const std::string& input_id, const Input& input,
const std::string& model_name,
TRITONSERVER_MemoryType* buffer_memory_type) const;

// Helpers for pending request metrics
Expand Down
45 changes: 30 additions & 15 deletions src/test/input_byte_size_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,11 @@ char InputByteSizeTest::input_data_string_

TEST_F(InputByteSizeTest, ValidInputByteSize)
{
const char* model_name = "savedmodel_zero_1_float32";
// Create an inference request
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestNew(
&irequest_, server_, "simple", -1 /* model_version */),
&irequest_, server_, model_name, -1 /* model_version */),
"creating inference request");
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestSetReleaseCallback(
Expand Down Expand Up @@ -312,10 +313,11 @@ TEST_F(InputByteSizeTest, ValidInputByteSize)

TEST_F(InputByteSizeTest, InputByteSizeMismatch)
{
const char* model_name = "savedmodel_zero_1_float32";
// Create an inference request
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestNew(
&irequest_, server_, "simple", -1 /* model_version */),
&irequest_, server_, model_name, -1 /* model_version */),
"creating inference request");
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestSetReleaseCallback(
Expand Down Expand Up @@ -353,8 +355,8 @@ TEST_F(InputByteSizeTest, InputByteSizeMismatch)
FAIL_TEST_IF_SUCCESS(
TRITONSERVER_ServerInferAsync(server_, irequest_, nullptr /* trace */),
"expect error with inference request",
"input byte size mismatch for input 'INPUT0' for model 'simple'. "
"Expected 64, got 68");
"input byte size mismatch for input 'INPUT0' for model '" +
std::string{model_name} + "'. Expected 64, got 68");

// Need to manually delete request, otherwise server will not shut down.
FAIL_TEST_IF_ERR(
Expand All @@ -364,10 +366,11 @@ TEST_F(InputByteSizeTest, InputByteSizeMismatch)

TEST_F(InputByteSizeTest, ValidStringInputByteSize)
{
const char* model_name = "savedmodel_zero_1_object";
// Create an inference request
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestNew(
&irequest_, server_, "simple_identity", -1 /* model_version */),
&irequest_, server_, model_name, -1 /* model_version */),
"creating inference request");
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestSetReleaseCallback(
Expand Down Expand Up @@ -424,10 +427,11 @@ TEST_F(InputByteSizeTest, ValidStringInputByteSize)

TEST_F(InputByteSizeTest, StringCountMismatch)
{
const char* model_name = "savedmodel_zero_1_object";
// Create an inference request
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestNew(
&irequest_, server_, "simple_identity", -1 /* model_version */),
&irequest_, server_, model_name, -1 /* model_version */),
"creating inference request");
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestSetReleaseCallback(
Expand Down Expand Up @@ -457,7 +461,8 @@ TEST_F(InputByteSizeTest, StringCountMismatch)
FAIL_TEST_IF_SUCCESS(
TRITONSERVER_ServerInferAsync(server_, irequest_, nullptr /* trace */),
"expect error with inference request",
"expected 3 string elements for inference input 'INPUT0', got 2");
"expected 3 string elements for inference input 'INPUT0' for model '" +
std::string{model_name} + "', got 2");

// Need to manually delete request, otherwise server will not shut down.
FAIL_TEST_IF_ERR(
Expand All @@ -467,7 +472,8 @@ TEST_F(InputByteSizeTest, StringCountMismatch)
// Create an inference request
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestNew(
&irequest_, server_, "simple_identity", -1 /* model_version */),
&irequest_, server_, "savedmodel_zero_1_object",
-1 /* model_version */),
"creating inference request");
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestSetReleaseCallback(
Expand Down Expand Up @@ -495,7 +501,9 @@ TEST_F(InputByteSizeTest, StringCountMismatch)
FAIL_TEST_IF_SUCCESS(
TRITONSERVER_ServerInferAsync(server_, irequest_, nullptr /* trace */),
"expect error with inference request",
"expected 1 string elements for inference input 'INPUT0', got 2");
"unexpected number of string elements 2 for inference input 'INPUT0' for "
"model '" +
std::string{model_name} + "', expecting 1");

// Need to manually delete request, otherwise server will not shut down.
FAIL_TEST_IF_ERR(
Expand All @@ -505,10 +513,11 @@ TEST_F(InputByteSizeTest, StringCountMismatch)

TEST_F(InputByteSizeTest, StringSizeMisalign)
{
const char* model_name = "savedmodel_zero_1_object";
// Create an inference request
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestNew(
&irequest_, server_, "simple_identity", -1 /* model_version */),
&irequest_, server_, model_name, -1 /* model_version */),
"creating inference request");
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestSetReleaseCallback(
Expand Down Expand Up @@ -542,9 +551,13 @@ TEST_F(InputByteSizeTest, StringSizeMisalign)

// Run inference
FAIL_TEST_IF_SUCCESS(
TRITONSERVER_ServerInferAsync(server_, irequest_, nullptr /* trace
*/), "expect error with inference request",
"element byte size indicator exceeds the end of the buffer");
TRITONSERVER_ServerInferAsync(server_, irequest_, nullptr /* trace*/),
"expect error with inference request",
"incomplete string length indicator for inference input 'INPUT0' for "
"model '" +
std::string{model_name} +
"', expecting 4 bytes but only 2 bytes available. Please make sure "
"the string length indicator is in one buffer.");

// Need to manually delete request, otherwise server will not shut down.
FAIL_TEST_IF_ERR(
Expand Down Expand Up @@ -573,7 +586,8 @@ TEST_F(InputByteSizeTest, StringCountMismatchGPU)
// Create an inference request
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestNew(
&irequest_, server_, "simple_identity", -1 /* model_version */),
&irequest_, server_, "savedmodel_zero_1_object",
-1 /* model_version */),
"creating inference request");
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestSetReleaseCallback(
Expand Down Expand Up @@ -629,7 +643,8 @@ TEST_F(InputByteSizeTest, StringCountMismatchGPU)
// Create an inference request
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestNew(
&irequest_, server_, "simple_identity", -1 /* model_version */),
&irequest_, server_, "savedmodel_zero_1_object",
-1 /* model_version */),
"creating inference request");
FAIL_TEST_IF_ERR(
TRITONSERVER_InferenceRequestSetReleaseCallback(
Expand Down

0 comments on commit b336ecc

Please sign in to comment.