Skip to content

Commit

Permalink
Fix (graph/quant): Bugfix in blacklist matching in find_module (#1021)
Browse files Browse the repository at this point in the history
  • Loading branch information
nickfraser authored Sep 6, 2024
1 parent b889bb2 commit b714943
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/brevitas/graph/quantize_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ def find_module(
else:
for name, module in model.named_children():
full_name = prefix + '.' + name if prefix != '' else name
if name_blacklist is not None and name in name_blacklist:
if name_blacklist is not None and full_name in name_blacklist:
continue
find_module(module, layer_map, module_to_replace, name_blacklist, full_name)

Expand Down
39 changes: 39 additions & 0 deletions tests/brevitas/graph/test_quantize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest_cases
import torch.nn as nn

from brevitas.graph.quantize import layerwise_quantize


@pytest_cases.parametrize(
'kwargs',
[
{
'model': nn.Sequential(nn.Linear(2, 3)),
'name_blacklist': [],
'key': '0',
'expected': "<class 'brevitas.nn.quant_linear.QuantLinear'>"},
{
'model': nn.Sequential(nn.Linear(2, 3)),
'name_blacklist': ['0'],
'key': '0',
'expected': "<class 'torch.nn.modules.linear.Linear'>"},
{
'model': nn.Sequential(nn.Sequential(nn.Linear(2, 3))),
'name_blacklist': ['0.0'],
'key': '0.0',
'expected': "<class 'torch.nn.modules.linear.Linear'>"},])
def test_layerwise_quantize_blacklist(kwargs):
key = kwargs['key']
exp = kwargs['expected']
del kwargs['key']
del kwargs['expected']
qmodel = layerwise_quantize(**kwargs)
checked = False
found_names = []
for n, m in qmodel.named_modules():
found_names.append(n)
if n == key:
mt = str(type(m))
assert mt == exp, f"Expect module {n} to be type: {exp}, found type {mt}"
checked = True
assert checked, f"Layer named {key} not found. Layer names are: {found_names}"

0 comments on commit b714943

Please sign in to comment.