diff --git a/tests/brevitas/core/test_float_quant.py b/tests/brevitas/core/test_float_quant.py index 2d4c829f0..16b8a4b5f 100644 --- a/tests/brevitas/core/test_float_quant.py +++ b/tests/brevitas/core/test_float_quant.py @@ -9,6 +9,7 @@ from brevitas.core.function_wrapper import FloatClamp from brevitas.core.function_wrapper import RoundSte from brevitas.core.function_wrapper import TensorClamp +from brevitas.core.function_wrapper.misc import Identity from brevitas.core.quant.float import FloatQuant from brevitas.core.scaling import ConstScaling from brevitas.core.scaling import FloatScaling @@ -32,6 +33,7 @@ def test_float_quant_defaults(minifloat_format): mantissa_bit_width=mantissa_bit_width, exponent_bias=exponent_bias, signed=signed, + input_view_impl=Identity(), float_clamp_impl=None) else: # init FloatClamp @@ -48,6 +50,7 @@ def test_float_quant_defaults(minifloat_format): exponent_bit_width=exponent_bit_width, mantissa_bit_width=mantissa_bit_width, exponent_bias=exponent_bias, + input_view_impl=Identity(), signed=signed, float_clamp_impl=float_clamp) assert isinstance(float_quant.float_to_int_impl, RoundSte) @@ -73,6 +76,7 @@ def test_float_to_quant_float(inp, minifloat_format): exponent_bit_width=exponent_bit_width, mantissa_bit_width=mantissa_bit_width, exponent_bias=exponent_bias, + input_view_impl=Identity(), signed=signed, float_clamp_impl=None) else: @@ -90,6 +94,7 @@ def test_float_to_quant_float(inp, minifloat_format): exponent_bit_width=exponent_bit_width, mantissa_bit_width=mantissa_bit_width, exponent_bias=exponent_bias, + input_view_impl=Identity(), signed=signed, float_clamp_impl=float_clamp) expected_out, *_ = float_quant(inp) @@ -115,6 +120,7 @@ def test_scaling_impls_called_once(inp, minifloat_format): mantissa_bit_width=mantissa_bit_width, exponent_bias=exponent_bias, signed=signed, + input_view_impl=Identity(), scaling_impl=scaling_impl, float_scaling_impl=float_scaling_impl, float_clamp_impl=None) @@ -132,6 +138,7 @@ def test_scaling_impls_called_once(inp, minifloat_format): mantissa_bit_width=mantissa_bit_width, exponent_bias=exponent_bias, signed=signed, + input_view_impl=Identity(), scaling_impl=scaling_impl, float_scaling_impl=float_scaling_impl, float_clamp_impl=float_clamp) @@ -162,6 +169,7 @@ def test_inner_scale(inp, minifloat_format, scale): mantissa_bit_width=mantissa_bit_width, exponent_bias=exponent_bias, signed=signed, + input_view_impl=Identity(), scaling_impl=scaling_impl, float_scaling_impl=float_scaling_impl, float_clamp_impl=None) @@ -179,6 +187,7 @@ def test_inner_scale(inp, minifloat_format, scale): mantissa_bit_width=mantissa_bit_width, exponent_bias=exponent_bias, signed=signed, + input_view_impl=Identity(), scaling_impl=scaling_impl, float_scaling_impl=float_scaling_impl, float_clamp_impl=float_clamp)