From 6fc2761ae76e74f0d6cf9af948923691913b3ba7 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Thu, 26 Sep 2024 17:56:59 +0900 Subject: [PATCH] Tighten assert condition in graph break tests (#458) Part of #452. --- test/nn/models/test_compile.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/test/nn/models/test_compile.py b/test/nn/models/test_compile.py index ec53c0d7f..dc22527cf 100644 --- a/test/nn/models/test_compile.py +++ b/test/nn/models/test_compile.py @@ -14,7 +14,7 @@ from torch_frame.testing import withPackage -@withPackage("torch>=2.1.0") +@withPackage("torch>=2.5.0") @pytest.mark.parametrize( "model_cls, model_kwargs, stypes, expected_graph_breaks", [ @@ -34,7 +34,7 @@ gamma=0.1, ), None, - 7, + 2, id="TabNet", ), pytest.param( @@ -47,21 +47,21 @@ ffn_dropout=0.5, ), None, - 4, + 0, id="TabTransformer", ), pytest.param( Trompt, dict(channels=8, num_prompts=2), None, - 16, + 4, id="Trompt", ), pytest.param( ExcelFormer, dict(in_channels=8, num_cols=3, num_heads=1), [stype.numerical], - 4, + 1, id="ExcelFormer", ), ], @@ -89,4 +89,5 @@ def test_compile_graph_break( **model_kwargs, ) explanation = torch._dynamo.explain(model)(tf) - assert explanation.graph_break_count <= expected_graph_breaks + graph_breaks = explanation.graph_break_count + assert graph_breaks == expected_graph_breaks