Skip to content

Commit

Permalink
Feat(graph): better exclusion mechanism (#1003)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Aug 21, 2024
1 parent f1655b2 commit 6733ba2
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 9 deletions.
10 changes: 5 additions & 5 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,31 +1019,31 @@ def __init__(
self.blacklist_layers = blacklist_layers

regions: List[Region] = []
name = ''
self.find_module(model, name, regions)
self.find_module(model, regions)
self.regions = regions

if self.scale_computation_type == 'maxabs':
self.scale_fn = _channel_maxabs
elif self.scale_computation_type == 'range':
self.scale_fn = _channel_range

def find_module(self, model, name, regions: List):
def find_module(self, model, regions: List, prefix=''):
"""
Iterate through the model looking at immediate children of every module to look for supported modules.
This allows us to stop the search when we meet a top-level module that is supported.
"""
if isinstance(model,
_supported_layers) and not isinstance(model, _batch_norm + (nn.LayerNorm,)):
if self.blacklist_layers is not None and name in self.blacklist_layers:
if self.blacklist_layers is not None and prefix in self.blacklist_layers:
return
weight = get_weight_sink(model)
eq_indexes = EqualizationIndexes(0, weight.shape[0], 0)
region = Region(sinks={'sinks0': eq_indexes}, name_to_module={'sinks0': model})
regions.append(region)
else:
for name, module in model.named_children():
self.find_module(module, name, regions)
full_name = prefix + '.' + name if prefix != '' else name
self.find_module(module, regions, full_name)

def setup(self):
for region in self.regions:
Expand Down
6 changes: 4 additions & 2 deletions src/brevitas/graph/quantize_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,8 @@ def find_module(
model: nn.Module,
layer_map: Dict[nn.Module, Optional[Dict]],
module_to_replace: List,
name_blacklist):
name_blacklist,
prefix=''):
"""
Iterate through the model looking at immediate children of every module to look for supported modules.
This allows us to stop the search when we meet a top-level module that is supported.
Expand All @@ -514,9 +515,10 @@ def find_module(
module_to_replace.append(model)
else:
for name, module in model.named_children():
full_name = prefix + '.' + name if prefix != '' else name
if name_blacklist is not None and name in name_blacklist:
continue
find_module(module, layer_map, module_to_replace, name_blacklist)
find_module(module, layer_map, module_to_replace, name_blacklist, full_name)


def layerwise_layer_handler(
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas_examples/stable_diffusion/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,10 @@ def main(args):
non_blacklist = dict()
for name, _ in pipe.unet.named_modules():
if 'time_emb' in name:
blacklist.append(name.split('.')[-1])
blacklist.append(name)
else:
if isinstance(_, (torch.nn.Linear, torch.nn.Conv2d)):
name_to_add = name.split('.')[-1]
name_to_add = name
if name_to_add not in non_blacklist:
non_blacklist[name_to_add] = 1
else:
Expand Down

0 comments on commit 6733ba2

Please sign in to comment.