From bd51b5b592c7dbba4eec21086464b847971a7c86 Mon Sep 17 00:00:00 2001 From: Thrsu <89128704+Thrsu@users.noreply.github.com> Date: Fri, 15 Sep 2023 01:34:16 +0800 Subject: [PATCH] [Unity] [Bugfix] Fix KeyError: 'padding' in _avg_pool2d implementation (#15734) This PR fixes a bug in the avg_pool2d implementation. The bug causes a KeyError when running the provided code snippet. The error message is as below: ``` Traceback (most recent call last): ... mod = from_fx(fx_model, input_info) File "/workplace/software/tvm/tvm/python/tvm/relax/frontend/torch/fx_translator.py", line 1449, in from_fx return TorchFXImporter().from_fx( File "/workplace/software/tvm/tvm/python/tvm/relax/frontend/torch/fx_translator.py", line 1336, in from_fx self.env[node] = self.convert_map[func_name](node) File "/workplace/software/tvm/tvm/python/tvm/relax/frontend/torch/fx_translator.py", line 763, in _avg_pool2d stride = node.args[2] if nargs > 2 else node.kwargs["stride"] KeyError: 'stride' ``` And here is the code to reproduce this bug. ```python import torch from torch import fx from torch.nn import Module import tvm from tvm import relax from tvm.relax.frontend.torch import from_fx input_data = torch.randn([1, 1, 3, 3], dtype=torch.float32) para_1 = (2, 1) class avg_pool2d(Module): def forward(self, input): return torch.nn.functional.avg_pool2d(input, para_1,divisor_override=2,) model = avg_pool2d().float() input_data = [input_data] input_info = list(zip([list(inp.shape) for inp in input_data], [str(inp.dtype) for inp in input_data])) fx_model : torch.fx.GraphModule = fx.symbolic_trace(model) with torch.no_grad(): mod = from_fx(fx_model, input_info) ``` The issue arises due to the lack of a check for the existence of the "stride" key in the code. To resolve this bug, I have modified the code to include a check for the existence of the "stride" key before accessing it. I have tested these changes by running the provided code snippet, and the KeyError is no longer thrown. The fix ensures that the code gracefully handles cases where the "stride" key is missing. --- .../tvm/relax/frontend/torch/fx_translator.py | 21 +++++++++++++--- tests/python/relax/test_frontend_from_fx.py | 24 +++++++++++++++++++ 2 files changed, 42 insertions(+), 3 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index b5cee77d11..a5c2a68cd8 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -859,9 +859,24 @@ def _avg_pool2d(self, node: fx.node.Node) -> relax.Var: else: nargs = len(node.args) kernel = node.args[1] if nargs > 1 else node.kwargs["kernel_size"] - stride = node.args[2] if nargs > 2 else node.kwargs["stride"] - padding = node.args[3] if nargs > 3 else node.kwargs["padding"] - ceil_mode = node.args[4] if nargs > 4 else node.kwargs["ceil_mode"] + if nargs > 2: + stride = node.args[2] + elif "stride" in node.kwargs.keys(): + stride = node.kwargs["stride"] + else: + stride = None + if nargs > 3: + padding = node.args[3] + elif "padding" in node.kwargs.keys(): + padding = node.kwargs["padding"] + else: + padding = 0 + if nargs > 4: + ceil_mode = node.args[4] + elif "ceil_mode" in node.kwargs.keys(): + ceil_mode = node.kwargs["ceil_mode"] + else: + ceil_mode = False stride = kernel if stride is None else stride diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index ec312767b4..36ef25b025 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -823,9 +823,33 @@ def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")): R.output(gv) return gv + class AvgPool2d4(Module): + def forward(self, input): + return torch.nn.functional.avg_pool2d(input, kernel_size=[2, 1], divisor_override=2) + + @tvm.script.ir_module + class expected3: + @R.function + def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv = R.nn.avg_pool2d( + input_1, + pool_size=[2, 1], + strides=[2, 1], + dilation=[1, 1], + padding=[0, 0, 0, 0], + ceil_mode=False, + layout="NCHW", + out_layout="NCHW", + ) + gv = lv + R.output(gv) + return gv + verify_model(AvgPool2d(), input_info, {}, expected1) verify_model(AvgPool2d2(), input_info, {}, expected2) verify_model(AvgPool2d3(), input_info, {}, expected2) + verify_model(AvgPool2d4(), input_info, {}, expected3) def test_adaptive_avgpool2d():