From b4e92873cefcc15c9b861993143baf1f2fdf509d Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 21 Jun 2024 17:17:15 +0200 Subject: [PATCH] Feat (graph/equalize): upcast during equalization computation (#970) --- src/brevitas/graph/equalize.py | 34 +++++++++++++++++++++------------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index fa63bf80d..31f295f5c 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -474,9 +474,11 @@ def _no_equalize(): return _no_equalize() scale_fn = _select_scale_computation_fn(scale_computation_type) - sink_weights = {name: transpose(m.weight.cpu(), axis) for name, (m, axis) in sink_axes.items()} - srcs_range = -1 * torch.ones(max_shape_srcs, device='cpu', dtype=dtype) - sinks_range = -1 * torch.ones(max_shape_sinks, device='cpu', dtype=dtype) + sink_weights = { + name: transpose(m.weight.cpu().to(torch.float32), axis) + for name, (m, axis) in sink_axes.items()} + srcs_range = -1 * torch.ones(max_shape_srcs, device='cpu', dtype=torch.float32) + sinks_range = -1 * torch.ones(max_shape_sinks, device='cpu', dtype=torch.float32) for k, v in sink_weights.items(): # Sinks can be partially equalized, thus we need to select # only the channels we are interested in @@ -493,11 +495,13 @@ def _no_equalize(): # weight equalization if merge_bias: src_weights = { - name: _combine_weights_bias(transpose(m.weight, axis), bias_shrinkage, m.bias).cpu() + name: _combine_weights_bias(transpose(m.weight, axis), bias_shrinkage, + m.bias).cpu().to(torch.float32) for name, (m, axis) in src_axes.items()} else: src_weights = { - name: transpose(m.weight.cpu(), axis) for name, (m, axis) in src_axes.items()} + name: transpose(m.weight.cpu().to(torch.float32), axis) + for name, (m, axis) in src_axes.items()} for k, v in src_weights.items(): # Srcs are always fully equalized, thus we simply need to apply the offset to position them # correctly with respect to the other srcs matrices. @@ -516,8 +520,10 @@ def _no_equalize(): list_of_act_val = list_of_act_val = [ transpose(act_val, act_axis) for act_val in list_of_act_val] srcs_range_act = scale_fn( - torch.cat([act_val.reshape(act_val.size(0), -1) for act_val in list_of_act_val], - 1)).cpu() + torch.cat([ + act_val.reshape(act_val.size(0), -1).cpu().to(torch.float32) + for act_val in list_of_act_val], + 1)) if list_of_act_val is not None: if co_optimize_act_weights and len(src_axes) > 0: @@ -536,9 +542,9 @@ def _no_equalize(): # which is the no-op equivalent for equalization. channelwise_no_equalize = (sinks_range <= EPSILON) | (srcs_range <= EPSILON) sinks_range = torch.where( - channelwise_no_equalize, torch.tensor(1., dtype=dtype, device='cpu'), sinks_range) + channelwise_no_equalize, torch.tensor(1., dtype=torch.float32, device='cpu'), sinks_range) srcs_range = torch.where( - channelwise_no_equalize, torch.tensor(1., dtype=dtype, device='cpu'), srcs_range) + channelwise_no_equalize, torch.tensor(1., dtype=torch.float32, device='cpu'), srcs_range) srcs_range = torch.pow(srcs_range, alpha) sinks_range = torch.pow(sinks_range, 1 - alpha) @@ -548,7 +554,8 @@ def _no_equalize(): if list_of_act_val is not None and list_of_insert_mul_node_fn is not None: device = list_of_act_val[0].device for act_val_shape, insert_mul_node_fn in zip(list_of_act_val_shapes, list_of_insert_mul_node_fn): - insert_mul_node_fn(inverse_scaling_factors.to(device=device), act_val_shape, act_axis) + insert_mul_node_fn( + inverse_scaling_factors.to(device=device, dtype=dtype), act_val_shape, act_axis) if len(src_axes) > 0: for name, (module, axis) in src_axes.items(): module_device = module.weight.device @@ -556,7 +563,7 @@ def _no_equalize(): channel_start = indexes.offset + indexes.start channel_end = indexes.offset + indexes.end partial_inverse_scale = inverse_scaling_factors[channel_start:channel_end].to( - device=module_device) + device=module_device, dtype=dtype) if hasattr(module, 'bias') and module.bias is not None: _update_weights( module, module.bias * partial_inverse_scale.view_as(module.bias), attr='bias') @@ -578,7 +585,7 @@ def _no_equalize(): # one (i.e., no equalization) partial_scaling[indexes.start:indexes.end] = scaling_factors[indexes.offset:indexes.offset + channel_range] - partial_scaling = partial_scaling.to(device=module_device) + partial_scaling = partial_scaling.to(device=module_device, dtype=dtype) _update_weights( module, module.weight * torch.reshape(partial_scaling, sink_broadcast_size), @@ -983,7 +990,8 @@ def forward_stats_hook(self, module, *args, name, batch_dim=0, use_inp=True, **k self.batch_dim_act_map[name] = batch_dim - input_scales = self.scale_fn(x, dim=batch_dim) + dtype = x.dtype + input_scales = self.scale_fn(x.to(torch.float32), dim=batch_dim).to(dtype) if name not in self.float_act_map: self.float_act_map[name] = input_scales else: