Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance: Quant Tensor Test #894

Merged
merged 10 commits into from
Apr 10, 2024
4 changes: 2 additions & 2 deletions tests/brevitas/quant_tensor/test_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_quant_tensor_transpose():

def test_quant_tensor_view():
Copy link
Collaborator

@Giuseppe5 Giuseppe5 Mar 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

View and transpose open the discussion to a broader topic regarding how to deal with quant metadata views and transpose, especially in the case where we are doing per channel or finer granularity quantizations.

For now, I would add a TODO in both test case that says that we need to deal with quant metadata and test it

x = torch.ones(4, 4)
a = QuantTensor(x)
a = to_quant_tensor(x)
b = torch.Tensor(x)

assert torch.isclose(a.view(-1), b.view(-1), atol=0.01).all().item()
Expand All @@ -100,7 +100,7 @@ def test_quant_tensor_view():
def test_is_valid():
x = torch.randn(4, 4)
# directly initialised QuantTensor shouldn't be valid
invalid_quant_tensor = QuantTensor(x)
invalid_quant_tensor = QuantTensor(x, None, None, None, None, None)
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
assert invalid_quant_tensor.is_valid == False

valid_quant_tensor = to_quant_tensor(x)
Expand Down
Loading