Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance: Quant Tensor Test #894

Merged
merged 10 commits into from
Apr 10, 2024
107 changes: 107 additions & 0 deletions tests/brevitas/quant_tensor/test_quant_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
from enum import Enum

import pytest
import torch

from brevitas.inject.enum import QuantType
from brevitas.nn import QuantIdentity
from brevitas.quant_tensor import QuantTensor


class Operator(Enum):
ADD = 0
SUBTRACT = 1
DIVIDE = 2
MULTIPLY = 3
MATMUL = 4


# QuantTensor isn't meant to be initialized directly, it'll be invalid if you do
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say that it could be invalid if you generate it manually (likewise, it is possible to generate manually a valid QuantTensor if you carefully pick scale factors, bit_width, values, etc.).

# so you need to create it indirectly via QuantIdentity for example
def to_quant_tensor(input: torch.Tensor) -> QuantTensor:
mod = QuantIdentity(bit_width=8, quant_type=QuantType.INT, return_quant_tensor=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe quant_type arg is not necessary.

return mod(input)


def test_quant_tensor_init():
x = torch.ones(4, 4)
quant_tensor = to_quant_tensor(x)
normal_tensor = torch.Tensor(x)

assert torch.isclose(normal_tensor, quant_tensor, atol=0.1).all().item()


@pytest.mark.parametrize(
'op', [Operator.ADD, Operator.SUBTRACT, Operator.DIVIDE, Operator.MULTIPLY, Operator.MATMUL])
def test_quant_tensor_operators(op):
x = torch.ones(4, 4)

a = torch.Tensor(x)
b = torch.Tensor(x)
qa = to_quant_tensor(a)
qb = to_quant_tensor(b)

if op == Operator.ADD:
normal = a + b
quant = qa + qb
elif op == Operator.SUBTRACT:
normal = a - b
quant = qa - qb
elif op == Operator.DIVIDE:
normal = a / b
quant = qa / qb
elif op == Operator.MULTIPLY:
normal = a * b
quant = qa * qb
elif op == Operator.MATMUL:
normal = a @ b
# @ matmul operator not implemented for QuantTensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference between @ and matmul? Also in terms of implementations, what would we need to override to implement @?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't believe there is a difference so its probably something we should create an issue to implement

quant = torch.matmul(qa, qb)
else:
# unrecognised operator
assert False

# tolerance set to a high value as there is considerable loss of precision
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment is outdated I believe

assert torch.isclose(normal, quant, atol=0.1).all().item()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is still tolerance required for all operators or are there any operators more troublesome than other?
I would try to keep a tighter bound where possible if it's not too much headache

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll see if I can tighten up the tolerance



def test_quant_tensor_div_by_zero():
a = to_quant_tensor(torch.ones(4, 4))
b = to_quant_tensor(torch.zeros(4, 4))
assert torch.isinf(a / b).all().item()


def test_quant_tensor_div_by_fraction():
a = to_quant_tensor(torch.ones(4, 4))
b = to_quant_tensor(torch.ones(4, 4) * 0.5)
assert torch.isclose(a / b, torch.ones(4, 4) * 2, atol=0.1).all().item()


def test_quant_tensor_transpose():
x = torch.ones(4, 4).tril()
a = x.clone()
b = to_quant_tensor(x)
assert torch.isclose(a.transpose(0, 1), b.transpose(0, 1), atol=0.01).all().item()


def test_quant_tensor_view():
Copy link
Collaborator

@Giuseppe5 Giuseppe5 Mar 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

View and transpose open the discussion to a broader topic regarding how to deal with quant metadata views and transpose, especially in the case where we are doing per channel or finer granularity quantizations.

For now, I would add a TODO in both test case that says that we need to deal with quant metadata and test it

x = torch.ones(4, 4)
a = to_quant_tensor(x)
b = torch.Tensor(x)

assert torch.isclose(a.view(-1), b.view(-1), atol=0.01).all().item()
assert torch.isclose(a.view(2, -1), b.view(2, -1), atol=0.01).all().item()
assert torch.isclose(a.view(16, -1), b.view(16, -1), atol=0.01).all().item()
assert torch.isclose(a.view(8, 2), b.view(8, 2), atol=0.01).all().item()


def test_is_valid():
x = torch.randn(4, 4)
# directly initialised QuantTensor shouldn't be valid
invalid_quant_tensor = QuantTensor(x, None, None, None, None, None)
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
assert invalid_quant_tensor.is_valid == False

valid_quant_tensor = to_quant_tensor(x)
assert valid_quant_tensor.is_valid
Loading