Skip to content

Commit

Permalink
biased mul
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Mar 25, 2024
1 parent e52db0a commit e31ed76
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions src/brevitas/nn/quant_scale_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,20 +86,25 @@ def inner_forward_impl(
quant_weight = quant_weight.view(self.runtime_shape)
quant_bias = quant_bias.view(self.runtime_shape)

def biased_mul(input, weight, bias):
out = torch.mul(input, weight)
if bias is not None:
out += bias
return out

# TODO: when implementing new types of QuantTensor, this should be revised
if isinstance(input, QuantTensor):
from brevitas.quant_tensor.torch_handler import quant_layer

output_tensor = quant_layer(
torch.mul,
biased_mul,
input,
quant_weight,
bias=None,
bias=quant_bias,
external_acc_bit_width_fn=self.max_acc_bit_width)
else:
output_tensor = torch.mul(input, quant_weight)
output_tensor = biased_mul(input, quant_weight, quant_bias)

if quant_bias is not None:
output_tensor += quant_bias
return output_tensor

def max_acc_bit_width(self, input_bit_width, weight_bit_width, weight_shape):
Expand Down

0 comments on commit e31ed76

Please sign in to comment.