Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Nov 1, 2023
1 parent 81758b8 commit c1d945a
Showing 1 changed file with 80 additions and 90 deletions.
170 changes: 80 additions & 90 deletions tests/brevitas/graph/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def reference_implementation_scale_factors_po2(
return scale


@given(inp=float_tensor_random_size_st(min_val=9.999999747378752e-06, max_val=1e4))
@given(inp=float_tensor_random_size_st())
def test_scale_factors_ptq_calibration_po2(inp):

class TestModel(nn.Module):
Expand All @@ -52,22 +52,12 @@ def forward(self, x):

model = TestModel()
model.eval()

print("------------")
print(model.act.quant_act_scale())
print(inp)

with torch.no_grad():
with calibration_mode(model):
print("Forward")
model(inp)
print(model(inp))
print("End forward")

expected_scale = reference_implementation_scale_factors_po2(inp)
print(expected_scale)
scale = model.act.quant_act_scale()
print(scale)
print("------------")

assert torch.allclose(expected_scale, scale)

Expand All @@ -94,107 +84,107 @@ def forward(self, x):
assert model.training == False


# class TestBiasCorrection():
class TestBiasCorrection():

# @fixture
# def models(self):
@fixture
def models(self):

# class MyModel(nn.Module):
class MyModel(nn.Module):

# def __init__(self) -> None:
# super().__init__()
# self.module_list = nn.ModuleList([
# nn.Linear(IN_CH, OUT_CH, bias=False), nn.Linear(OUT_CH, OUT_CH, bias=False)])
def __init__(self) -> None:
super().__init__()
self.module_list = nn.ModuleList([
nn.Linear(IN_CH, OUT_CH, bias=False), nn.Linear(OUT_CH, OUT_CH, bias=False)])

# def forward(self, inp):
# out_0 = self.module_list[0](inp)
# out_1 = self.module_list[1](out_0)
# return torch.cat((out_0, out_1))
def forward(self, inp):
out_0 = self.module_list[0](inp)
out_1 = self.module_list[1](out_0)
return torch.cat((out_0, out_1))

# class MyQuantModel(nn.Module):
class MyQuantModel(nn.Module):

# def __init__(self) -> None:
# super().__init__()
# self.module_list = nn.ModuleList([
# qnn.QuantLinear(IN_CH, OUT_CH, bias=False),
# qnn.QuantLinear(OUT_CH, OUT_CH, bias=False)])
def __init__(self) -> None:
super().__init__()
self.module_list = nn.ModuleList([
qnn.QuantLinear(IN_CH, OUT_CH, bias=False),
qnn.QuantLinear(OUT_CH, OUT_CH, bias=False)])

# def forward(self, inp):
# out_0 = self.module_list[0](inp)
# out_1 = self.module_list[1](out_0)
# return torch.cat((out_0, out_1))
def forward(self, inp):
out_0 = self.module_list[0](inp)
out_1 = self.module_list[1](out_0)
return torch.cat((out_0, out_1))

# quant_model = MyQuantModel()
# model = MyModel()
quant_model = MyQuantModel()
model = MyModel()

# quant_model.module_list[0].weight.data = model.module_list[0].weight.data
# quant_model.module_list[1].weight.data = model.module_list[1].weight.data
# model.eval()
# quant_model.eval()
quant_model.module_list[0].weight.data = model.module_list[0].weight.data
quant_model.module_list[1].weight.data = model.module_list[1].weight.data
model.eval()
quant_model.eval()

# return model, quant_model
return model, quant_model

# def test_bias_correction_results(self, models):
# fp_model, quant_model = models
# num_layers = len(quant_model.module_list)
def test_bias_correction_results(self, models):
fp_model, quant_model = models
num_layers = len(quant_model.module_list)

# # Generate 2 random inputs (i.e., batch_size=2)
# inp_list = [torch.randn(BATCH, IN_CH), torch.randn(BATCH, IN_CH)]
# fp_outs = torch.zeros(len(inp_list), num_layers, OUT_CH)
# quant_outs = torch.zeros(len(inp_list), num_layers, OUT_CH)
# Generate 2 random inputs (i.e., batch_size=2)
inp_list = [torch.randn(BATCH, IN_CH), torch.randn(BATCH, IN_CH)]
fp_outs = torch.zeros(len(inp_list), num_layers, OUT_CH)
quant_outs = torch.zeros(len(inp_list), num_layers, OUT_CH)

# error = torch.zeros(num_layers, OUT_CH)
error = torch.zeros(num_layers, OUT_CH)

# # Reference Implementation of bias correction
# for b, inp in enumerate(inp_list):
# fp_outs[b, :, :] = fp_model(inp)
# Reference Implementation of bias correction
for b, inp in enumerate(inp_list):
fp_outs[b, :, :] = fp_model(inp)

# quant_outs[b, 0, :] = quant_model.module_list[0](inp)
# quant_outs[b, 1, :] = quant_model.module_list[1](
# fp_outs[b, 0, :]) # The second layer takes as input the "corrected" output
# error += fp_outs[b] - quant_outs[b]
quant_outs[b, 0, :] = quant_model.module_list[0](inp)
quant_outs[b, 1, :] = quant_model.module_list[1](
fp_outs[b, 0, :]) # The second layer takes as input the "corrected" output
error += fp_outs[b] - quant_outs[b]

# with bias_correction_mode(quant_model):
# for inp in inp_list:
# quant_model(inp)
with bias_correction_mode(quant_model):
for inp in inp_list:
quant_model(inp)

# assert quant_model.module_list[0].bias is not None
# assert quant_model.module_list[1].bias is not None
# assert torch.allclose(quant_model.module_list[0].bias, error[0] / len(inp_list))
# assert torch.allclose(quant_model.module_list[1].bias, error[1] / len(inp_list))
assert quant_model.module_list[0].bias is not None
assert quant_model.module_list[1].bias is not None
assert torch.allclose(quant_model.module_list[0].bias, error[0] / len(inp_list))
assert torch.allclose(quant_model.module_list[1].bias, error[1] / len(inp_list))

# def test_bias_correction_hook(self, models):
# fp_model, quant_model = models
# num_layers = len(quant_model.module_list)
def test_bias_correction_hook(self, models):
fp_model, quant_model = models
num_layers = len(quant_model.module_list)

# # Generate 2 random inputs (i.e., batch_size=2)
# inp_list = [torch.randn(BATCH, IN_CH), torch.randn(BATCH, IN_CH)]
# Generate 2 random inputs (i.e., batch_size=2)
inp_list = [torch.randn(BATCH, IN_CH), torch.randn(BATCH, IN_CH)]

# inputs = []
# outputs = []
inputs = []
outputs = []

# # If the user tries to modify the output with the forward_hook, this will be ignored
# # because overriden by our own forward_hook
# def simple_hook(mod, inp, out):
# inputs.append(*inp)
# outputs.append(*out)
# If the user tries to modify the output with the forward_hook, this will be ignored
# because overriden by our own forward_hook
def simple_hook(mod, inp, out):
inputs.append(*inp)
outputs.append(*out)

# fp_outs = torch.zeros(len(inp_list), num_layers, OUT_CH)
fp_outs = torch.zeros(len(inp_list), num_layers, OUT_CH)

# for b, inp in enumerate(inp_list):
# fp_outs[b, :, :] = fp_model(inp)
for b, inp in enumerate(inp_list):
fp_outs[b, :, :] = fp_model(inp)

# quant_model.module_list[1].register_forward_hook(
# simple_hook) # Register hook on the second layer
quant_model.module_list[1].register_forward_hook(
simple_hook) # Register hook on the second layer

# with bias_correction_mode(quant_model):
# for inp in inp_list:
# quant_model(inp)
with bias_correction_mode(quant_model):
for inp in inp_list:
quant_model(inp)

# assert len(
# outputs
# ) == 2 # Forward hook called only once per input, even though we performed 3 "forwards" per input
# assert (inputs[0] == fp_outs[0, 0, :]).all(
# ) # In bias_correction mode, the input to each layer is equal to the FP output of the previous layer
# assert (inputs[1] == fp_outs[1, 0, :]).all(
# ) # In bias_correction mode, the input to each layer is equal to the FP output of the previous layer
assert len(
outputs
) == 2 # Forward hook called only once per input, even though we performed 3 "forwards" per input
assert (inputs[0] == fp_outs[0, 0, :]).all(
) # In bias_correction mode, the input to each layer is equal to the FP output of the previous layer
assert (inputs[1] == fp_outs[1, 0, :]).all(
) # In bias_correction mode, the input to each layer is equal to the FP output of the previous layer

0 comments on commit c1d945a

Please sign in to comment.