diff --git a/src/brevitas/core/function_wrapper/learned_round.py b/src/brevitas/core/function_wrapper/learned_round.py index 2d3e76aeb..cfb1cfa5c 100644 --- a/src/brevitas/core/function_wrapper/learned_round.py +++ b/src/brevitas/core/function_wrapper/learned_round.py @@ -25,20 +25,21 @@ class LearnedRoundHardSigmoid(brevitas.jit.ScriptModule): def __init__(self, learned_round_zeta: float = 1.1, learned_round_gamma: float = -0.1) -> None: super(LearnedRoundHardSigmoid, self).__init__() - self.float_to_int_ste = floor_ste - self.is_p_value = True self.learned_round_zeta = learned_round_zeta self.learned_round_gamma = learned_round_gamma @brevitas.jit.script_method - def forward(self, x: torch.Tensor, training: bool) -> torch.Tensor: - p = torch.sigmoid(x) + def forward(self, p: torch.Tensor) -> torch.Tensor: + p = torch.sigmoid(p) p = p * (self.learned_round_zeta - self.learned_round_gamma) + self.learned_round_gamma p = torch.clamp(p, 0.0, 1.0) - if not training: + if not self.training: return p > 0.5 return p + def round_forward(self, x: torch.Tensor, p: torch.Tensor) -> torch.Tensor: + return floor_ste(x) + p + class LearnedRoundSigmoid(brevitas.jit.ScriptModule): """ @@ -49,17 +50,19 @@ class LearnedRoundSigmoid(brevitas.jit.ScriptModule): def __init__(self, learned_round_temperature: float = 1.) -> None: super(LearnedRoundSigmoid, self).__init__() assert learned_round_temperature != 0, 'Temperature should be different than 0' - self.float_to_int_ste = floor_ste - self.is_p_value = True self.learned_round_temperature = learned_round_temperature @brevitas.jit.script_method - def forward(self, x: torch.Tensor, training: bool) -> torch.Tensor: - if not training: - return x > 0 - p = torch.sigmoid(x / self.learned_round_temperature) + def forward(self, p: torch.Tensor) -> torch.Tensor: + if not self.training: + return p > 0 + p = torch.sigmoid(p / self.learned_round_temperature) return p + @brevitas.jit.script_method + def round_forward(self, x: torch.Tensor, p: torch.Tensor) -> torch.Tensor: + return floor_ste(x) + p + class LearnedRoundIdentity(brevitas.jit.ScriptModule): """ @@ -69,12 +72,14 @@ class LearnedRoundIdentity(brevitas.jit.ScriptModule): def __init__(self) -> None: super(LearnedRoundIdentity, self).__init__() - self.float_to_int_ste = round_ste - self.is_p_value = False @brevitas.jit.script_method - def forward(self, x: torch.Tensor, training: bool) -> torch.Tensor: - return x + def forward(self, p: torch.Tensor) -> torch.Tensor: + return p + + @brevitas.jit.script_method + def round_forward(self, x: torch.Tensor, p: torch.Tensor) -> torch.Tensor: + return round_ste(x + p) class LearnedRoundSte(brevitas.jit.ScriptModule): @@ -97,12 +102,10 @@ def __init__( @brevitas.jit.script_method def forward(self, x: torch.Tensor) -> torch.Tensor: - float_to_int_ste = self.learned_round_impl.float_to_int_ste - is_p_value = self.learned_round_impl.is_p_value - p = self.learned_round_impl(self.value, self.training) + p = self.learned_round_impl(self.value) p = self.tensor_slicer(p) p = (p.to(x.dtype)).view_as(x) - return float_to_int_ste(x) + p if is_p_value else float_to_int_ste(x + p) + return self.learned_round_impl.round_forward(x, p) def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, diff --git a/src/brevitas_examples/common/learned_round/learned_round_method.py b/src/brevitas_examples/common/learned_round/learned_round_method.py index 04e7b3818..ae6ab8392 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_method.py +++ b/src/brevitas_examples/common/learned_round/learned_round_method.py @@ -3,7 +3,7 @@ from abc import ABC from abc import abstractmethod -from typing import Dict, Generator, List, Optional, Tuple, Type +from typing import Callable, Dict, Generator, List, Optional, Tuple, Type import torch from torch import nn @@ -35,18 +35,72 @@ def format_loss_components(self, *args) -> str: pass +def learned_round_value_init_non_linear( + layer: nn.Module, + learned_round_zeta: float = 1.1, + learned_round_gamma: float = -0.1, + **learned_round_impl_kwargs, +) -> torch.Tensor: + floor_weight = torch.floor(layer.weight.data / layer.quant_weight().scale) + delta = (layer.weight.data / layer.quant_weight().scale) - floor_weight + value = -torch.log((learned_round_zeta - learned_round_gamma) / + (delta - learned_round_gamma) - 1) + return value + + +def learned_round_value_init_linear( + layer: nn.Module, + **learned_round_impl_kwargs, +) -> torch.Tensor: + value = torch.zeros_like(layer.weight.data) + return value + + +LEARNED_ROUND_VALUE_INIT_MAP = { + LearnedRoundImplType.HARD_SIGMOID.value: learned_round_value_init_non_linear, + LearnedRoundImplType.SIGMOID.value: learned_round_value_init_non_linear, + LearnedRoundImplType.IDENTITY.value: learned_round_value_init_linear,} + + class LearnedRound(ABC): + def __init__( + self, + learned_round_impl_type: LearnedRoundImplType = LearnedRoundImplType.HARD_SIGMOID, + learned_round_value_init_fn: Optional[Callable] = None, + **learned_round_impl_kwargs, + ) -> None: + self.learned_round_impl_type = learned_round_impl_type + self.learned_round_value_init_fn = learned_round_value_init_fn + self.learned_round_impl_kwargs = learned_round_impl_kwargs + + def learned_round_value_init( + self, + layer: nn.Module, + ) -> torch.Tensor: + # A custom initialization function for the learned round parameter can be passed + if self.learned_round_value_init_fn is not None: + return self.learned_round_value_init_fn(layer, **self.learned_round_impl_kwargs) + # If not provided, the default function, as defined in LEARNED_ROUND_VALUE_INIT_MAP + # is leveraged + return LEARNED_ROUND_VALUE_INIT_MAP[self.learned_round_impl_type.value]( + layer, **self.learned_round_impl_kwargs) + + def _insert_learned_round_quantizer_to_layer(self, layer: nn.Module) -> None: + value = self.learned_round_value_init(layer) + layer.weight_quant.quant_injector = layer.weight_quant.quant_injector.let( + float_to_int_impl_type=FloatToIntImplType.LEARNED_ROUND, + learned_round_impl_type=self.learned_round_impl_type, + learned_round_init=value, + **self.learned_round_impl_kwargs, + ) + layer.weight_quant.init_tensor_quant(preserve_state_dict=True) + def insert_learned_round_quantizers(self, model: nn.Module) -> None: for module in model.modules(): if isinstance(module, QuantWBIOL) and len( self.return_learned_round_quantizers(module)) == 0: self._insert_learned_round_quantizer_to_layer(module) - module.weight_quant.init_tensor_quant(preserve_state_dict=True) - - @abstractmethod - def _insert_learned_round_quantizer_to_layer(self, layer: nn.Module) -> None: - pass def return_learned_round_quantizers(self, block: nn.Module) -> List[nn.Module]: return [module for module in block.modules() if isinstance(module, LearnedRoundSte)] @@ -80,9 +134,9 @@ def __init__( warmup: float = 0.2, decay_start: float = 0.0, **kwargs) -> None: - # AdaRound operates in a layer-wise manner, so integrity needs to be checked - assert isinstance(module, QuantWBIOL), "AdaRound can only accept a single QuantWBIOL layer." - assert len(learned_round_modules) == 1, "AdaRound can only accept a single learned round module." + # This loss operates in a layer-wise manner, so integrity needs to be checked + assert isinstance(module, QuantWBIOL), "Regularised MSE loss can only accept a single QuantWBIOL layer." + assert len(learned_round_modules) == 1, "Regularised MSE loss can only accept a single learned round module." self.weight = weight self.module = module @@ -119,33 +173,6 @@ def format_loss_components(self, loss: float, rec_loss: float, round_loss: float b) -class AdaRound(LearnedRound): - - def __init__( - self, - learned_round_zeta: float = 1.1, - learned_round_gamma: float = -0.1, - learned_round_impl_type: LearnedRoundImplType = LearnedRoundImplType.HARD_SIGMOID, - **kwargs, - ) -> None: - # Quantiser-related configuration - self.learned_round_zeta = learned_round_zeta - self.learned_round_gamma = learned_round_gamma - self.learned_round_impl_type = learned_round_impl_type - - def _insert_learned_round_quantizer_to_layer(self, layer: nn.Module) -> None: - floor_weight = torch.floor(layer.weight.data / layer.quant_weight().scale) - delta = (layer.weight.data / layer.quant_weight().scale) - floor_weight - value = -torch.log((self.learned_round_zeta - self.learned_round_gamma) / - (delta - self.learned_round_gamma) - 1) - layer.weight_quant.quant_injector = layer.weight_quant.quant_injector.let( - float_to_int_impl_type=FloatToIntImplType.LEARNED_ROUND, - learned_round_impl_type=self.learned_round_impl_type, - learned_round_gamma=self.learned_round_gamma, - learned_round_zeta=self.learned_round_zeta, - learned_round_init=value) - - class MSELoss(LearnedRoundLoss): def __init__(self, block: nn.Module, learned_round_modules: List[nn.Module], **kwargs) -> None: @@ -157,17 +184,3 @@ def __call__(self, pred: torch.Tensor, tgt: torch.Tensor) -> Tuple[torch.Tensor, def format_loss_components(self, loss: float) -> str: return "Loss = {:.4f}".format(loss) - - -class AutoRound(LearnedRound): - - def __init__(self, **kwargs) -> None: - pass - - def _insert_learned_round_quantizer_to_layer(self, layer: nn.Module) -> None: - value = torch.zeros_like(layer.weight.data) - layer.weight_quant.quant_injector = layer.weight_quant.quant_injector.let( - float_to_int_impl_type=FloatToIntImplType.LEARNED_ROUND, - learned_round_impl_type=LearnedRoundImplType.IDENTITY, - learned_round_init=value, - ) diff --git a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py index 8d83d6510..aa943f473 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py +++ b/src/brevitas_examples/imagenet_classification/ptq/learned_round_utils.py @@ -38,11 +38,11 @@ from torch.utils.data.dataloader import DataLoader from brevitas import config +from brevitas.inject.enum import LearnedRoundImplType from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL from brevitas.optim.sign_sgd import SignSGD from brevitas.quant_tensor import QuantTensor -from brevitas_examples.common.learned_round.learned_round_method import AdaRound -from brevitas_examples.common.learned_round.learned_round_method import AutoRound +from brevitas_examples.common.learned_round.learned_round_method import LearnedRound from brevitas_examples.common.learned_round.learned_round_method import MSELoss from brevitas_examples.common.learned_round.learned_round_method import RegularisedMSELoss from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundOptimizer @@ -59,8 +59,9 @@ def is_layer(module: nn.Module, module_name: str) -> bool: LEARNED_ROUND_MAP = { - "auto_round": AutoRound, - "ada_round": AdaRound,} + "linear_round": LearnedRoundImplType.IDENTITY, + "hard_sigmoid_round": LearnedRoundImplType.HARD_SIGMOID, + "sigmoid_round": LearnedRoundImplType.SIGMOID,} LEARNED_ROUND_LOSS_MAP = { "mse": MSELoss, "regularised_mse": RegularisedMSELoss,} @@ -152,7 +153,7 @@ def apply_learned_round( model: nn.Module, calibration_loader: DataLoader, iters: int = 1000, - learned_round: str = "ada_round", + learned_round: str = "hard_sigmoid_round", learned_round_loss: str = "regularised_mse", optimizer: str = "adam", lr_scheduler: Optional[str] = None, @@ -169,7 +170,7 @@ def apply_learned_round( ) -> None: if learned_round not in LEARNED_ROUND_MAP: raise ValueError(f"Learned round method {learned_round} is not available.") - learned_round = LEARNED_ROUND_MAP[learned_round]() + learned_round = LearnedRound(learned_round_impl_type=LEARNED_ROUND_MAP[learned_round]) if learned_round_loss not in LEARNED_ROUND_LOSS_MAP: raise ValueError(f"Learned round loss {learned_round_loss} is not available.") @@ -221,4 +222,4 @@ def apply_learned_round( cache=cache, block_check_fn=block_check_fn, keep_gpu=True, - ) \ No newline at end of file + ) diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index 65b9e07c6..348213bdb 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -162,7 +162,7 @@ def validate_args(args): '--learned-round', default=None, type=str, - choices=[None, 'ada_round', 'auto_round'], + choices=[None, 'linear_round', 'hard_sigmoid_round', 'sigmoid_round'], help='Learned round type (default: None)') parser.add_argument( '--learned-round-loss', @@ -430,7 +430,7 @@ def main(): equalize_merge_bias=args.graph_eq_merge_bias, merge_bn=not args.calibrate_bn) elif args.target_backend == 'fx' or args.target_backend == 'layerwise': - if args.learned_round != "auto_round": + if args.learned_round_mode != "blockwise": model = preprocess_for_quantize( model, equalize_iters=args.graph_eq_iterations, diff --git a/src/brevitas_examples/llm/benchmark/llm_benchmark.py b/src/brevitas_examples/llm/benchmark/llm_benchmark.py index dec0e81c3..c21036be5 100644 --- a/src/brevitas_examples/llm/benchmark/llm_benchmark.py +++ b/src/brevitas_examples/llm/benchmark/llm_benchmark.py @@ -112,7 +112,8 @@ def unique(sequence): 'export_prefix': [None], # Path prefix to use for the various export flows. 'checkpoint_name': [None], # Filename to save checkpoint. 'fuse_sequences': [False], # Whether to merge the dataset sequences. - 'learned_round': [None, "auto_round"], # Whether to use learned round. If `None`, RTN is used. + 'learned_round': [None, + "linear_round"], # Whether to use learned round. If `None`, RTN is used. } parser = argparse.ArgumentParser(description='PyTorch LLM PTQ Validation') diff --git a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py index bf2e565cc..f402ce211 100644 --- a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py +++ b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py @@ -14,15 +14,15 @@ from transformers.models.llama.modeling_llama import LlamaDecoderLayer from transformers.models.opt.modeling_opt import OPTDecoderLayer +from brevitas.inject.enum import LearnedRoundImplType from brevitas.optim.sign_sgd import SignSGD -from brevitas_examples.common.learned_round.learned_round_method import AutoRound from brevitas_examples.common.learned_round.learned_round_method import LearnedRound from brevitas_examples.common.learned_round.learned_round_method import LearnedRoundLoss from brevitas_examples.common.learned_round.learned_round_method import MSELoss from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundOptimizer LEARNED_ROUND_MAP = { - "auto_round": AutoRound,} + "linear_round": LearnedRoundImplType.IDENTITY,} LEARNED_ROUND_LOSS_MAP = { "mse": MSELoss,} OPTIMIZER_MAP = { @@ -161,7 +161,7 @@ def apply_learned_round( model: nn.Module, calibration_loader: DataLoader, iters: int = 200, - learned_round: str = "auto_round", + learned_round: str = "linear_round", learned_round_loss: str = "mse", optimizer: str = "sign_sgd", lr_scheduler: Optional[str] = "linear", @@ -178,7 +178,7 @@ def apply_learned_round( ) -> None: if learned_round not in LEARNED_ROUND_MAP: raise ValueError(f"Learned round method {learned_round} is not available.") - learned_round = LEARNED_ROUND_MAP[learned_round]() + learned_round = LearnedRound(learned_round_impl_type=LEARNED_ROUND_MAP[learned_round]) if learned_round_loss not in LEARNED_ROUND_LOSS_MAP: raise ValueError(f"Learned round loss {learned_round_loss} is not available.") diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index bb5f915fc..574d1280d 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -684,7 +684,7 @@ def parse_args(args): parser.add_argument( '--learned-round', default=None, - choices=[None, 'auto_round'], + choices=[None, 'linear_round'], help='Whether to use learned round. If `None`, RTN is used (default: %(default)s)') return parser.parse_args(args) diff --git a/tests/brevitas/core/test_float_to_int.py b/tests/brevitas/core/test_float_to_int.py index 41b74b4d6..2701e78a3 100644 --- a/tests/brevitas/core/test_float_to_int.py +++ b/tests/brevitas/core/test_float_to_int.py @@ -22,16 +22,14 @@ LearnedRoundSigmoid(), # Sigmoid Implementation LearnedRoundSigmoid(learned_round_temperature=2.), # Sigmoid + Temperature LearnedRoundHardSigmoid(), # Hard Sigmoid - LearnedRoundIdentity(), # AutoRound Implement -] + LearnedRoundIdentity(),] class TestLearnedRound(): def instantiate_learnedround_float_to_int_impl(self, impl, weights, value): impl = LearnedRoundSte(impl, torch.full(weights.shape, 0.)) - # For methods with p_value=False, it is required that value is within [-0.5, 0.5] - if not impl.learned_round_impl.is_p_value: + if isinstance(impl.learned_round_impl, LearnedRoundIdentity): min_value, max_value = torch.min(value), torch.max(value) # Prevent division by zero when all the elements of the tensor are the same if max_value - min_value < 1e-8: @@ -61,8 +59,7 @@ def test_learnedround(self, impl, training, weights_value): out = impl(weights) # The FP values and its quantized values must differ by at most +/- 1 assert torch.all(torch.abs(out - weights) <= 1) - # For is_p_value=True, the rounding can be soft while training=True - if impl.learned_round_impl.is_p_value: + if not isinstance(impl.learned_round_impl, LearnedRoundIdentity): if training: # Soft quantization. All values are at most distant +/- 1 from the nearest integer assert torch.all(torch.abs(out - torch.round(out)) <= 1) @@ -70,7 +67,7 @@ def test_learnedround(self, impl, training, weights_value): # Hard quantization. All values are integers assert torch.allclose(out, torch.round(out)) else: - # All values should be integers when is_p_value=False + # All values should be integers for LearnedRoundIdentity assert torch.allclose(out, torch.round(out)) @given( @@ -87,8 +84,10 @@ def test_learnedround_float_to_int_impl_hard_sigmoid( learned_round_zeta=learned_round_zeta, learned_round_gamma=learned_round_gamma, ) - value_eval = learned_round_hard_sigmoid(value, training=False) - value_train = learned_round_hard_sigmoid(value, training=True) + learned_round_hard_sigmoid.train(False) + value_eval = learned_round_hard_sigmoid(value) + learned_round_hard_sigmoid.train(True) + value_train = learned_round_hard_sigmoid(value) out_eval = weight + value_eval out_train = weight + (value_train > 0.5) @@ -109,7 +108,7 @@ def learnedround_float_to_int_impl(self, impl): def test_learnedround_load_dict(self, learnedround_float_to_int_impl): config.IGNORE_MISSING_KEYS = True - impl, _ = learnedround_float_to_int_impl + impl, _, _ = learnedround_float_to_int_impl quant_conv = qnn.QuantConv2d(IN_CH, OUT_CH, KERNEL_SIZE, weight_float_to_int_impl=impl) fp_conv = torch.nn.Conv2d(IN_CH, OUT_CH, KERNEL_SIZE) try: diff --git a/tests/brevitas_examples/test_learned_round_utils.py b/tests/brevitas_examples/test_learned_round_utils.py index e9501668e..751d7790d 100644 --- a/tests/brevitas_examples/test_learned_round_utils.py +++ b/tests/brevitas_examples/test_learned_round_utils.py @@ -16,11 +16,11 @@ from brevitas import config from brevitas.core.function_wrapper.learned_round import LearnedRoundSte +from brevitas.inject.enum import LearnedRoundImplType import brevitas.nn as qnn from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL from brevitas.quant_tensor.base_quant_tensor import QuantTensor -from brevitas_examples.common.learned_round.learned_round_method import AdaRound -from brevitas_examples.common.learned_round.learned_round_method import AutoRound +from brevitas_examples.common.learned_round.learned_round_method import LearnedRound from brevitas_examples.common.learned_round.learned_round_optimizer import get_blocks from brevitas_examples.common.learned_round.learned_round_optimizer import save_inputs_output @@ -302,7 +302,11 @@ def model_forward(model, inputs): for cache_output, gt_output in zip(cache.output, fp_outs if disable_quant else quant_outs): _compare_tensors(cache_output, gt_output, disable_quant, keep_gpu) - @pytest.mark.parametrize("learned_round", [AutoRound(), AdaRound()]) + @pytest.mark.parametrize( + "learned_round", + [ + LearnedRound(learned_round_impl_type=LearnedRoundImplType.IDENTITY), + LearnedRound(learned_round_impl_type=LearnedRoundImplType.HARD_SIGMOID)]) def test_insert_learned_round_quantizers(self, quant_model, learned_round): block = quant_model.in_proj_mlp learned_round.insert_learned_round_quantizers(block) @@ -313,7 +317,11 @@ def test_insert_learned_round_quantizers(self, quant_model, learned_round): assert isinstance( module.weight_quant.tensor_quant.int_quant.float_to_int_impl, LearnedRoundSte) - @pytest.mark.parametrize("learned_round", [AutoRound(), AdaRound()]) + @pytest.mark.parametrize( + "learned_round", + [ + LearnedRound(learned_round_impl_type=LearnedRoundImplType.IDENTITY), + LearnedRound(learned_round_impl_type=LearnedRoundImplType.HARD_SIGMOID)]) @pytest.mark.parametrize( "block_strs, num_round_modules", [([], 0), (["hidden_mlp"], 2), (["in_proj_mlp", "out_proj_mlp"], 4)])