Skip to content

Commit

Permalink
New QuantTensor Structure
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Mar 25, 2024
1 parent e31ed76 commit b216d14
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 62 deletions.
4 changes: 2 additions & 2 deletions src/brevitas/fx/value_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
import torch.utils._pytree as pytree

from brevitas import torch_version
from brevitas.quant_tensor import QuantTensorBase
from brevitas.quant_tensor import QuantTensor

from . import *
from . import _assert_is_none
Expand All @@ -82,7 +82,7 @@
from . import ScopeContextManager

_UNSET = object()
extended_base_types = base_types + (QuantTensorBase,)
extended_base_types = base_types + (QuantTensor,)

FRAME_FILES = [
'fx/brevitas_tracer.py',
Expand Down
30 changes: 17 additions & 13 deletions src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from brevitas import config
from brevitas.function import max_int
from brevitas.inject import BaseInjector as Injector
from brevitas.quant_tensor import IntQuantTensor
from brevitas.quant_tensor import QuantTensor
from brevitas.utils.quant_utils import _CachedIO

Expand Down Expand Up @@ -103,11 +104,11 @@ def bit_width(self):
bit_width = self.__call__(self.tracked_parameter_list[0]).bit_width
return bit_width

def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]:
def forward(self, x: torch.Tensor) -> Union[Tensor, IntQuantTensor]:
if self.is_quant_enabled:
impl = self.export_handler if self.export_mode else self.tensor_quant
out, scale, zero_point, bit_width = impl(x)
return QuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training)
return IntQuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training)
else: # quantization disabled
return x

Expand All @@ -128,11 +129,11 @@ def pre_zero_point(self):
out, scale, zero_point, bit_width, pre_scale, pre_zero_point = output_tuple
return pre_zero_point

def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]:
def forward(self, x: torch.Tensor) -> Union[Tensor, IntQuantTensor]:
if self.is_quant_enabled:
impl = self.export_handler if self.export_mode else self.tensor_quant
out, scale, zero_point, bit_width, pre_scale, pre_zero_point = impl(x)
return QuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training)
return IntQuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training)
else: # quantization disabled
return x

Expand All @@ -157,11 +158,12 @@ def pre_zero_point(self):
raise NotImplementedError

def forward(
self,
x: torch.Tensor,
quant_input: Optional[Union[Tensor, QuantTensor]] = None) -> Union[Tensor, QuantTensor]:
self,
x: torch.Tensor,
quant_input: Optional[Union[Tensor,
IntQuantTensor]] = None) -> Union[Tensor, IntQuantTensor]:
if isinstance(quant_input,
QuantTensor) and not self.training and self.cache_inference_quant_act:
IntQuantTensor) and not self.training and self.cache_inference_quant_act:
cached_inp = _CachedIO(quant_input.detach(), self.cache_quant_io_metadata_only)
self._cached_act = cached_inp

Expand All @@ -170,14 +172,14 @@ def forward(
assert self._cached_act is not None, "No cached quant input found. Enable caching and perform a forward pass"
quant_input = self._cached_act
else:
assert isinstance(quant_input, QuantTensor), "Input must be quantized"
assert isinstance(quant_input, IntQuantTensor), "Input must be quantized"

input_bit_width = quant_input.bit_width
input_is_signed = quant_input.signed

impl = self.export_handler if self.export_mode else self.tensor_quant
out, scale, zero_point, bit_width, pre_scale, pre_zero_point = impl(x, input_bit_width, input_is_signed)
return QuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training)
return IntQuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training)
else: # quantization disabled
return x

Expand Down Expand Up @@ -236,7 +238,7 @@ def bit_width(self):

