From 35c93de62918c3163fe9356a0942d95251dbead2 Mon Sep 17 00:00:00 2001 From: pranavm Date: Wed, 18 Dec 2024 12:38:31 -0800 Subject: [PATCH] Fixes a flaky test by increasing tolerance --- tripy/tests/integration/test_conv_transpose.py | 9 ++++----- tripy/tests/integration/test_sequential.py | 3 +-- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/tripy/tests/integration/test_conv_transpose.py b/tripy/tests/integration/test_conv_transpose.py index 816fe6b7f..40f34a906 100644 --- a/tripy/tests/integration/test_conv_transpose.py +++ b/tripy/tests/integration/test_conv_transpose.py @@ -280,14 +280,13 @@ def test_transposed_equivalency(self, torch_dtype, tp_dtype, eager_or_compiled): output = eager_or_compiled(conv_layer, input) output_transpose = eager_or_compiled(conv_transpose_layer, input) - rtol = 2e-7 if tp_dtype == tp.float32 else 9e-4 - assert tp.allclose(output, tp.Tensor(expected), rtol=rtol, atol=1e-5) + assert tp.allclose(output, tp.Tensor(expected), rtol=1e-2, atol=1e-4) assert output.shape == list(expected.shape) - assert tp.allclose(output_transpose, tp.Tensor(expected_transpose), rtol=rtol, atol=1e-5) + assert tp.allclose(output_transpose, tp.Tensor(expected_transpose), rtol=1e-2, atol=1e-4) assert output_transpose.shape == list(expected_transpose.shape) - assert tp.allclose(output, output_transpose, rtol=rtol, atol=1e-5) + assert tp.allclose(output, output_transpose, rtol=1e-2, atol=1e-4) assert output.shape == output_transpose.shape - assert tp.allclose(tp.Tensor(expected), tp.Tensor(expected_transpose), rtol=rtol, atol=1e-5) + assert tp.allclose(tp.Tensor(expected), tp.Tensor(expected_transpose), rtol=1e-2, atol=1e-4) assert list(expected.shape) == list(expected_transpose.shape) @pytest.mark.parametrize("test_case", test_cases_transpose_downscale) diff --git a/tripy/tests/integration/test_sequential.py b/tripy/tests/integration/test_sequential.py index 812951e13..b6f5e9c71 100644 --- a/tripy/tests/integration/test_sequential.py +++ b/tripy/tests/integration/test_sequential.py @@ -102,8 +102,7 @@ def test_nested_forward_pass_accuracy(self, eager_or_compiled): with torch.no_grad(): torch_output = torch_model(input_tensor) - rtol_ = 2e-6 - assert torch.allclose(torch.from_dlpack(tp_output), torch_output, rtol=rtol_) + assert torch.allclose(torch.from_dlpack(tp_output), torch_output, rtol=1e-4, atol=1e-4) def test_basic_state_dict_comparison(self): torch_model = torch.nn.Sequential(