diff --git a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py index cf9dde3ce..0f2939acc 100644 --- a/src/brevitas_examples/common/learned_round/learned_round_optimizer.py +++ b/src/brevitas_examples/common/learned_round/learned_round_optimizer.py @@ -201,6 +201,7 @@ from torch.optim.lr_scheduler import LinearLR from torch.optim.lr_scheduler import LRScheduler from torch.optim.optimizer import Optimizer +from torch.optim.sgd import SGD from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataloader import RandomSampler from tqdm import tqdm @@ -211,6 +212,7 @@ from brevitas.graph.calibrate import DisableEnableQuantization from brevitas.graph.calibrate import restore_return_quant_tensor from brevitas.optim.sign_sgd import SignSGD +from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjectorBase from brevitas_examples.common.accelerate_utils.accelerate import offload_model from brevitas_examples.common.accelerate_utils.accelerate import remove_hooks from brevitas_examples.common.learned_round.learned_round_method import LearnedRound @@ -238,6 +240,24 @@ def _get_blocks(module: nn.Module): return blocks +def return_scale_parameters(block: nn.Module) -> List[nn.Parameter]: + + scale_parameters = [] + + def _get_scale_parameters(module: nn.Module): + for module_child in module.children(): + if isinstance(module, WeightQuantProxyFromInjectorBase): + for submodule_name, submodule in module_child.named_parameters(): + if submodule_name.endswith('scaling_impl.value'): + scale_parameters.append(submodule) + else: + _get_scale_parameters(module_child) + + # Run recursion from block + _get_scale_parameters(block) + return scale_parameters + + class StopFwdException(Exception): """Used to throw and catch an exception to stop traversing the graph.""" pass @@ -350,10 +370,12 @@ def __init__( learned_round_loss_class: Type[LearnedRoundLoss], *, optimizer_class: Type[Optimizer] = SignSGD, + scale_optimizer_class: Type[Optimizer] = SGD, lr_scheduler_class: Optional[Type[LRScheduler]] = LinearLR, optimizer_lr: float = 5e-3, batch_size: float = 8, iters: int = 200, + learn_scale: bool = False, use_best_model: bool = True, use_amp: bool = True, amp_dtype: torch.dtype = torch.float16, @@ -365,10 +387,12 @@ def __init__( ) -> None: self.learned_round = learned_round self.optimizer_class = optimizer_class + self.scale_optimizer_class = scale_optimizer_class self.lr_scheduler_class = lr_scheduler_class self.optimizer_lr = optimizer_lr self.batch_size = batch_size self.iters = iters + self.learn_scale = learn_scale self.use_best_model = use_best_model self.use_amp = use_amp self.amp_dtype = amp_dtype @@ -399,11 +423,25 @@ def _collect_round_params(self, block: nn.Module) -> Dict: params[n] = copy.deepcopy(m.state_dict()) return params - def _step(self, optimizer: Optimizer, lr_scheduler: LRScheduler) -> None: - optimizer.step() - optimizer.zero_grad() - if lr_scheduler: - lr_scheduler.step() + def _optim_step(self, *optimizers: Optimizer) -> None: + for optimizer in optimizers: + if optimizer: + optimizer.step() + optimizer.zero_grad() + + def _lr_sched_step(self, *lr_schedulers: LRScheduler) -> None: + for lr_scheduler in lr_schedulers: + if lr_scheduler: + lr_scheduler.step() + + def _step(self, optimizers: List[Optimizer], lr_schedulers: List[LRScheduler]) -> None: + for optimizer in optimizers: + if optimizer: + optimizer.step() + optimizer.zero_grad() + for lr_scheduler in lr_schedulers: + if lr_scheduler: + lr_scheduler.step() def _populate_cache( self, @@ -448,6 +486,7 @@ def _optimize_learned_round_block( cache: Cache, block_loss: LearnedRoundLoss, block_forward: Callable, + scale_params: Optional[nn.Parameter] = None, ) -> Tuple[float, float, int]: # Move block to GPU if available if torch.cuda.is_available(): @@ -474,6 +513,22 @@ def _optimize_learned_round_block( self.lr_scheduler_class(optimizer, **self.lr_scheduler_kwargs) if self.lr_scheduler_class else None) + # Initialize optimizer/LR scheduler for the scale parameters if enabled + if self.learn_scale and scale_params is not None: + optimizer_scale = self.scale_optimizer_class( + scale_params, + lr=self.optimizer_lr, + momentum=0.9, + **self.optimizer_kwargs, + ) + lr_scheduler_scale = ( + self.lr_scheduler_class( + optimizer_scale, start_factor=1, end_factor=0, total_iters=600) + if self.lr_scheduler_class else None) + else: + optimizer_scale = None + lr_scheduler_scale = None + # Variables needed for printing best_loss = torch.finfo(torch.float).max init_loss = -1.0 @@ -482,7 +537,7 @@ def _optimize_learned_round_block( # Dictionary to store the rounding parameters yielding the lowest # training loss optimal_rounding_params = {} - + torch.autograd.set_detect_anomaly(True) n_samples = len(cache) pbar = tqdm(range(self.iters), desc='') for i in pbar: @@ -512,7 +567,8 @@ def _optimize_learned_round_block( # Scale loss and perform gradient step loss = loss * self.loss_scaling_factor loss.backward() - self._step(optimizer, lr_scheduler) + self._optim_step(optimizer, optimizer_scale) + self._lr_sched_step(lr_scheduler, lr_scheduler_scale) # Update progress bar pbar.set_description("{}".format(block_loss.format_loss_components(*loss_components))) @@ -696,6 +752,9 @@ def apply_learned_round( # Remove hooks needed to offload the model blocks to cpu remove_hooks(model) + # Retrieve scales + scale_params = return_scale_parameters(block) + # The parameters of the block that are not part of the rounding quantizers # need to be frozen, as only the rounding needs to be optimized. block.eval() @@ -707,6 +766,10 @@ def apply_learned_round( block_learned_round_module.train() for params in block_learned_round_module.parameters(): params.requires_grad = True + # As well as the scale parameters, if enabled + if self.learn_scale: + for params in scale_params: + params.requires_grad = True # Move block to GPU if available if torch.cuda.is_available(): @@ -729,6 +792,7 @@ def apply_learned_round( cache=cache, block_loss=block_loss, block_forward=block_forward, + scale_params=scale_params, ) print( @@ -741,6 +805,8 @@ def apply_learned_round( block_learned_round_module.eval() for params in block_learned_round_module.parameters(): params.requires_grad = False + for params in scale_params: + params.requires_grad = False # Move the block back to CPU block.cpu() diff --git a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py index dd0842702..bf2e565cc 100644 --- a/src/brevitas_examples/llm/llm_quant/learned_round_utils.py +++ b/src/brevitas_examples/llm/llm_quant/learned_round_utils.py @@ -167,6 +167,7 @@ def apply_learned_round( lr_scheduler: Optional[str] = "linear", optimizer_lr: float = 5e-3, batch_size: int = 8, + learn_scale: bool = False, use_best_model: bool = True, use_amp: bool = True, amp_dtype: torch.dtype = torch.float16, @@ -203,6 +204,7 @@ def apply_learned_round( optimizer_lr=optimizer_lr, batch_size=batch_size, iters=iters, + learn_scale=learn_scale, use_best_model=use_best_model, use_amp=use_amp, amp_dtype=amp_dtype, diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index a92d253fa..bb5f915fc 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -371,7 +371,12 @@ def main(args): if args.learned_round: print("Applying learned round...") remove_hooks(model) - apply_learned_round(model, calibration_loader) + apply_learned_round( + model, + calibration_loader, + iters=args.learned_round_iters, + learn_scale=args.learned_round_scale, + ) print("Learned round applied.") model = offload_model(model) @@ -560,6 +565,15 @@ def parse_args(args): type=int, default=64, help='Group size for per_group input quantization. Default: 64.') + parser.add_argument( + '--learned-round-iters', + type=int, + default=200, + help='Number of iterations for learned round. Default: 200.') + parser.add_argument( + '--learned-round-scale', + action='store_true', + help='Learned scale factor together with round.') parser.add_argument( '--quantize-input-zero-point', action='store_true', help='Quantize input zero-point.') parser.add_argument(