Skip to content

Commit

Permalink
Fix (quant_tensor): fix typing and remove unused checks
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Mar 19, 2024
1 parent bb4feb2 commit abebad2
Showing 1 changed file with 12 additions and 19 deletions.
31 changes: 12 additions & 19 deletions src/brevitas/quant_tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@

class QuantTensorBase(NamedTuple):
value: Tensor
scale: Optional[Tensor]
zero_point: Optional[Tensor]
bit_width: Optional[Tensor]
signed_t: Optional[Tensor]
training_t: Optional[Tensor]
scale: Tensor
zero_point: Tensor
bit_width: Tensor
signed_t: Tensor
training_t: Tensor


def _unpack_quant_tensor(input_data):
Expand Down Expand Up @@ -61,17 +61,11 @@ def __new__(cls, value, scale, zero_point, bit_width, signed, training):

@property
def signed(self):
if self.signed_t is not None:
return self.signed_t.item()
else:
return None
return self.signed_t.item()

@property
def training(self):
if self.training_t is not None:
return self.training_t.item()
else:
return None
return self.training_t.item()

def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
Expand Down Expand Up @@ -129,8 +123,7 @@ def device(self):
value_device = self.value.device
is_same_device = True
for t in [self.scale, self.zero_point, self.bit_width]:
if t is not None:
is_same_device &= value_device == t.device
is_same_device &= value_device == t.device
if not is_same_device:
raise RuntimeError("Value and metadata are on different devices")
return value_device
Expand Down Expand Up @@ -193,13 +186,13 @@ def is_zero_zero_point(tensor):
return (tensor.zero_point == 0.).all()

def check_scaling_factors_same(self, other):
if self.training is not None and self.training:
if self.training:
return True
if not torch.allclose(self.scale, other.scale):
raise RuntimeError("Scaling factors are different")

def check_zero_points_same(self, other):
if self.training is not None and self.training:
if self.training:
return True
if not torch.allclose(self.zero_point, other.zero_point):
raise RuntimeError("Zero points are different")
Expand All @@ -226,7 +219,7 @@ def transpose(self, *args, **kwargs):
tensor_meta = {
'scale': self.scale, 'zero_point': self.zero_point, 'bit_width': self.bit_width}
for k, tm in tensor_meta.items():
if tm is not None and len(value.shape) == len(tm.shape):
if len(value.shape) == len(tm.shape):
tensor_meta[k] = tm.transpose(*args, **kwargs)
return self.set(value=value, **tensor_meta)

Expand All @@ -235,7 +228,7 @@ def permute(self, *args, **kwargs):
tensor_meta = {
'scale': self.scale, 'zero_point': self.zero_point, 'bit_width': self.bit_width}
for k, tm in tensor_meta.items():
if tm is not None and len(value.shape) == len(tm.shape):
if len(value.shape) == len(tm.shape):
tensor_meta[k] = tm.permute(*args, **kwargs)
return self.set(value=value, **tensor_meta)

Expand Down

0 comments on commit abebad2

Please sign in to comment.