def forward(self,
x: Tensor,
input_scale: Optional[Tensor] = None) -> Union[Tensor, QuantTensor]:
input_scale: Optional[Tensor] = None) -> Union[Tensor, IntQuantTensor]:
out = x
if self.is_quant_enabled:
impl = self.export_handler if self.export_mode else self.tensor_quant
Expand All @@ -251,10 +253,12 @@ def forward(self,
else:
out, out_scale, out_zp, out_bit_width = impl(x)

out = QuantTensor(out, out_scale, out_zp, out_bit_width, self.is_signed, self.training)
out = IntQuantTensor(
out, out_scale, out_zp, out_bit_width, self.is_signed, self.training)
else:
out = x
if isinstance(out, QuantTensor) and not self.training and self.cache_inference_quant_bias:
if isinstance(out,
IntQuantTensor) and not self.training and self.cache_inference_quant_bias:
cached_bias = _CachedIO(out.detach(), metadata_only=False)
self._cached_bias = cached_bias
return out
26 changes: 14 additions & 12 deletions src/brevitas/proxy/runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing_extensions import runtime_checkable

import brevitas
from brevitas.quant_tensor import IntQuantTensor
from brevitas.quant_tensor import QuantTensor
from brevitas.utils.quant_utils import _CachedIO

Expand Down Expand Up @@ -166,11 +167,11 @@ def bit_width(self, force_eval=True):
elif self._cached_act is None:
return None

def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
def forward(self, x: Union[Tensor, IntQuantTensor]) -> Union[Tensor, IntQuantTensor]:
out = x
if self.fused_activation_quant_proxy is not None:
y = x
if isinstance(y, QuantTensor):
if isinstance(y, IntQuantTensor):
y = y.value

if self.export_mode:
Expand All @@ -180,15 +181,15 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
y = self.fused_activation_quant_proxy.activation_impl(y)
else:
y = self.fused_activation_quant_proxy(y)
# If y is an empty QuantTensor, we need to check if this is a passthrough proxy,
# If y is an empty IntQuantTensor, we need to check if this is a passthrough proxy,
# otherwise return a simple Tensor
if isinstance(y, tuple) and not any(map(lambda f: f is None, y)):
out = QuantTensor(*y, signed=self.is_signed, training=self.training)
out = IntQuantTensor(*y, signed=self.is_signed, training=self.training)
elif self.is_passthrough_act: # preserve scale/zp/bit/sign even without output quant
if isinstance(y, tuple):
y = y[0]
if isinstance(x, QuantTensor):
out = QuantTensor(
if isinstance(x, IntQuantTensor):
out = IntQuantTensor(
y, x.scale, x.zero_point, x.bit_width, x.signed, self.training)
else:
out = y
Expand All @@ -199,7 +200,7 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
else:
# If fused activation quant proxy is not enabled, return the input
out = x
if not self.training and self.cache_inference_quant_act and isinstance(out, QuantTensor):
if not self.training and self.cache_inference_quant_act and isinstance(out, IntQuantTensor):
cached_out = _CachedIO(out.detach(), self.cache_quant_io_metadata_only)
self._cached_act = cached_out
return out
Expand All @@ -216,11 +217,11 @@ def zero_point(self, force_eval=True):

class ClampQuantProxyFromInjector(QuantProxyFromInjector, AccQuantProxyProtocol):

def forward(self, x: QuantTensor) -> Union[Tensor, QuantTensor]:
def forward(self, x: IntQuantTensor) -> Union[Tensor, IntQuantTensor]:
if self.is_quant_enabled:
out_tuple = self.tensor_quant(x.value, x.scale, x.bit_width)
out_value, out_scale, out_zp, out_bit_width = out_tuple
return QuantTensor(
return IntQuantTensor(
out_value, out_scale, out_zp, out_bit_width, self.is_signed, self.training)
return x

Expand All @@ -232,19 +233,20 @@ def bit_width(self):
return None
zhs = self._zero_hw_sentinel()
# Signed might or might not be defined. We just care about retrieving the bitwidth
empty_imp = QuantTensor(zhs, zhs, zhs, zhs, signed=True, training=self.training)
empty_imp = IntQuantTensor(zhs, zhs, zhs, zhs, signed=True, training=self.training)
bit_width = self.__call__(empty_imp).bit_width
return bit_width

def forward(self, x: QuantTensor) -> Union[Tensor, QuantTensor]:
def forward(self, x: IntQuantTensor) -> Union[Tensor, IntQuantTensor]:
if self.is_quant_enabled:
if self.export_mode:
out_tuple = self.export_handler(
x.value, x.scale, x.zero_point, x.bit_width, x.signed)
else:
out_tuple = self.tensor_quant(x.value, x.scale, x.zero_point, x.bit_width)
out_value, out_scale, out_zp, out_bit_width = out_tuple
return QuantTensor(out_value, out_scale, out_zp, out_bit_width, x.signed, self.training)
return IntQuantTensor(
out_value, out_scale, out_zp, out_bit_width, x.signed, self.training)
else:
return x

Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/quant_tensor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from .base_quant_tensor import *
from .base_quant_tensor import _unpack_quant_tensor
from .base_quant_tensor import QuantTensorBase
from .int_quant_tensor import QuantTensor
from .int_quant_tensor import *
8 changes: 6 additions & 2 deletions src/brevitas/quant_tensor/base_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
from torch import Tensor


class QuantTensorBase(NamedTuple):
class QuantTensor:
pass


class IntTensorBase(NamedTuple):
value: Tensor
scale: Optional[Tensor]
zero_point: Optional[Tensor]
Expand All @@ -13,7 +17,7 @@ class QuantTensorBase(NamedTuple):


def _unpack_quant_tensor(input_data):
if isinstance(input_data, QuantTensorBase):
if isinstance(input_data, QuantTensor):
return input_data.value
elif isinstance(input_data, tuple):
return tuple([_unpack_quant_tensor(v) for v in input_data])
Expand Down
Loading

0 comments on commit b216d14

Please sign in to comment.