From bb192885cb00a4bf973cd3a0abf7e57b2512f547 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 14 Jun 2024 15:53:37 +0100 Subject: [PATCH 1/2] Feat (graph/equalize): upcast during equalization computation --- src/brevitas/graph/equalize.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index fa63bf80d..601081676 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -474,9 +474,11 @@ def _no_equalize(): return _no_equalize() scale_fn = _select_scale_computation_fn(scale_computation_type) - sink_weights = {name: transpose(m.weight.cpu(), axis) for name, (m, axis) in sink_axes.items()} - srcs_range = -1 * torch.ones(max_shape_srcs, device='cpu', dtype=dtype) - sinks_range = -1 * torch.ones(max_shape_sinks, device='cpu', dtype=dtype) + sink_weights = { + name: transpose(m.weight.cpu().to(torch.float32), axis) + for name, (m, axis) in sink_axes.items()} + srcs_range = -1 * torch.ones(max_shape_srcs, device='cpu', dtype=torch.float32) + sinks_range = -1 * torch.ones(max_shape_sinks, device='cpu', dtype=torch.float32) for k, v in sink_weights.items(): # Sinks can be partially equalized, thus we need to select # only the channels we are interested in @@ -493,11 +495,13 @@ def _no_equalize(): # weight equalization if merge_bias: src_weights = { - name: _combine_weights_bias(transpose(m.weight, axis), bias_shrinkage, m.bias).cpu() + name: _combine_weights_bias(transpose(m.weight, axis), bias_shrinkage, + m.bias).cpu().to(torch.float32) for name, (m, axis) in src_axes.items()} else: src_weights = { - name: transpose(m.weight.cpu(), axis) for name, (m, axis) in src_axes.items()} + name: transpose(m.weight.cpu().to(torch.float32), axis) + for name, (m, axis) in src_axes.items()} for k, v in src_weights.items(): # Srcs are always fully equalized, thus we simply need to apply the offset to position them # correctly with respect to the other srcs matrices. @@ -516,8 +520,10 @@ def _no_equalize(): list_of_act_val = list_of_act_val = [ transpose(act_val, act_axis) for act_val in list_of_act_val] srcs_range_act = scale_fn( - torch.cat([act_val.reshape(act_val.size(0), -1) for act_val in list_of_act_val], - 1)).cpu() + torch.cat([ + act_val.reshape(act_val.size(0), -1).cpu().to(torch.float32) + for act_val in list_of_act_val], + 1)) if list_of_act_val is not None: if co_optimize_act_weights and len(src_axes) > 0: @@ -536,9 +542,9 @@ def _no_equalize(): # which is the no-op equivalent for equalization. channelwise_no_equalize = (sinks_range <= EPSILON) | (srcs_range <= EPSILON) sinks_range = torch.where( - channelwise_no_equalize, torch.tensor(1., dtype=dtype, device='cpu'), sinks_range) + channelwise_no_equalize, torch.tensor(1., dtype=torch.float32, device='cpu'), sinks_range) srcs_range = torch.where( - channelwise_no_equalize, torch.tensor(1., dtype=dtype, device='cpu'), srcs_range) + channelwise_no_equalize, torch.tensor(1., dtype=torch.float32, device='cpu'), srcs_range) srcs_range = torch.pow(srcs_range, alpha) sinks_range = torch.pow(sinks_range, 1 - alpha) @@ -548,7 +554,8 @@ def _no_equalize(): if list_of_act_val is not None and list_of_insert_mul_node_fn is not None: device = list_of_act_val[0].device for act_val_shape, insert_mul_node_fn in zip(list_of_act_val_shapes, list_of_insert_mul_node_fn): - insert_mul_node_fn(inverse_scaling_factors.to(device=device), act_val_shape, act_axis) + insert_mul_node_fn( + inverse_scaling_factors.to(device=device, dtype=dtype), act_val_shape, act_axis) if len(src_axes) > 0: for name, (module, axis) in src_axes.items(): module_device = module.weight.device @@ -556,7 +563,7 @@ def _no_equalize(): channel_start = indexes.offset + indexes.start channel_end = indexes.offset + indexes.end partial_inverse_scale = inverse_scaling_factors[channel_start:channel_end].to( - device=module_device) + device=module_device, dtype=dtype) if hasattr(module, 'bias') and module.bias is not None: _update_weights( module, module.bias * partial_inverse_scale.view_as(module.bias), attr='bias') @@ -578,7 +585,7 @@ def _no_equalize(): # one (i.e., no equalization) partial_scaling[indexes.start:indexes.end] = scaling_factors[indexes.offset:indexes.offset + channel_range] - partial_scaling = partial_scaling.to(device=module_device) + partial_scaling = partial_scaling.to(device=module_device, dtype=dtype) _update_weights( module, module.weight * torch.reshape(partial_scaling, sink_broadcast_size), From 382d33811283e9a9b493a135aa5980cfa464c276 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 20 Jun 2024 13:15:10 +0100 Subject: [PATCH 2/2] missing upcasts --- src/brevitas/graph/equalize.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 601081676..31f295f5c 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -990,7 +990,8 @@ def forward_stats_hook(self, module, *args, name, batch_dim=0, use_inp=True, **k self.batch_dim_act_map[name] = batch_dim - input_scales = self.scale_fn(x, dim=batch_dim) + dtype = x.dtype + input_scales = self.scale_fn(x.to(torch.float32), dim=batch_dim).to(dtype) if name not in self.float_act_map: self.float_act_map[name] = input_scales else: