Skip to content

Commit

Permalink
ci: Fix shape and reformat free tensor handling in the input byte siz…
Browse files Browse the repository at this point in the history
…e check (#7444)
  • Loading branch information
pskiran1 authored Jul 27, 2024
1 parent aca16ba commit 334f81f
Show file tree
Hide file tree
Showing 11 changed files with 312 additions and 2 deletions.
34 changes: 34 additions & 0 deletions docs/user_guide/model_configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,40 @@ input1: [4, 4, 6] <== shape of this tensor [3]
Currently, only TensorRT supports shape tensors. Read [Shape Tensor I/O](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#shape_tensor_io)
to learn more about shape tensors.

## Non-Linear I/O Formats

For models that process input or output data in non-linear formats, the _is_non_linear_format_io_ property
must be set. The following example model configuration shows how to specify that INPUT0 and INPUT1 use non-linear I/O data formats.

```
name: "mytensorrtmodel"
platform: "tensorrt_plan"
max_batch_size: 8
input [
{
name: "INPUT0"
data_type: TYPE_FP16
dims: [ 3,224,224 ]
is_non_linear_format_io: true
},
{
name: "INPUT1"
data_type: TYPE_FP16
dims: [ 3,224,224 ]
is_non_linear_format_io: true
}
]
output [
{
name: "OUTPUT0"
data_type: TYPE_FP16
dims: [ 1,3 ]
}
]
```

Currently, only TensorRT supports this property. To learn more about I/O formats, refer to the [I/O Formats documentation](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#reformat-free-network-tensors).

## Version Policy

Each model can have one or more
Expand Down
72 changes: 72 additions & 0 deletions qa/L0_input_validation/input_validation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import infer_util as iu
import numpy as np
import tritonclient.grpc as tritongrpcclient
import tritonclient.utils.shared_memory as shm
from tritonclient.utils import InferenceServerException, np_to_triton_dtype


Expand Down Expand Up @@ -211,6 +212,77 @@ def get_input_array(input_size, np_dtype):
err_str,
)

def test_wrong_input_shape_tensor_size(self):
def inference_helper(model_name, batch_size=1):
triton_client = tritongrpcclient.InferenceServerClient("localhost:8001")
if batch_size > 1:
dummy_input_data = np.random.rand(batch_size, 32, 32).astype(np.float32)
else:
dummy_input_data = np.random.rand(32, 32).astype(np.float32)
shape_tensor_data = np.asarray([4, 4], dtype=np.int32)

# Pass incorrect input byte size date for shape tensor
# Use shared memory to bypass the shape check in client library
input_byte_size = (shape_tensor_data.size - 1) * np.dtype(np.int32).itemsize

input_shm_handle = shm.create_shared_memory_region(
"INPUT0_SHM",
"/INPUT0_SHM",
input_byte_size,
)
shm.set_shared_memory_region(
input_shm_handle,
[
shape_tensor_data,
],
)
triton_client.register_system_shared_memory(
"INPUT0_SHM",
"/INPUT0_SHM",
input_byte_size,
)

inputs = [
tritongrpcclient.InferInput(
"DUMMY_INPUT0",
dummy_input_data.shape,
np_to_triton_dtype(np.float32),
),
tritongrpcclient.InferInput(
"INPUT0",
shape_tensor_data.shape,
np_to_triton_dtype(np.int32),
),
]
inputs[0].set_data_from_numpy(dummy_input_data)
inputs[1].set_shared_memory("INPUT0_SHM", input_byte_size)

outputs = [
tritongrpcclient.InferRequestedOutput("DUMMY_OUTPUT0"),
tritongrpcclient.InferRequestedOutput("OUTPUT0"),
]

try:
# Perform inference
with self.assertRaises(InferenceServerException) as e:
triton_client.infer(
model_name=model_name, inputs=inputs, outputs=outputs
)
err_str = str(e.exception)
correct_input_byte_size = (
shape_tensor_data.size * np.dtype(np.int32).itemsize
)
self.assertIn(
f"input byte size mismatch for input 'INPUT0' for model '{model_name}'. Expected {correct_input_byte_size}, got {input_byte_size}",
err_str,
)
finally:
shm.destroy_shared_memory_region(input_shm_handle)
triton_client.unregister_system_shared_memory("INPUT0_SHM")

inference_helper(model_name="plan_nobatch_zero_1_float32_int32")
inference_helper(model_name="plan_zero_1_float32_int32", batch_size=8)


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions qa/L0_input_validation/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ dynamic_batching {
EOL

cp -r $DATADIR/qa_model_repository/graphdef_object_int32_int32 models/.
cp -r $DATADIR/qa_shapetensor_model_repository/plan_nobatch_zero_1_float32_int32 models/.
cp -r $DATADIR/qa_shapetensor_model_repository/plan_zero_1_float32_int32 models/.

SERVER_ARGS="--model-repository=`pwd`/models"
run_server
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
max_batch_size: 8
input [
{
name: "INPUT0"
data_type: TYPE_FP32
dims: [ 16 ]
is_non_linear_format_io: true
},
{
name: "INPUT1"
data_type: TYPE_FP32
dims: [ 16 ]
}
]
output [
{
name: "OUTPUT0"
data_type: TYPE_FP32
dims: [ 16 ]
},
{
name: "OUTPUT1"
data_type: TYPE_FP32
dims: [ 16 ]
}
]
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
'INPUT0' uses a linear IO format, but 'is_non_linear_format_io' is incorrectly set to true in the model configuration.
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
max_batch_size: 8
input [
{
name: "INPUT0"
data_type: TYPE_FP32
dims: [ 16 ]
},
{
name: "INPUT1"
data_type: TYPE_FP32
dims: [ 16 ]
}
]
output [
{
name: "OUTPUT0"
data_type: TYPE_FP32
dims: [ 16 ]
},
{
name: "OUTPUT1"
data_type: TYPE_FP32
dims: [ 16 ]
is_non_linear_format_io: true
}
]
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
'OUTPUT1' uses a linear IO format, but 'is_non_linear_format_io' is incorrectly set to true in the model configuration.
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
name: "no_config_non_linear_format_io"
platform: "tensorrt_plan"
backend: "tensorrt"
version_policy {
latest {
num_versions: 1
}
}
max_batch_size: 8
input {
name: "INPUT0"
data_type: TYPE_FP32
dims: -1
dims: 2
dims: 1
is_non_linear_format_io: true
}
input {
name: "INPUT1"
data_type: TYPE_FP32
dims: -1
dims: 2
dims: 1
is_non_linear_format_io: true
}
output {
name: "OUTPUT0"
data_type: TYPE_FP32
dims: -1
dims: 2
dims: 1
}
output {
name: "OUTPUT1"
data_type: TYPE_FP32
dims: -1
dims: 2
dims: 1
}
optimization {
input_pinned_memory {
enable: true
}
output_pinned_memory {
enable: true
}
}
dynamic_batching {
preferred_batch_size: 8
}
instance_group {
name: "no_config_non_linear_format_io"
kind: KIND_GPU
count: 1
gpus: 0
}
default_model_filename: "model.plan"
13 changes: 12 additions & 1 deletion qa/L0_model_config/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,12 @@ for modelpath in \
autofill_noplatform/tensorrt/bad_input_shape/1 \
autofill_noplatform/tensorrt/bad_input_type/1 \
autofill_noplatform/tensorrt/bad_input_shape_tensor/1 \
autofill_noplatform/tensorrt/bad_input_non_linear_format_io/1 \
autofill_noplatform/tensorrt/bad_output_dims/1 \
autofill_noplatform/tensorrt/bad_output_shape/1 \
autofill_noplatform/tensorrt/bad_output_type/1 \
autofill_noplatform/tensorrt/bad_output_shape_tensor/1 \
autofill_noplatform/tensorrt/bad_outut_non_linear_format_io/1 \
autofill_noplatform/tensorrt/too_few_inputs/1 \
autofill_noplatform/tensorrt/too_many_inputs/1 \
autofill_noplatform/tensorrt/unknown_input/1 \
Expand Down Expand Up @@ -92,6 +94,14 @@ for modelpath in \
$modelpath/.
done

# Copy TensorRT plans with non-linear format IO into the test model repositories.
for modelpath in \
autofill_noplatform_success/tensorrt/no_config_non_linear_format_io/1 ; do
mkdir -p $modelpath
cp /data/inferenceserver/${REPO_VERSION}/qa_trt_format_model_repository/plan_CHW32_LINEAR_float32_float32_float32/1/model.plan \
$modelpath/.
done

# Copy variable-sized TensorRT plans into the test model repositories.
for modelpath in \
autofill_noplatform_success/tensorrt/no_name_platform_variable/1 \
Expand Down Expand Up @@ -593,7 +603,8 @@ for TARGET_DIR in `ls -d autofill_noplatform_success/*/*`; do
# that the directory is an entire model repository.
rm -fr models && mkdir models
if [ -f ${TARGET_DIR}/config.pbtxt ] || [ "$TARGET" = "no_config" ] \
|| [ "$TARGET" = "no_config_variable" ] || [ "$TARGET" = "no_config_shape_tensor" ] ; then
|| [ "$TARGET" = "no_config_variable" ] || [ "$TARGET" = "no_config_shape_tensor" ] \
|| [ "$TARGET" = "no_config_non_linear_format_io" ] ; then
cp -r ${TARGET_DIR} models/.
else
cp -r ${TARGET_DIR}/* models/.
Expand Down
2 changes: 1 addition & 1 deletion qa/L0_trt_reformat_free/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ if [ $? -ne 0 ]; then
cat $CLIENT_LOG
RET=1
else
check_test_results $TEST_RESULT_FILE 4
check_test_results $TEST_RESULT_FILE 6
if [ $? -ne 0 ]; then
cat $CLIENT_LOG
echo -e "\n***\n*** Test Result Verification Failed\n***"
Expand Down
80 changes: 80 additions & 0 deletions qa/L0_trt_reformat_free/trt_reformat_free_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import test_util as tu
import tritonclient.http as tritonhttpclient
import tritonclient.utils.shared_memory as shm
from tritonclient.utils import InferenceServerException


def div_up(a, b):
Expand Down Expand Up @@ -141,6 +142,41 @@ def test_nobatch_chw2_input(self):
"OUTPUT0 expected: {}, got {}".format(expected_output1_np, output1_np),
)

def test_wrong_nobatch_chw2_input(self):
model_name = "plan_nobatch_CHW2_LINEAR_float16_float16_float16"
input_np = np.arange(26, dtype=np.float16).reshape((13, 2, 1))

# Use shared memory to bypass the shape check in client library, because
# for non-linear format tensor, the data buffer is padded and thus the
# data byte size may not match what is calculated from tensor shape
inputs = []
inputs.append(tritonhttpclient.InferInput("INPUT0", [13, 2, 1], "FP16"))
# Send the original size input instead of the reformatted size input.
self.add_reformat_free_data_as_shared_memory("input0", inputs[-1], input_np)

inputs.append(tritonhttpclient.InferInput("INPUT1", [13, 2, 1], "FP16"))
# Send the original size input instead of the reformatted size input.
self.add_reformat_free_data_as_shared_memory("input1", inputs[-1], input_np)

outputs = []
outputs.append(
tritonhttpclient.InferRequestedOutput("OUTPUT0", binary_data=True)
)
outputs.append(
tritonhttpclient.InferRequestedOutput("OUTPUT1", binary_data=True)
)

with self.assertRaises(InferenceServerException) as e:
self.triton_client.infer(
model_name=model_name, inputs=inputs, outputs=outputs
)

err_str = str(e.exception)
self.assertIn(
"input byte size mismatch for input 'INPUT0' for model 'plan_nobatch_CHW2_LINEAR_float16_float16_float16'. Expected 56, got 52",
err_str,
)

def test_chw2_input(self):
model_name = "plan_CHW2_LINEAR_float16_float16_float16"
for bs in [1, 8]:
Expand Down Expand Up @@ -186,6 +222,50 @@ def test_chw2_input(self):
"OUTPUT0 expected: {}, got {}".format(expected_output1_np, output1_np),
)

def test_wrong_chw2_input(self):
model_name = "plan_CHW2_LINEAR_float16_float16_float16"
for bs in [1, 8]:
input_np = np.arange(26 * bs, dtype=np.float16).reshape((bs, 13, 2, 1))

# Use shared memory to bypass the shape check in client library,
# because for non-linear format tensor, the data buffer is padded
# and thus the data byte size may not match what is calculated from
# tensor shape
inputs = []
inputs.append(tritonhttpclient.InferInput("INPUT0", [bs, 13, 2, 1], "FP16"))
# Send the original size input instead of the reformatted size input.
self.add_reformat_free_data_as_shared_memory(
"input0" + str(bs), inputs[-1], input_np
)

inputs.append(tritonhttpclient.InferInput("INPUT1", [bs, 13, 2, 1], "FP16"))
# Send the original size input instead of the reformatted size input.
self.add_reformat_free_data_as_shared_memory(
"input1" + str(bs), inputs[-1], input_np
)

outputs = []
outputs.append(
tritonhttpclient.InferRequestedOutput("OUTPUT0", binary_data=True)
)
outputs.append(
tritonhttpclient.InferRequestedOutput("OUTPUT1", binary_data=True)
)

with self.assertRaises(InferenceServerException) as e:
self.triton_client.infer(
model_name=model_name, inputs=inputs, outputs=outputs
)
err_str = str(e.exception)
# reformatted input size - (bs, 14, 2, 1) * size(float16)
expected_size = bs * 28 * 2
# original input size - (bs, 13, 2, 1) * size(float16)
received_size = bs * 26 * 2
self.assertIn(
f"input byte size mismatch for input 'INPUT0' for model 'plan_CHW2_LINEAR_float16_float16_float16'. Expected {expected_size}, got {received_size}",
err_str,
)

def test_nobatch_chw32_input(self):
model_name = "plan_nobatch_CHW32_LINEAR_float32_float32_float32"
input_np = np.arange(26, dtype=np.float32).reshape((13, 2, 1))
Expand Down

0 comments on commit 334f81f

Please sign in to comment.