Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Unity] [Bugfix] Fix KeyError: 'padding' in _avg_pool2d implementatio…
…n (#15734) This PR fixes a bug in the avg_pool2d implementation. The bug causes a KeyError when running the provided code snippet. The error message is as below: ``` Traceback (most recent call last): ... mod = from_fx(fx_model, input_info) File "/workplace/software/tvm/tvm/python/tvm/relax/frontend/torch/fx_translator.py", line 1449, in from_fx return TorchFXImporter().from_fx( File "/workplace/software/tvm/tvm/python/tvm/relax/frontend/torch/fx_translator.py", line 1336, in from_fx self.env[node] = self.convert_map[func_name](node) File "/workplace/software/tvm/tvm/python/tvm/relax/frontend/torch/fx_translator.py", line 763, in _avg_pool2d stride = node.args[2] if nargs > 2 else node.kwargs["stride"] KeyError: 'stride' ``` And here is the code to reproduce this bug. ```python import torch from torch import fx from torch.nn import Module import tvm from tvm import relax from tvm.relax.frontend.torch import from_fx input_data = torch.randn([1, 1, 3, 3], dtype=torch.float32) para_1 = (2, 1) class avg_pool2d(Module): def forward(self, input): return torch.nn.functional.avg_pool2d(input, para_1,divisor_override=2,) model = avg_pool2d().float() input_data = [input_data] input_info = list(zip([list(inp.shape) for inp in input_data], [str(inp.dtype) for inp in input_data])) fx_model : torch.fx.GraphModule = fx.symbolic_trace(model) with torch.no_grad(): mod = from_fx(fx_model, input_info) ``` The issue arises due to the lack of a check for the existence of the "stride" key in the code. To resolve this bug, I have modified the code to include a check for the existence of the "stride" key before accessing it. I have tested these changes by running the provided code snippet, and the KeyError is no longer thrown. The fix ensures that the code gracefully handles cases where the "stride" key is missing.
- Loading branch information