diff --git a/src/brevitas/quant/solver/act.py b/src/brevitas/quant/solver/act.py index 345239089..35f771c54 100644 --- a/src/brevitas/quant/solver/act.py +++ b/src/brevitas/quant/solver/act.py @@ -1,6 +1,8 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +from warnings import warn + import torch from torch import nn from torch import Tensor @@ -111,6 +113,33 @@ def scaling_shape(scaling_per_output): elif scaling_per_output == ScalingPerOutputType.TENSOR: return SCALAR_SHAPE + @value + def group_dim(module=None, group_size=None): + # Avoid circular import + from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer + + if group_size is not None and module is not None: + if isinstance(module, QuantWeightBiasInputOutputLayer): + if isinstance(module, nn.Linear): + return -1 + elif isinstance(module, + (nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.ConvTranspose1d, + nn.ConvTranspose2d, + nn.ConvTranspose3d)): + warn( + "Group dim is being selected assuming batched input. Using unbatched input will fail and requires manually specification of group_dim" + ) + # We are assuming batched input + return 1 + else: + raise RuntimeError("Cannot determine automatically group_dim. Please specify") + else: + raise RuntimeError( + f"Cannot determine automatically group_dim for {type(module)}. Please specify") + class SolveActScalingPerOutputChannelShape(ExtendedInjector):