diff --git a/src/brevitas/graph/standardize.py b/src/brevitas/graph/standardize.py index 7bbbb201d..93e99eac9 100644 --- a/src/brevitas/graph/standardize.py +++ b/src/brevitas/graph/standardize.py @@ -59,7 +59,7 @@ def match_node(self, node: Node) -> bool: is_adaptive_2d_mean = ((2, 3) in node.args or [2, 3] in node.args or 'dim' in node.kwargs and (node.kwargs['dim'] == (2, 3) or node.kwargs['dim'] == [2, 3])) - is_adaptive_2d_mean = is_adaptive_2d_mean and not node.kwargs['keepdim'] + is_adaptive_2d_mean = is_adaptive_2d_mean and not node.kwargs.get('keepdim', False) return spr and is_adaptive_2d_mean def move_node_args_to_kwargs(self, node: Node):