Skip to content

Commit

Permalink
Fix (bias_corr): load quant models after bias correction (#868)
Browse files Browse the repository at this point in the history
  • Loading branch information
costigt-dev authored Feb 21, 2024
1 parent caf0338 commit 6251f9e
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 5 deletions.
37 changes: 32 additions & 5 deletions src/brevitas/graph/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
from .base import Transform

__all__ = [
'ClipFloatWeights', 'DisableEnableQuantization', 'bias_correction_mode', 'calibration_mode']
'ClipFloatWeights',
'DisableEnableQuantization',
'bias_correction_mode',
'calibration_mode',
'load_quant_model']

_PARAM_PROXIES = (WeightQuantProxyFromInjector, BiasQuantProxyFromInjector)

Expand Down Expand Up @@ -85,11 +89,33 @@ def __exit__(self, type, value, traceback):
self.model, is_training=self.previous_training_state, quantization_enabled=True)


class load_quant_model:

def __init__(self, model):
self.model = model
self.tracked_modules = []

def __enter__(self):
for module in self.model.modules():
if issubclass(type(module), QuantWBIOL):
if module.bias is None:
module.register_parameter(
'bias',
nn.Parameter(torch.empty(module.weight.shape[0])).to(module.weight.device))
self.tracked_modules.append(module)

def __exit__(self, type, value, traceback):
for module in self.tracked_modules:
# empty tensor has a numel result of 0
if torch.numel(module.bias) == 0:
module.bias = None


class bias_correction_mode:

def __init__(self, model, enabled=True):
def __init__(self, model, enabled=True, skip_if_no_bias=False):
self.model = model
self.bias_correction = _BiasCorrection()
self.bias_correction = _BiasCorrection(skip_if_no_bias=skip_if_no_bias)
self.enabled = enabled
self.hooks = []

Expand Down Expand Up @@ -209,14 +235,15 @@ class _BiasCorrection(DisableEnableQuantization):

LAYERS = (QuantWBIOL,)

def __init__(self, layers=LAYERS):
def __init__(self, layers=LAYERS, skip_if_no_bias=False):
super(_BiasCorrection, self).__init__()
self.layers = layers
self.iterations = {}
self.correction_map = {}
self.float_mean_map = {}
self.collect_float_mean_hooks = []
self.correct_bias_hooks = []
self.skip_if_no_bias = skip_if_no_bias

def compute_mean(self, inp, transpose_dim):
inp = inp.transpose(0, transpose_dim)
Expand Down Expand Up @@ -248,7 +275,7 @@ def apply_correction(self, model):
correction = self.correction_map[name] / self.iterations[name]
if module.bias is not None:
module.bias.data += correction
else:
elif self.skip_if_no_bias is False:
module.register_parameter(
'bias', nn.Parameter(correction).to(module.weight.device))

Expand Down
51 changes: 51 additions & 0 deletions tests/brevitas/graph/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from brevitas.graph.calibrate import bias_correction_mode
from brevitas.graph.calibrate import calibration_mode
from brevitas.graph.calibrate import load_quant_model
import brevitas.nn as qnn
from brevitas.quant import Int8ActPerTensorFixedPoint
# Use custom implementation of kthvalue as work around to (b)float16 kernel limitations
Expand Down Expand Up @@ -189,3 +190,53 @@ def simple_hook(mod, inp, out):
) # In bias_correction mode, the input to each layer is equal to the FP output of the previous layer
assert (inputs[1] == fp_outs[1, 0, :]).all(
) # In bias_correction mode, the input to each layer is equal to the FP output of the previous layer


def test_import_bias_correction():

class SimpleQuantLinearNet(nn.Module):

def __init__(self) -> None:
super().__init__()
self.net = nn.Sequential(qnn.QuantLinear(IN_CH, OUT_CH, bias=False))

def forward(self, inp):
return self.net(inp)

model = SimpleQuantLinearNet()

with bias_correction_mode(model):
model(torch.randn((1, IN_CH)))

for m in model.modules():
if isinstance(m, qnn.QuantLinear):
assert m.bias is not None

new_model = SimpleQuantLinearNet()
with load_quant_model(new_model):
new_model.load_state_dict(model.state_dict())

for m in new_model.modules():
if isinstance(m, qnn.QuantLinear):
assert m.bias is not None


def test_bias_correction_flag():

class SimpleQuantLinearNet(nn.Module):

def __init__(self) -> None:
super().__init__()
self.net = nn.Sequential(qnn.QuantLinear(IN_CH, OUT_CH, bias=False))

def forward(self, inp):
return self.net(inp)

model = SimpleQuantLinearNet()

with bias_correction_mode(model, skip_if_no_bias=True):
model(torch.randn((1, IN_CH)))

for m in model.modules():
if isinstance(m, qnn.QuantLinear):
assert m.bias is None

0 comments on commit 6251f9e

Please sign in to comment.