Skip to content

Commit

Permalink
Fix bunch of tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Mar 25, 2024
1 parent b216d14 commit 4d04a99
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 9 deletions.
14 changes: 8 additions & 6 deletions src/brevitas/nn/mixin/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from brevitas.inject import ExtendedInjector
from brevitas.inject import Injector
from brevitas.quant_tensor import _unpack_quant_tensor
from brevitas.quant_tensor import IntQuantTensor
from brevitas.quant_tensor import QuantTensor
from brevitas.utils.torch_utils import compute_channel_view_shape

Expand Down Expand Up @@ -74,8 +75,9 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe
# Hack to recognize a QuantTensor that has decayed to a tuple
# when used as input to tracing (e.g. during ONNX export)
if (torch._C._get_tracing_state() is not None and isinstance(inp, tuple) and
len(inp) == len(QuantTensor._fields) and all([isinstance(t, Tensor) for t in inp])):
inp = QuantTensor(*inp)
len(inp) == len(IntQuantTensor._fields) and
all([isinstance(t, Tensor) for t in inp])):
inp = IntQuantTensor(*inp)
if not torch._C._get_tracing_state():
if isinstance(inp, QuantTensor):
inp = inp.set(value=inp.value.rename(None))
Expand Down Expand Up @@ -186,7 +188,7 @@ def pack_quant_outputs(self, quant_outputs):
# inner layers in a deep network overrides it, so we check again.
if self.export_mode:
if self.return_quant_tensor and self.io_quant.is_quant_enabled:
return QuantTensor(
return IntQuantTensor(
quant_outputs,
self.io_quant.scale(),
self.io_quant.zero_point(),
Expand All @@ -198,7 +200,7 @@ def pack_quant_outputs(self, quant_outputs):
seq_dim = 1 if self.cell.batch_first else 0
if self.return_quant_tensor and self.io_quant.is_quant_enabled:
outputs = [
QuantTensor(
IntQuantTensor(
torch.unsqueeze(quant_output[0], dim=seq_dim),
quant_output[1],
quant_output[2],
Expand All @@ -217,7 +219,7 @@ def pack_quant_state(self, quant_state, quant):
# inner layers in a deep network overrides it, so we check again.
if self.export_mode:
if self.return_quant_tensor and quant.is_quant_enabled:
quant_state = QuantTensor(
quant_state = IntQuantTensor(
torch.unsqueeze(quant_state, dim=0),
quant.scale(),
quant.zero_point(),
Expand All @@ -228,7 +230,7 @@ def pack_quant_state(self, quant_state, quant):
quant_state = torch.unsqueeze(quant_state, dim=0)
else:
if self.return_quant_tensor and quant.is_quant_enabled:
quant_state = QuantTensor(
quant_state = IntQuantTensor(
torch.unsqueeze(quant_state[0], dim=0),
quant_state[1],
quant_state[2],
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/quant_tensor/torch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def transpose_handler(inp, *args, **kwargs):

@implements(torch.cat)
def cat_handler(*args, **kwargs):
from brevitas.quant_tensor import QuantTensor
return QuantTensor.cat(*args, **kwargs)
from brevitas.quant_tensor import IntQuantTensor
return IntQuantTensor.cat(*args, **kwargs)


@implements(F.pad)
Expand Down
3 changes: 2 additions & 1 deletion tests/brevitas/nn/nn_quantizers_fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from brevitas.quant.scaled_int import Uint8ActPerTensorFloat
from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloat
from brevitas.quant_tensor import IntQuantTensor
from brevitas.quant_tensor import QuantTensor

SEED = 123456
Expand Down Expand Up @@ -169,7 +170,7 @@ def forward(self, x):
raise RuntimeError("Unsupported operation")

if input_quantized:
quant_inp = QuantTensor(
quant_inp = IntQuantTensor(
torch.randint(-128, 127, in_size) * 0.128, 0.128, 0., 8., True, is_training)
else:
quant_inp = torch.randn(in_size)
Expand Down
14 changes: 14 additions & 0 deletions tests/brevitas/test_quant_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch

from brevitas.quant_tensor import IntQuantTensor
from brevitas.quant_tensor import QuantTensor


def test_qt_structure():
qt = IntQuantTensor(
torch.randn(10), torch.randn(1), torch.tensor(0.), torch.tensor(8.), True, False)
assert isinstance(qt, IntQuantTensor)
assert isinstance(qt, QuantTensor)
assert isinstance(qt, tuple)
assert hasattr(qt, '_fields')
assert len(qt._fields) == 6

0 comments on commit 4d04a99

Please sign in to comment.