diff --git a/src/brevitas/graph/calibrate.py b/src/brevitas/graph/calibrate.py index ad849c5e7..0fd01a2eb 100644 --- a/src/brevitas/graph/calibrate.py +++ b/src/brevitas/graph/calibrate.py @@ -290,13 +290,13 @@ def apply_correction(self, model): module.bias.data += correction elif self.skip_if_no_bias is False: # When accelerate is enabled, bring tensors onto the device to avoid allocating a meta parameter. - if hasattr(self.layer, 'allocate_params'): - self.layer.allocate_params(self.layer) + if hasattr(module, 'allocate_params'): + module.allocate_params(module) module.register_parameter( 'bias', nn.Parameter(correction).to(module.weight.device)) # Offload params again - if hasattr(self.layer, 'offload_params'): - self.layer.offload_params(self.layer) + if hasattr(module, 'offload_params'): + module.offload_params(module) def compute_correct_bias(self, module, inp, name): inp = self.unpack_input(inp)