From e7f1600e8b5a64da714da3572cacaf32252c8342 Mon Sep 17 00:00:00 2001 From: hjjq Date: Wed, 6 Sep 2023 13:44:52 -0400 Subject: [PATCH] fix test --- tests/operators/test_conv2d.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/operators/test_conv2d.py b/tests/operators/test_conv2d.py index b7a88f8bf..fa0ec1897 100644 --- a/tests/operators/test_conv2d.py +++ b/tests/operators/test_conv2d.py @@ -52,6 +52,7 @@ def torch_conv2d( [1, 32, 32, 32, 64, 1, 1], # kernel 1, ], ) +@pytest.mark.parametrize("padding", [[0, 0, 0], [1, 1, 0]]) @pytest.mark.parametrize("groups", [1, 2, 4]) @pytest.mark.parametrize("stride", [[1, 1], [2, 3]]) @pytest.mark.parametrize("dilations", [[1, 1], [2, 3]]) @@ -59,16 +60,18 @@ def torch_conv2d( @pytest.mark.parametrize( "device", ["cuda"] ) # we don't test for cpu because its quite imprecise in fp16 for larger kernel sizes -def test_conv2d_gemm_fp16(n, c, h, w, oc, kx, ky, groups, stride, dilations, parallel_k, device): +def test_conv2d_gemm_fp16(n, c, h, w, oc, kx, ky, padding, groups, stride, dilations, parallel_k, device): tol = 0.8 + padh, padw, padc = padding check_binary( a_shape=[n, c, h, w], b_shape=[oc, c // groups, kx, ky], - numpy_op=lambda data, weight: torch_conv2d(data, weight, [0, 0], stride, dilations, groups), + numpy_op=lambda data, weight: torch_conv2d(data, weight, [padh, padw], stride, dilations, groups), hidet_op=lambda data, weight: ops.transpose( ops.conv2d_gemm_fp16_channel_last( ops.transpose(data, [0, 2, 3, 1]), weight, + padding=padding, stride=stride, dilations=dilations, groups=groups,