Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
hjjq committed Sep 6, 2023
1 parent 6fbae2c commit e7f1600
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions tests/operators/test_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,23 +52,26 @@ def torch_conv2d(
[1, 32, 32, 32, 64, 1, 1], # kernel 1,
],
)
@pytest.mark.parametrize("padding", [[0, 0, 0], [1, 1, 0]])
@pytest.mark.parametrize("groups", [1, 2, 4])
@pytest.mark.parametrize("stride", [[1, 1], [2, 3]])
@pytest.mark.parametrize("dilations", [[1, 1], [2, 3]])
@pytest.mark.parametrize("parallel_k", [1, 2, 3])
@pytest.mark.parametrize(
"device", ["cuda"]
) # we don't test for cpu because its quite imprecise in fp16 for larger kernel sizes
def test_conv2d_gemm_fp16(n, c, h, w, oc, kx, ky, groups, stride, dilations, parallel_k, device):
def test_conv2d_gemm_fp16(n, c, h, w, oc, kx, ky, padding, groups, stride, dilations, parallel_k, device):
tol = 0.8
padh, padw, padc = padding
check_binary(
a_shape=[n, c, h, w],
b_shape=[oc, c // groups, kx, ky],
numpy_op=lambda data, weight: torch_conv2d(data, weight, [0, 0], stride, dilations, groups),
numpy_op=lambda data, weight: torch_conv2d(data, weight, [padh, padw], stride, dilations, groups),
hidet_op=lambda data, weight: ops.transpose(
ops.conv2d_gemm_fp16_channel_last(
ops.transpose(data, [0, 2, 3, 1]),
weight,
padding=padding,
stride=stride,
dilations=dilations,
groups=groups,
Expand Down

0 comments on commit e7f1600

Please sign in to comment.