-
Notifications
You must be signed in to change notification settings - Fork 197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enhance: Quant Tensor Test #894
Changes from 5 commits
67f63ab
ca962fd
7e80aac
0740f06
c7e6d60
12c5108
367e906
1b9b62a
261831e
7fc19bb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
from enum import Enum | ||
|
||
import pytest | ||
import torch | ||
|
||
from brevitas.inject.enum import QuantType | ||
from brevitas.nn import QuantIdentity | ||
from brevitas.quant_tensor import QuantTensor | ||
|
||
|
||
class Operator(Enum): | ||
ADD = 0 | ||
SUBTRACT = 1 | ||
DIVIDE = 2 | ||
MULTIPLY = 3 | ||
MATMUL = 4 | ||
|
||
|
||
# QuantTensor isn't meant to be initialized directly, it'll be invalid if you do | ||
# so you need to create it indirectly via QuantIdentity for example | ||
def to_quant_tensor(input: torch.Tensor) -> QuantTensor: | ||
mod = QuantIdentity(bit_width=8, quant_type=QuantType.INT, return_quant_tensor=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe |
||
return mod(input) | ||
|
||
|
||
def test_quant_tensor_init(): | ||
x = torch.ones(4, 4) | ||
quant_tensor = to_quant_tensor(x) | ||
normal_tensor = torch.Tensor(x) | ||
|
||
assert torch.isclose(normal_tensor, quant_tensor, atol=0.1).all().item() | ||
|
||
|
||
@pytest.mark.parametrize( | ||
'op', [Operator.ADD, Operator.SUBTRACT, Operator.DIVIDE, Operator.MULTIPLY, Operator.MATMUL]) | ||
def test_quant_tensor_operators(op): | ||
x = torch.ones(4, 4) | ||
|
||
a = torch.Tensor(x) | ||
b = torch.Tensor(x) | ||
qa = to_quant_tensor(a) | ||
qb = to_quant_tensor(b) | ||
|
||
if op == Operator.ADD: | ||
normal = a + b | ||
quant = qa + qb | ||
elif op == Operator.SUBTRACT: | ||
normal = a - b | ||
quant = qa - qb | ||
elif op == Operator.DIVIDE: | ||
normal = a / b | ||
quant = qa / qb | ||
elif op == Operator.MULTIPLY: | ||
normal = a * b | ||
quant = qa * qb | ||
elif op == Operator.MATMUL: | ||
normal = a @ b | ||
# @ matmul operator not implemented for QuantTensor | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the difference between There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't believe there is a difference so its probably something we should create an issue to implement |
||
quant = torch.matmul(qa, qb) | ||
else: | ||
# unrecognised operator | ||
assert False | ||
|
||
# tolerance set to a high value as there is considerable loss of precision | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Comment is outdated I believe |
||
assert torch.isclose(normal, quant, atol=0.1).all().item() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is still tolerance required for all operators or are there any operators more troublesome than other? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll see if I can tighten up the tolerance |
||
|
||
|
||
def test_quant_tensor_div_by_zero(): | ||
a = to_quant_tensor(torch.ones(4, 4)) | ||
b = to_quant_tensor(torch.zeros(4, 4)) | ||
assert torch.isinf(a / b).all().item() | ||
|
||
|
||
def test_quant_tensor_div_by_fraction(): | ||
a = to_quant_tensor(torch.ones(4, 4)) | ||
b = to_quant_tensor(torch.ones(4, 4) * 0.5) | ||
assert torch.isclose(a / b, torch.ones(4, 4) * 2, atol=0.1).all().item() | ||
|
||
|
||
def test_quant_tensor_transpose(): | ||
x = torch.ones(4, 4).tril() | ||
a = x.clone() | ||
b = to_quant_tensor(x) | ||
assert torch.isclose(a.transpose(0, 1), b.transpose(0, 1), atol=0.01).all().item() | ||
|
||
|
||
def test_quant_tensor_view(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. View and transpose open the discussion to a broader topic regarding how to deal with quant metadata views and transpose, especially in the case where we are doing per channel or finer granularity quantizations. For now, I would add a TODO in both test case that says that we need to deal with quant metadata and test it |
||
x = torch.ones(4, 4) | ||
a = to_quant_tensor(x) | ||
b = torch.Tensor(x) | ||
|
||
assert torch.isclose(a.view(-1), b.view(-1), atol=0.01).all().item() | ||
assert torch.isclose(a.view(2, -1), b.view(2, -1), atol=0.01).all().item() | ||
assert torch.isclose(a.view(16, -1), b.view(16, -1), atol=0.01).all().item() | ||
assert torch.isclose(a.view(8, 2), b.view(8, 2), atol=0.01).all().item() | ||
|
||
|
||
def test_is_valid(): | ||
x = torch.randn(4, 4) | ||
# directly initialised QuantTensor shouldn't be valid | ||
invalid_quant_tensor = QuantTensor(x, None, None, None, None, None) | ||
Giuseppe5 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
assert invalid_quant_tensor.is_valid == False | ||
|
||
valid_quant_tensor = to_quant_tensor(x) | ||
assert valid_quant_tensor.is_valid |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would say that it could be invalid if you generate it manually (likewise, it is possible to generate manually a valid QuantTensor if you carefully pick scale factors, bit_width, values, etc.).