diff --git a/src/brevitas/quant_tensor/int_quant_tensor.py b/src/brevitas/quant_tensor/int_quant_tensor.py index 7391a9ba8..25685e7c3 100644 --- a/src/brevitas/quant_tensor/int_quant_tensor.py +++ b/src/brevitas/quant_tensor/int_quant_tensor.py @@ -235,6 +235,9 @@ def dim(self): def add(self, other): return self + other + def unsqueeze(self, *args, **kwargs): + return self.value.unsqueeze(*args, **kwargs) + @staticmethod def cat(tensors, dim, out=None): if out is not None: