From f7d634deb08d5bac344055ed8a400391e5fa548a Mon Sep 17 00:00:00 2001 From: nickfraser Date: Mon, 8 Jul 2024 15:20:46 +0100 Subject: [PATCH] Fix (graph/bias_correction): Fix when layer parameters are offloaded to `accelerate` (#962) * Fix (graph/bias_correction): Fix when layer parameters are offloaded to `accelerate` * Fix (bias_correction): Typo fix * Fix (bias_correction): Apply accelerate fix to entire `if/elif` block. * fix (bias_corr/accelerate): Added comment --- src/brevitas/graph/calibrate.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/brevitas/graph/calibrate.py b/src/brevitas/graph/calibrate.py index a5d7e08e8..b3117cff7 100644 --- a/src/brevitas/graph/calibrate.py +++ b/src/brevitas/graph/calibrate.py @@ -286,11 +286,18 @@ def apply_correction(self, model): for name, module in model.named_modules(): if name in self.correction_map.keys(): correction = self.correction_map[name] / self.iterations[name] + # When accelerate is enabled, bring tensors onto the device to avoid allocating a meta parameter. + if hasattr(module, 'allocate_params'): + module.allocate_params(module) if module.bias is not None: module.bias.data += correction elif self.skip_if_no_bias is False: + # If accelerate is enabled, bias will be on the same execution device as the weights, but won't be managed properly by accelerate module.register_parameter( 'bias', nn.Parameter(correction).to(module.weight.device)) + # Offload params again + if hasattr(module, 'offload_params'): + module.offload_params(module) def compute_correct_bias(self, module, inp, name): inp = self.unpack_input(inp)