diff --git a/src/brevitas/graph/standardize.py b/src/brevitas/graph/standardize.py index c3d667aee..7bbbb201d 100644 --- a/src/brevitas/graph/standardize.py +++ b/src/brevitas/graph/standardize.py @@ -59,6 +59,7 @@ def match_node(self, node: Node) -> bool: is_adaptive_2d_mean = ((2, 3) in node.args or [2, 3] in node.args or 'dim' in node.kwargs and (node.kwargs['dim'] == (2, 3) or node.kwargs['dim'] == [2, 3])) + is_adaptive_2d_mean = is_adaptive_2d_mean and not node.kwargs['keepdim'] return spr and is_adaptive_2d_mean def move_node_args_to_kwargs(self, node: Node):