Skip to content

Commit

Permalink
Tests (Channel-Splitting/Equalize): adding more tests and cleaning up
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob committed Jan 23, 2024
1 parent 9fc32d0 commit 4ad3cdf
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 167 deletions.
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import math
from typing import Dict, List, Set, Tuple, Union
import warnings

import torch
import torch.nn as nn

from brevitas.fx import GraphModule
from brevitas.graph.base import GraphTransform
from brevitas.graph.equalize import _batch_norm
from brevitas.graph.equalize import _channel_maxabs
from brevitas.graph.equalize import _extract_regions
from brevitas.graph.equalize import _get_input_axis
Expand All @@ -18,6 +16,9 @@
_conv = (
nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)

_unsupported_layers = (
nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.LayerNorm, nn.MultiheadAttention)


def _channels_to_split(
sources: Dict[str, nn.Module],
Expand Down Expand Up @@ -64,15 +65,15 @@ def transpose_tensor(tensor: torch.Tensor, axis: int):
return tensor.permute(shape)


@torch.no_grad
def _split_channels(module, channels_to_split, split_input=False, split_factor=0.5) -> None:
"""
Splits the channels `channels_to_split` of the `weights`.
`split_input` specifies whether to split Input or Output channels.
Can also be used to duplicate a channel, just set split_factor to 1.
Returns: None
Given a module, this method splits the weight channels as proposed in https://arxiv.org/abs/1901.09504.
`split_factor` determines how to split the channels, `channels_to_split` is a list of channel indices.
If `split_input=True`, the input channels of the module are split, otherwise the output channels.
"""
weight = torch.clone(module.weight.data)
bias = torch.clone(module.bias.data) if module.bias is not None else None
weight = module.weight.data
bias = module.bias.data if module.bias is not None else None
num_added_channels = len(channels_to_split)

_get_axis = _get_input_axis if split_input else _get_output_axis
Expand All @@ -91,28 +92,29 @@ def _split_channels(module, channels_to_split, split_input=False, split_factor=0
weight_t = torch.cat([weight_t, split_channel], dim=0)

if bias is not None and not split_input:
channel = bias[id:id + 1] * split_factor
bias = torch.cat((bias[:id], channel, bias[id + 1:], channel))
bias[id] *= split_factor
split_channel = bias[id:id + 1]
bias = torch.cat((bias, split_channel))

# reshape weight_t back to orig shape with the added channels
del orig_shape[axis]
weight_t = weight_t.reshape(weight_t.size(0), *orig_shape)
weight_t = transpose_tensor(weight_t, axis)
module.weight.data = weight_t
if bias is not None:
module.bias.data = bias

if isinstance(module, _conv):
if split_input:
module.in_channels += num_added_channels
else:
module.out_channels += num_added_channels
else:
elif isinstance(module, nn.Linear):
if split_input:
module.in_features += num_added_channels
else:
module.out_features += num_added_channels

if bias is not None:
module.bias.data = bias


def _split_channels_region(
sources: Dict[str, nn.Module],
Expand All @@ -130,6 +132,7 @@ def _split_channels_region(
# input channels are split in half, output channels duplicated
for name, module in sinks.items():
_split_channels(module, channels_to_split, split_input=True)

for name, module in sources.items():
# duplicating output_channels for all modules in the source
_split_channels(module, channels_to_split, split_factor=1, split_input=False)
Expand All @@ -140,17 +143,17 @@ def _is_groupwise(module: nn.Module):
return isinstance(module, _conv) and module.groups > 1


def _is_batchnorm(module: nn.Module):
return isinstance(module, _batch_norm)
def _is_unsupported(module: nn.Module):
return isinstance(module, _unsupported_layers)


def _is_supported(srcs: List[nn.Module], sinks: List[nn.Module]) -> bool:
# groupwise convolutions are not supported so filter them out
if any(map(_is_groupwise, srcs + sinks)):
return False

# bn layers aren't allowed
if any(map(_is_batchnorm, sinks + srcs)):
# filter out unsupported layers
if any(map(_is_unsupported, sinks + srcs)):
return False

# check if OCs of sources are all equal
Expand All @@ -168,59 +171,62 @@ def _is_supported(srcs: List[nn.Module], sinks: List[nn.Module]) -> bool:

def _split(
model: GraphModule,
regions: Set[Tuple[str]],
regions: List[Region],
split_ratio: float,
split_criterion: str,
split_input: bool) -> GraphModule:
split_input: bool,
split_criterion: str = 'maxabs') -> GraphModule:
for i, region in enumerate(regions):

# check if region is suitable for channel splitting
sources = {src: region.get_module_from_name(src) for src in region.srcs_names}
sinks = {sink: region.get_module_from_name(sink) for sink in region.sinks_names}

if _is_supported(list(sources.values()), list(sinks.values())):
# get channels to split
channels_to_split = _channels_to_split(
sources=sources,
sinks=sinks,
split_criterion=split_criterion,
split_ratio=split_ratio,
split_input=split_input)
# splitting/duplicating channels
_split_channels_region(
sources=sources,
sinks=sinks,
channels_to_split=channels_to_split,
split_input=split_input)
# get channels to split
channels_to_split = _channels_to_split(
sources=sources,
sinks=sinks,
split_criterion=split_criterion,
split_ratio=split_ratio,
split_input=split_input)
# splitting/duplicating channels
_split_channels_region(
sources=sources,
sinks=sinks,
channels_to_split=channels_to_split,
split_input=split_input)

return model


def _clean_regions(regions: List[Region]):
"""
This method checks whether the list of regions is compatible with channel splitting.
If a module is in the sinks/sources of multiple regions, these regions will be removed.
Given a list of regions, this method removes all regions that are not compatible with channel splitting.
"""
# idea: map modules to their regions and check whether it appears in multiple regions
regions_to_del = set()
source_modules = dict()
sink_modules = dict()
for i, region in enumerate(regions):
# add srcs to source_modules
for src in region.srcs_names:
sources = {src: region.get_module_from_name(src) for src in region.srcs_names}
sinks = {sink: region.get_module_from_name(sink) for sink in region.sinks_names}

# a module cannot be in the sources (or sinks) of multiple regions
for src in sources.keys():
# if not yet in the dict, instantiate new list for keeping track
if src not in source_modules:
source_modules[src] = [i]
else:
# we know the module has been in sources before, so region needs to be deleted
source_modules[src].append(i)
regions_to_del.update({i, *source_modules[src]})
for sink in region.sinks_names:
regions_to_del.update({*source_modules[src]})
for sink in sinks.keys():
if sink not in sink_modules:
sink_modules[sink] = [i]
else:
sink_modules[sink].append(i)
regions_to_del.update({i, *sink_modules[sink]})
regions_to_del.update({*sink_modules[sink]})

# check for other unsupported
if not _is_supported(list(sources.values()), list(sinks.values())):
# add region to be deleted
regions_to_del.add(i)

regions = [regions[i] for i, _ in enumerate(regions) if i not in regions_to_del]
return regions
Expand Down
16 changes: 8 additions & 8 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ class EqualizationIndexes:

# Required for being hashable
@dataclass(eq=True, frozen=True)
class WeightBiasTuple:
weight: nn.Module = None
bias: nn.Module = None
class WeightBiasWrapper:
weight: torch.Tensor = None
bias: torch.Tensor = None


# Required for being hashable
Expand Down Expand Up @@ -430,7 +430,7 @@ def _no_equalize():
# For MultiheadAttention, we support only self-attetion
if isinstance(module, nn.MultiheadAttention) and module.in_proj_weight is not None:
# For sinks, we only need to modify the weight but not the bias
module = WeightBiasTuple(module.in_proj_weight)
module = WeightBiasWrapper(module.in_proj_weight)
elif isinstance(module, nn.MultiheadAttention) and module.in_proj_weight is None:
return _no_equalize()
sink_axes[name] = (module, axis)
Expand All @@ -452,7 +452,7 @@ def _no_equalize():

# Check if any of the axis is None, which means that the module is not supported.
# In that case, do not perform graph equalization
axes_to_check = [*src_axes.values(), *sink_axes.values()]
axes_to_check = [axis for _, axis in list(src_axes.values()) + list(sink_axes.values())]
if None in axes_to_check:
return _no_equalize()

Expand Down Expand Up @@ -481,7 +481,7 @@ def _no_equalize():
if any(shape_0 != shape for shape in list_of_act_val_shapes):
return _no_equalize()
list_of_act_val = [
transpose(WeightBiasTuple(act_val), act_axis) for act_val in list_of_act_val]
transpose(WeightBiasWrapper(act_val), act_axis) for act_val in list_of_act_val]
srcs_range = scale_fn(
torch.cat([act_val.reshape(act_val.size(0), -1) for act_val in list_of_act_val], 1))
else:
Expand Down Expand Up @@ -562,7 +562,7 @@ def _no_equalize():


def _update_weights(original_module, new_value, attr='weight'):
if isinstance(original_module, WeightBiasTuple):
if isinstance(original_module, WeightBiasWrapper):
setattr(getattr(original_module, attr), 'data', new_value)
else:
setattr(original_module, attr, nn.Parameter(new_value))
Expand Down Expand Up @@ -645,7 +645,7 @@ def get_weight_sink(module):
transpose = lambda weight, axis: weight if axis == 0 else weight.transpose(0, 1)
if isinstance(module, nn.MultiheadAttention) and not hasattr(module, 'in_proj_weight'):
raise RuntimeError("Configuration for Multiheadattention not supported")
weight = WeightBiasTuple(module.in_proj_weight).weight if isinstance(
weight = WeightBiasWrapper(module.in_proj_weight).weight if isinstance(
module, nn.MultiheadAttention) else module.weight
axis = _get_input_axis(module)
weight = transpose(weight, axis)
Expand Down
6 changes: 2 additions & 4 deletions src/brevitas/graph/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from brevitas.core.scaling.standalone import ParameterScaling
from brevitas.fx.brevitas_tracer import symbolic_trace
from brevitas.graph.base import ModuleToModuleByClass
from brevitas.graph.channel_splitting import RegionwiseChannelSplitting
from brevitas.graph.equalize import EqualizeGraph
from brevitas.graph.fixed_point import CollapseConsecutiveConcats
from brevitas.graph.fixed_point import MergeBatchNorm
Expand All @@ -26,7 +27,6 @@
from brevitas.graph.standardize import TorchFunctionalToModule
from brevitas.nn import quant_layer
import brevitas.nn as qnn
from brevitas.ptq_algorithms.channel_splitting import RegionwiseChannelSplitting
from brevitas.quant import Int8ActPerTensorFloat
from brevitas.quant import Int8ActPerTensorFloatMinMaxInit
from brevitas.quant import Int8WeightPerTensorFloat
Expand Down Expand Up @@ -267,10 +267,8 @@ def preprocess_for_quantize(
equalize_scale_computation: str = 'maxabs',
channel_splitting=False,
channel_splitting_ratio=0.02,
channel_splitting_grid_aware=False,
channel_splitting_split_input=True,
channel_splitting_criterion: str = 'maxabs',
channel_splitting_weight_bit_width=8):
channel_splitting_criterion: str = 'maxabs'):

training_state = model.training
model.eval()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def main():
model,
equalize_iters=args.graph_eq_iterations,
equalize_merge_bias=args.graph_eq_merge_bias,
merge_bn=not args.calibrate_bn,
merge_bn=args.merge_bn,
channel_splitting=args.channel_splitting,
channel_splitting_grid_aware=args.grid_aware,
channel_splitting_split_input=args.split_input,
Expand Down Expand Up @@ -436,9 +436,9 @@ def main():
iters=args.learned_round_iters,
optimizer_lr=args.learned_round_lr)

# if args.calibrate_bn:
# print("Calibrate BN:")
# calibrate_bn(calib_loader, quant_model)
if args.calibrate_bn:
print("Calibrate BN:")
calibrate_bn(calib_loader, quant_model)

if args.bias_corr:
print("Applying bias correction:")
Expand Down
47 changes: 46 additions & 1 deletion tests/brevitas/graph/equalization_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,49 @@ def forward(self, x):
return ResidualSrcsAndSinkModel


@pytest_cases.fixture
def convgroupconv_model():

class ConvGroupConvModel(nn.Module):

def __init__(self) -> None:
super().__init__()
self.conv = nn.Conv2d(3, 16, kernel_size=3)
self.conv_0 = nn.Conv2d(16, 32, kernel_size=1, groups=2)
self.conv_1 = nn.Conv2d(32, 64, kernel_size=1, groups=4)
self.relu = nn.ReLU()

def forward(self, x):
x = self.conv(x)
x = self.relu(x)
x = self.conv_0(x)
x = self.relu(x)
x = self.conv_1(x)
return x

return ConvGroupConvModel


@pytest_cases.fixture
def convtranspose_model():

class ConvTransposeModel(nn.Module):

def __init__(self) -> None:
super().__init__()
self.relu = nn.ReLU()
self.conv_0 = nn.ConvTranspose2d(in_channels=3, out_channels=8, kernel_size=3)
self.conv_1 = nn.ConvTranspose2d(in_channels=8, out_channels=32, kernel_size=3)

def forward(self, x):
x = self.conv_0(x)
x = self.relu(x)
x = self.conv_1(x)
return x

return ConvTransposeModel


list_of_fixtures = [
'residual_model',
'srcsinkconflict_model',
Expand All @@ -309,7 +352,9 @@ def forward(self, x):
'convdepthconv_model',
'linearmha_model',
'mhalinear_model',
'layernormmha_model']
'layernormmha_model',
'convgroupconv_model',
'convtranspose_model']

toy_model = fixture_union('toy_model', list_of_fixtures, ids=list_of_fixtures)

Expand Down
Loading

0 comments on commit 4ad3cdf

Please sign in to comment.