From 6f0d2f96fd73eeb7e3f0457a3a2b0cdd63027938 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 3 Jul 2024 11:05:14 +0100 Subject: [PATCH] Fix tests --- tests/brevitas/graph/test_transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/brevitas/graph/test_transforms.py b/tests/brevitas/graph/test_transforms.py index c58d9d828..875d5a52c 100644 --- a/tests/brevitas/graph/test_transforms.py +++ b/tests/brevitas/graph/test_transforms.py @@ -287,6 +287,6 @@ def forward(self, x): model = TestModel() assert model.conv.stride == (1, 1) - kwargs = {'stride': lambda module: 2 if module.in_channels == 3 else 1} + kwargs = {'stride': lambda module, name: 2 if module.in_channels == 3 else 1} model = ModuleToModuleByInstance(model.conv, nn.Conv2d, **kwargs).apply(model) assert model.conv.stride == (2, 2)