Skip to content

Commit

Permalink
Unified learned round methods
Browse files Browse the repository at this point in the history
  • Loading branch information
pablomlago committed Nov 21, 2024
1 parent 39095c0 commit 229098c
Show file tree
Hide file tree
Showing 9 changed files with 123 additions and 98 deletions.
41 changes: 22 additions & 19 deletions src/brevitas/core/function_wrapper/learned_round.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,21 @@ class LearnedRoundHardSigmoid(brevitas.jit.ScriptModule):

def __init__(self, learned_round_zeta: float = 1.1, learned_round_gamma: float = -0.1) -> None:
super(LearnedRoundHardSigmoid, self).__init__()
self.float_to_int_ste = floor_ste
self.is_p_value = True
self.learned_round_zeta = learned_round_zeta
self.learned_round_gamma = learned_round_gamma

@brevitas.jit.script_method
def forward(self, x: torch.Tensor, training: bool) -> torch.Tensor:
p = torch.sigmoid(x)
def forward(self, p: torch.Tensor) -> torch.Tensor:
p = torch.sigmoid(p)
p = p * (self.learned_round_zeta - self.learned_round_gamma) + self.learned_round_gamma
p = torch.clamp(p, 0.0, 1.0)
if not training:
if not self.training:
return p > 0.5
return p

