Skip to content

Commit

Permalink
[Unity] [Bugfix] Fix KeyError: 'padding' in _avg_pool2d implementatio…
Browse files Browse the repository at this point in the history
…n (#15734)

This PR fixes a bug in the avg_pool2d implementation. The bug causes a KeyError when running the provided code snippet. The error message is as below:
```
Traceback (most recent call last):
    ...
    mod = from_fx(fx_model, input_info)
  File "/workplace/software/tvm/tvm/python/tvm/relax/frontend/torch/fx_translator.py", line 1449, in from_fx
    return TorchFXImporter().from_fx(
  File "/workplace/software/tvm/tvm/python/tvm/relax/frontend/torch/fx_translator.py", line 1336, in from_fx
    self.env[node] = self.convert_map[func_name](node)
  File "/workplace/software/tvm/tvm/python/tvm/relax/frontend/torch/fx_translator.py", line 763, in _avg_pool2d
    stride = node.args[2] if nargs > 2 else node.kwargs["stride"]
KeyError: 'stride'
```
And here is the code to reproduce this bug.
```python
import torch
from torch import fx
from torch.nn import Module
import tvm
from tvm import relax
from tvm.relax.frontend.torch import from_fx

input_data = torch.randn([1, 1, 3, 3], dtype=torch.float32)
para_1 = (2, 1)
class avg_pool2d(Module):
    def forward(self, input):
        return torch.nn.functional.avg_pool2d(input, para_1,divisor_override=2,)
model = avg_pool2d().float()
input_data = [input_data]
input_info = list(zip([list(inp.shape) for inp in input_data], [str(inp.dtype) for inp in input_data]))
fx_model : torch.fx.GraphModule = fx.symbolic_trace(model)
with torch.no_grad():
    mod = from_fx(fx_model, input_info)
```
The issue arises due to the lack of a check for the existence of the "stride" key in the code. To resolve this bug, I have modified the code to include a check for the existence of the "stride" key before accessing it. 

I have tested these changes by running the provided code snippet, and the KeyError is no longer thrown. The fix ensures that the code gracefully handles cases where the "stride" key is missing.
  • Loading branch information
Thrsu authored Sep 14, 2023
1 parent 40b9a92 commit bd51b5b
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 3 deletions.
21 changes: 18 additions & 3 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,9 +859,24 @@ def _avg_pool2d(self, node: fx.node.Node) -> relax.Var:
else:
nargs = len(node.args)
kernel = node.args[1] if nargs > 1 else node.kwargs["kernel_size"]
stride = node.args[2] if nargs > 2 else node.kwargs["stride"]
padding = node.args[3] if nargs > 3 else node.kwargs["padding"]
ceil_mode = node.args[4] if nargs > 4 else node.kwargs["ceil_mode"]
if nargs > 2:
stride = node.args[2]
elif "stride" in node.kwargs.keys():
stride = node.kwargs["stride"]
else:
stride = None
if nargs > 3:
padding = node.args[3]
elif "padding" in node.kwargs.keys():
padding = node.kwargs["padding"]
else:
padding = 0
if nargs > 4:
ceil_mode = node.args[4]
elif "ceil_mode" in node.kwargs.keys():
ceil_mode = node.kwargs["ceil_mode"]
else:
ceil_mode = False

stride = kernel if stride is None else stride

Expand Down
24 changes: 24 additions & 0 deletions tests/python/relax/test_frontend_from_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,9 +823,33 @@ def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")):
R.output(gv)
return gv

class AvgPool2d4(Module):
def forward(self, input):
return torch.nn.functional.avg_pool2d(input, kernel_size=[2, 1], divisor_override=2)

@tvm.script.ir_module
class expected3:
@R.function
def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")):
with R.dataflow():
lv = R.nn.avg_pool2d(
input_1,
pool_size=[2, 1],
strides=[2, 1],
dilation=[1, 1],
padding=[0, 0, 0, 0],
ceil_mode=False,
layout="NCHW",
out_layout="NCHW",
)
gv = lv
R.output(gv)
return gv

verify_model(AvgPool2d(), input_info, {}, expected1)
verify_model(AvgPool2d2(), input_info, {}, expected2)
verify_model(AvgPool2d3(), input_info, {}, expected2)
verify_model(AvgPool2d4(), input_info, {}, expected3)


def test_adaptive_avgpool2d():
Expand Down

0 comments on commit bd51b5b

Please sign in to comment.