diff --git a/src/brevitas/nn/mixin/base.py b/src/brevitas/nn/mixin/base.py index 3a865b282..0f38b4af1 100644 --- a/src/brevitas/nn/mixin/base.py +++ b/src/brevitas/nn/mixin/base.py @@ -75,7 +75,8 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe # when used as input to tracing (e.g. during ONNX export) if (torch._C._get_tracing_state() is not None and isinstance(inp, tuple) and len(inp) == len(QuantTensor._fields) and all([isinstance(t, Tensor) for t in inp])): - inp = QuantTensor(*inp) + raise + # inp = QuantTensor(*inp) if not torch._C._get_tracing_state(): if isinstance(inp, QuantTensor): inp = inp.set(value=inp.value.rename(None))