def round_forward(self, x: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
return floor_ste(x) + p


class LearnedRoundSigmoid(brevitas.jit.ScriptModule):
"""
Expand All @@ -49,17 +50,19 @@ class LearnedRoundSigmoid(brevitas.jit.ScriptModule):
def __init__(self, learned_round_temperature: float = 1.) -> None:
super(LearnedRoundSigmoid, self).__init__()
assert learned_round_temperature != 0, 'Temperature should be different than 0'
self.float_to_int_ste = floor_ste
self.is_p_value = True
self.learned_round_temperature = learned_round_temperature

@brevitas.jit.script_method
def forward(self, x: torch.Tensor, training: bool) -> torch.Tensor:
if not training:
return x > 0
p = torch.sigmoid(x / self.learned_round_temperature)
def forward(self, p: torch.Tensor) -> torch.Tensor:
if not self.training:
return p > 0
p = torch.sigmoid(p / self.learned_round_temperature)
return p

@brevitas.jit.script_method
def round_forward(self, x: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
return floor_ste(x) + p


class LearnedRoundIdentity(brevitas.jit.ScriptModule):
"""
Expand All @@ -69,12 +72,14 @@ class LearnedRoundIdentity(brevitas.jit.ScriptModule):

def __init__(self) -> None:
super(LearnedRoundIdentity, self).__init__()
self.float_to_int_ste = round_ste
self.is_p_value = False

@brevitas.jit.script_method
def forward(self, x: torch.Tensor, training: bool) -> torch.Tensor:
return x
def forward(self, p: torch.Tensor) -> torch.Tensor:
return p

@brevitas.jit.script_method
def round_forward(self, x: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
return round_ste(x + p)


class LearnedRoundSte(brevitas.jit.ScriptModule):
Expand All @@ -97,12 +102,10 @@ def __init__(

@brevitas.jit.script_method
def forward(self, x: torch.Tensor) -> torch.Tensor:
float_to_int_ste = self.learned_round_impl.float_to_int_ste
is_p_value = self.learned_round_impl.is_p_value
p = self.learned_round_impl(self.value, self.training)
p = self.learned_round_impl(self.value)
p = self.tensor_slicer(p)
p = (p.to(x.dtype)).view_as(x)
return float_to_int_ste(x) + p if is_p_value else float_to_int_ste(x + p)
return self.learned_round_impl.round_forward(x, p)

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
Expand Down
113 changes: 63 additions & 50 deletions src/brevitas_examples/common/learned_round/learned_round_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from abc import ABC
from abc import abstractmethod
from typing import Dict, Generator, List, Optional, Tuple, Type
from typing import Callable, Dict, Generator, List, Optional, Tuple, Type

import torch
from torch import nn
Expand Down Expand Up @@ -35,18 +35,72 @@ def format_loss_components(self, *args) -> str:
pass


def learned_round_value_init_non_linear(
layer: nn.Module,
learned_round_zeta: float = 1.1,
learned_round_gamma: float = -0.1,
**learned_round_impl_kwargs,
) -> torch.Tensor:
floor_weight = torch.floor(layer.weight.data / layer.quant_weight().scale)
delta = (layer.weight.data / layer.quant_weight().scale) - floor_weight
value = -torch.log((learned_round_zeta - learned_round_gamma) /
(delta - learned_round_gamma) - 1)
return value


def learned_round_value_init_linear(
layer: nn.Module,
**learned_round_impl_kwargs,
) -> torch.Tensor:
value = torch.zeros_like(layer.weight.data)
return value


LEARNED_ROUND_VALUE_INIT_MAP = {
LearnedRoundImplType.HARD_SIGMOID.value: learned_round_value_init_non_linear,
LearnedRoundImplType.SIGMOID.value: learned_round_value_init_non_linear,
LearnedRoundImplType.IDENTITY.value: learned_round_value_init_linear,}


class LearnedRound(ABC):

def __init__(
self,
learned_round_impl_type: LearnedRoundImplType = LearnedRoundImplType.HARD_SIGMOID,
learned_round_value_init_fn: Optional[Callable] = None,
**learned_round_impl_kwargs,
) -> None:
self.learned_round_impl_type = learned_round_impl_type
self.learned_round_value_init_fn = learned_round_value_init_fn
self.learned_round_impl_kwargs = learned_round_impl_kwargs

def learned_round_value_init(
self,
layer: nn.Module,
) -> torch.Tensor:
# A custom initialization function for the learned round parameter can be passed
if self.learned_round_value_init_fn is not None:
return self.learned_round_value_init_fn(layer, **self.learned_round_impl_kwargs)
# If not provided, the default function, as defined in LEARNED_ROUND_VALUE_INIT_MAP
# is leveraged
return LEARNED_ROUND_VALUE_INIT_MAP[self.learned_round_impl_type.value](
layer, **self.learned_round_impl_kwargs)

def _insert_learned_round_quantizer_to_layer(self, layer: nn.Module) -> None:
value = self.learned_round_value_init(layer)
layer.weight_quant.quant_injector = layer.weight_quant.quant_injector.let(
float_to_int_impl_type=FloatToIntImplType.LEARNED_ROUND,
learned_round_impl_type=self.learned_round_impl_type,
learned_round_init=value,
**self.learned_round_impl_kwargs,
)
layer.weight_quant.init_tensor_quant(preserve_state_dict=True)

def insert_learned_round_quantizers(self, model: nn.Module) -> None:
for module in model.modules():
if isinstance(module, QuantWBIOL) and len(
self.return_learned_round_quantizers(module)) == 0:
self._insert_learned_round_quantizer_to_layer(module)
module.weight_quant.init_tensor_quant(preserve_state_dict=True)

@abstractmethod
def _insert_learned_round_quantizer_to_layer(self, layer: nn.Module) -> None:
pass

def return_learned_round_quantizers(self, block: nn.Module) -> List[nn.Module]:
return [module for module in block.modules() if isinstance(module, LearnedRoundSte)]
Expand Down Expand Up @@ -80,9 +134,9 @@ def __init__(
warmup: float = 0.2,
decay_start: float = 0.0,
**kwargs) -> None:
# AdaRound operates in a layer-wise manner, so integrity needs to be checked
assert isinstance(module, QuantWBIOL), "AdaRound can only accept a single QuantWBIOL layer."
assert len(learned_round_modules) == 1, "AdaRound can only accept a single learned round module."
# This loss operates in a layer-wise manner, so integrity needs to be checked
assert isinstance(module, QuantWBIOL), "Regularised MSE loss can only accept a single QuantWBIOL layer."
assert len(learned_round_modules) == 1, "Regularised MSE loss can only accept a single learned round module."

self.weight = weight
self.module = module
Expand Down Expand Up @@ -119,33 +173,6 @@ def format_loss_components(self, loss: float, rec_loss: float, round_loss: float
b)


class AdaRound(LearnedRound):

def __init__(
self,
learned_round_zeta: float = 1.1,
learned_round_gamma: float = -0.1,
learned_round_impl_type: LearnedRoundImplType = LearnedRoundImplType.HARD_SIGMOID,
**kwargs,
) -> None:
# Quantiser-related configuration
self.learned_round_zeta = learned_round_zeta
self.learned_round_gamma = learned_round_gamma
self.learned_round_impl_type = learned_round_impl_type

def _insert_learned_round_quantizer_to_layer(self, layer: nn.Module) -> None:
floor_weight = torch.floor(layer.weight.data / layer.quant_weight().scale)
delta = (layer.weight.data / layer.quant_weight().scale) - floor_weight
value = -torch.log((self.learned_round_zeta - self.learned_round_gamma) /
(delta - self.learned_round_gamma) - 1)
layer.weight_quant.quant_injector = layer.weight_quant.quant_injector.let(
float_to_int_impl_type=FloatToIntImplType.LEARNED_ROUND,
learned_round_impl_type=self.learned_round_impl_type,
learned_round_gamma=self.learned_round_gamma,
learned_round_zeta=self.learned_round_zeta,
learned_round_init=value)


class MSELoss(LearnedRoundLoss):

def __init__(self, block: nn.Module, learned_round_modules: List[nn.Module], **kwargs) -> None:
Expand All @@ -157,17 +184,3 @@ def __call__(self, pred: torch.Tensor, tgt: torch.Tensor) -> Tuple[torch.Tensor,

def format_loss_components(self, loss: float) -> str:
return "Loss = {:.4f}".format(loss)


class AutoRound(LearnedRound):

def __init__(self, **kwargs) -> None:
pass

def _insert_learned_round_quantizer_to_layer(self, layer: nn.Module) -> None:
value = torch.zeros_like(layer.weight.data)
layer.weight_quant.quant_injector = layer.weight_quant.quant_injector.let(
float_to_int_impl_type=FloatToIntImplType.LEARNED_ROUND,
learned_round_impl_type=LearnedRoundImplType.IDENTITY,
learned_round_init=value,
)
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@
from torch.utils.data.dataloader import DataLoader

from brevitas import config
from brevitas.inject.enum import LearnedRoundImplType
from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL
from brevitas.optim.sign_sgd import SignSGD
from brevitas.quant_tensor import QuantTensor
from brevitas_examples.common.learned_round.learned_round_method import AdaRound
from brevitas_examples.common.learned_round.learned_round_method import AutoRound
from brevitas_examples.common.learned_round.learned_round_method import LearnedRound
from brevitas_examples.common.learned_round.learned_round_method import MSELoss
from brevitas_examples.common.learned_round.learned_round_method import RegularisedMSELoss
from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundOptimizer
Expand All @@ -59,8 +59,9 @@ def is_layer(module: nn.Module, module_name: str) -> bool:


LEARNED_ROUND_MAP = {
"auto_round": AutoRound,
"ada_round": AdaRound,}
"linear_round": LearnedRoundImplType.IDENTITY,
"hard_sigmoid_round": LearnedRoundImplType.HARD_SIGMOID,
"sigmoid_round": LearnedRoundImplType.SIGMOID,}
LEARNED_ROUND_LOSS_MAP = {
"mse": MSELoss,
"regularised_mse": RegularisedMSELoss,}
Expand Down Expand Up @@ -152,7 +153,7 @@ def apply_learned_round(
model: nn.Module,
calibration_loader: DataLoader,
iters: int = 1000,
learned_round: str = "ada_round",
learned_round: str = "hard_sigmoid_round",
learned_round_loss: str = "regularised_mse",
optimizer: str = "adam",
lr_scheduler: Optional[str] = None,
Expand All @@ -169,7 +170,7 @@ def apply_learned_round(
) -> None:
if learned_round not in LEARNED_ROUND_MAP:
raise ValueError(f"Learned round method {learned_round} is not available.")
learned_round = LEARNED_ROUND_MAP[learned_round]()
learned_round = LearnedRound(learned_round_impl_type=LEARNED_ROUND_MAP[learned_round])

if learned_round_loss not in LEARNED_ROUND_LOSS_MAP:
raise ValueError(f"Learned round loss {learned_round_loss} is not available.")
Expand Down Expand Up @@ -221,4 +222,4 @@ def apply_learned_round(
cache=cache,
block_check_fn=block_check_fn,
keep_gpu=True,
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def validate_args(args):
'--learned-round',
default=None,
type=str,
choices=[None, 'ada_round', 'auto_round'],
choices=[None, 'linear_round', 'hard_sigmoid_round', 'sigmoid_round'],
help='Learned round type (default: None)')
parser.add_argument(
'--learned-round-loss',
Expand Down Expand Up @@ -430,7 +430,7 @@ def main():
equalize_merge_bias=args.graph_eq_merge_bias,
merge_bn=not args.calibrate_bn)
elif args.target_backend == 'fx' or args.target_backend == 'layerwise':
if args.learned_round != "auto_round":
if args.learned_round_mode != "blockwise":
model = preprocess_for_quantize(
model,
equalize_iters=args.graph_eq_iterations,
Expand Down
3 changes: 2 additions & 1 deletion src/brevitas_examples/llm/benchmark/llm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ def unique(sequence):
'export_prefix': [None], # Path prefix to use for the various export flows.
'checkpoint_name': [None], # Filename to save checkpoint.
'fuse_sequences': [False], # Whether to merge the dataset sequences.
'learned_round': [None, "auto_round"], # Whether to use learned round. If `None`, RTN is used.
'learned_round': [None,
"linear_round"], # Whether to use learned round. If `None`, RTN is used.
}

parser = argparse.ArgumentParser(description='PyTorch LLM PTQ Validation')
Expand Down
8 changes: 4 additions & 4 deletions src/brevitas_examples/llm/llm_quant/learned_round_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from transformers.models.opt.modeling_opt import OPTDecoderLayer

from brevitas.inject.enum import LearnedRoundImplType
from brevitas.optim.sign_sgd import SignSGD
from brevitas_examples.common.learned_round.learned_round_method import AutoRound
from brevitas_examples.common.learned_round.learned_round_method import LearnedRound
from brevitas_examples.common.learned_round.learned_round_method import LearnedRoundLoss
from brevitas_examples.common.learned_round.learned_round_method import MSELoss
from brevitas_examples.common.learned_round.learned_round_optimizer import LearnedRoundOptimizer

LEARNED_ROUND_MAP = {
"auto_round": AutoRound,}
"linear_round": LearnedRoundImplType.IDENTITY,}
LEARNED_ROUND_LOSS_MAP = {
"mse": MSELoss,}
OPTIMIZER_MAP = {
Expand Down Expand Up @@ -161,7 +161,7 @@ def apply_learned_round(
model: nn.Module,
calibration_loader: DataLoader,
iters: int = 200,
learned_round: str = "auto_round",
learned_round: str = "linear_round",
learned_round_loss: str = "mse",
optimizer: str = "sign_sgd",
lr_scheduler: Optional[str] = "linear",
Expand All @@ -178,7 +178,7 @@ def apply_learned_round(
) -> None:
if learned_round not in LEARNED_ROUND_MAP:
raise ValueError(f"Learned round method {learned_round} is not available.")
learned_round = LEARNED_ROUND_MAP[learned_round]()
learned_round = LearnedRound(learned_round_impl_type=LEARNED_ROUND_MAP[learned_round])

if learned_round_loss not in LEARNED_ROUND_LOSS_MAP:
raise ValueError(f"Learned round loss {learned_round_loss} is not available.")
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,7 @@ def parse_args(args):
parser.add_argument(
'--learned-round',
default=None,
choices=[None, 'auto_round'],
choices=[None, 'linear_round'],
help='Whether to use learned round. If `None`, RTN is used (default: %(default)s)')
return parser.parse_args(args)

Expand Down
Loading

0 comments on commit 229098c

Please sign in to comment.