Skip to content

Commit

Permalink
Fix for a2q
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Feb 28, 2024
1 parent 8d52868 commit a8900a0
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/brevitas/export/common/handler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def quant_input_zero_point(cls, module):
@classmethod
def quant_weight_zero_point(cls, module):
signed = module.weight_quant.is_signed
zero_point = module.weight_quant.zero_point()
zero_point = module.quant_weight().zero_point
bit_width = module.weight_quant.bit_width()
return cls.zero_point_with_dtype(signed, bit_width, zero_point)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def op_symbolic_kwargs(self, module: Union[QuantConv1d, QuantConv2d]):
'input_scale': module.input_quant.scale(),
'input_zero_point': self.quant_input_zero_point(module),
'int_weight': self.int_weight(module),
'weight_scale': to_0dim_if_scalar(module.weight_quant.scale().flatten()),
'weight_scale': to_0dim_if_scalar(module.quant_weight().scale.flatten()),
'weight_zero_point': to_0dim_if_scalar(self.quant_weight_zero_point(module).flatten()),
'output_scale': module.output_quant.scale(),
'output_zero_point': self.quant_output_zero_point(module),
Expand Down Expand Up @@ -148,7 +148,7 @@ def op_symbolic_kwargs(self, module: QuantLinear):
'input_scale': module.input_quant.scale(),
'input_zero_point': self.quant_input_zero_point(module),
'int_weight': self.int_weight(module).view(module.out_features, module.in_features, 1),
'weight_scale': to_0dim_if_scalar(module.weight_quant.scale().flatten()),
'weight_scale': to_0dim_if_scalar(module.quant_weight().scale.flatten()),
'weight_zero_point': to_0dim_if_scalar(self.quant_weight_zero_point(module).flatten()),
'output_scale': module.output_quant.scale(),
'output_zero_point': self.quant_output_zero_point(module),
Expand Down
2 changes: 1 addition & 1 deletion tests/brevitas/graph/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def forward(self, x):
model(inp)

expected_scale = reference_implementation_scale_factors_po2(inp)
scale = model.act.quant_act_scale()
scale = model.act.act_quant.scale()

assert torch.allclose(expected_scale, scale)

Expand Down

0 comments on commit a8900a0

Please sign in to comment.