From 6733ba2a4b512e580d702cc3e058a139fcc01c8c Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 21 Aug 2024 14:01:33 +0200 Subject: [PATCH] Feat(graph): better exclusion mechanism (#1003) --- src/brevitas/graph/equalize.py | 10 +++++----- src/brevitas/graph/quantize_impl.py | 6 ++++-- src/brevitas_examples/stable_diffusion/main.py | 4 ++-- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index 7729a3ecf..7f412148f 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -1019,8 +1019,7 @@ def __init__( self.blacklist_layers = blacklist_layers regions: List[Region] = [] - name = '' - self.find_module(model, name, regions) + self.find_module(model, regions) self.regions = regions if self.scale_computation_type == 'maxabs': @@ -1028,14 +1027,14 @@ def __init__( elif self.scale_computation_type == 'range': self.scale_fn = _channel_range - def find_module(self, model, name, regions: List): + def find_module(self, model, regions: List, prefix=''): """ Iterate through the model looking at immediate children of every module to look for supported modules. This allows us to stop the search when we meet a top-level module that is supported. """ if isinstance(model, _supported_layers) and not isinstance(model, _batch_norm + (nn.LayerNorm,)): - if self.blacklist_layers is not None and name in self.blacklist_layers: + if self.blacklist_layers is not None and prefix in self.blacklist_layers: return weight = get_weight_sink(model) eq_indexes = EqualizationIndexes(0, weight.shape[0], 0) @@ -1043,7 +1042,8 @@ def find_module(self, model, name, regions: List): regions.append(region) else: for name, module in model.named_children(): - self.find_module(module, name, regions) + full_name = prefix + '.' + name if prefix != '' else name + self.find_module(module, regions, full_name) def setup(self): for region in self.regions: diff --git a/src/brevitas/graph/quantize_impl.py b/src/brevitas/graph/quantize_impl.py index 42696efac..ed6382907 100644 --- a/src/brevitas/graph/quantize_impl.py +++ b/src/brevitas/graph/quantize_impl.py @@ -503,7 +503,8 @@ def find_module( model: nn.Module, layer_map: Dict[nn.Module, Optional[Dict]], module_to_replace: List, - name_blacklist): + name_blacklist, + prefix=''): """ Iterate through the model looking at immediate children of every module to look for supported modules. This allows us to stop the search when we meet a top-level module that is supported. @@ -514,9 +515,10 @@ def find_module( module_to_replace.append(model) else: for name, module in model.named_children(): + full_name = prefix + '.' + name if prefix != '' else name if name_blacklist is not None and name in name_blacklist: continue - find_module(module, layer_map, module_to_replace, name_blacklist) + find_module(module, layer_map, module_to_replace, name_blacklist, full_name) def layerwise_layer_handler( diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 42d874b4f..a1c4fef53 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -223,10 +223,10 @@ def main(args): non_blacklist = dict() for name, _ in pipe.unet.named_modules(): if 'time_emb' in name: - blacklist.append(name.split('.')[-1]) + blacklist.append(name) else: if isinstance(_, (torch.nn.Linear, torch.nn.Conv2d)): - name_to_add = name.split('.')[-1] + name_to_add = name if name_to_add not in non_blacklist: non_blacklist[name_to_add] = 1 else: