Skip to content

Commit

Permalink
Enable scale tuning in learned round
Browse files Browse the repository at this point in the history
  • Loading branch information
pablomlago committed Nov 21, 2024
1 parent 65271c0 commit 39095c0
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@
from torch.optim.lr_scheduler import LinearLR
from torch.optim.lr_scheduler import LRScheduler
from torch.optim.optimizer import Optimizer
from torch.optim.sgd import SGD
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataloader import RandomSampler
from tqdm import tqdm
Expand All @@ -211,6 +212,7 @@
from brevitas.graph.calibrate import DisableEnableQuantization
from brevitas.graph.calibrate import restore_return_quant_tensor
from brevitas.optim.sign_sgd import SignSGD
from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjectorBase
from brevitas_examples.common.accelerate_utils.accelerate import offload_model
from brevitas_examples.common.accelerate_utils.accelerate import remove_hooks
from brevitas_examples.common.learned_round.learned_round_method import LearnedRound
Expand Down Expand Up @@ -238,6 +240,24 @@ def _get_blocks(module: nn.Module):
return blocks


def return_scale_parameters(block: nn.Module) -> List[nn.Parameter]:

scale_parameters = []

def _get_scale_parameters(module: nn.Module):
for module_child in module.children():
if isinstance(module, WeightQuantProxyFromInjectorBase):
for submodule_name, submodule in module_child.named_parameters():
if submodule_name.endswith('scaling_impl.value'):
scale_parameters.append(submodule)
else:
_get_scale_parameters(module_child)

# Run recursion from block
_get_scale_parameters(block)
return scale_parameters


class StopFwdException(Exception):
"""Used to throw and catch an exception to stop traversing the graph."""
pass
Expand Down Expand Up @@ -350,10 +370,12 @@ def __init__(
learned_round_loss_class: Type[LearnedRoundLoss],
*,
optimizer_class: Type[Optimizer] = SignSGD,
scale_optimizer_class: Type[Optimizer] = SGD,
lr_scheduler_class: Optional[Type[LRScheduler]] = LinearLR,
optimizer_lr: float = 5e-3,
batch_size: float = 8,
iters: int = 200,
learn_scale: bool = False,
use_best_model: bool = True,
use_amp: bool = True,
amp_dtype: torch.dtype = torch.float16,
Expand All @@ -365,10 +387,12 @@ def __init__(
) -> None:
self.learned_round = learned_round
self.optimizer_class = optimizer_class
self.scale_optimizer_class = scale_optimizer_class
self.lr_scheduler_class = lr_scheduler_class
self.optimizer_lr = optimizer_lr
self.batch_size = batch_size
self.iters = iters
self.learn_scale = learn_scale
self.use_best_model = use_best_model
self.use_amp = use_amp
self.amp_dtype = amp_dtype
Expand Down Expand Up @@ -399,11 +423,25 @@ def _collect_round_params(self, block: nn.Module) -> Dict:
params[n] = copy.deepcopy(m.state_dict())
return params

def _step(self, optimizer: Optimizer, lr_scheduler: LRScheduler) -> None:
optimizer.step()
optimizer.zero_grad()
if lr_scheduler:
lr_scheduler.step()
def _optim_step(self, *optimizers: Optimizer) -> None:
for optimizer in optimizers:
if optimizer:
optimizer.step()
optimizer.zero_grad()

def _lr_sched_step(self, *lr_schedulers: LRScheduler) -> None:
for lr_scheduler in lr_schedulers:
if lr_scheduler:
lr_scheduler.step()

def _step(self, optimizers: List[Optimizer], lr_schedulers: List[LRScheduler]) -> None:
for optimizer in optimizers:
if optimizer:
optimizer.step()
optimizer.zero_grad()
for lr_scheduler in lr_schedulers:
if lr_scheduler:
lr_scheduler.step()

def _populate_cache(
self,
Expand Down Expand Up @@ -448,6 +486,7 @@ def _optimize_learned_round_block(
cache: Cache,
block_loss: LearnedRoundLoss,
block_forward: Callable,
scale_params: Optional[nn.Parameter] = None,
) -> Tuple[float, float, int]:
# Move block to GPU if available
if torch.cuda.is_available():
Expand All @@ -474,6 +513,22 @@ def _optimize_learned_round_block(
self.lr_scheduler_class(optimizer, **self.lr_scheduler_kwargs)
if self.lr_scheduler_class else None)

# Initialize optimizer/LR scheduler for the scale parameters if enabled
if self.learn_scale and scale_params is not None:
optimizer_scale = self.scale_optimizer_class(
scale_params,
lr=self.optimizer_lr,
momentum=0.9,
**self.optimizer_kwargs,
)
lr_scheduler_scale = (
self.lr_scheduler_class(
optimizer_scale, start_factor=1, end_factor=0, total_iters=600)
if self.lr_scheduler_class else None)
else:
optimizer_scale = None
lr_scheduler_scale = None

# Variables needed for printing
best_loss = torch.finfo(torch.float).max
init_loss = -1.0
Expand All @@ -482,7 +537,7 @@ def _optimize_learned_round_block(
# Dictionary to store the rounding parameters yielding the lowest
# training loss
optimal_rounding_params = {}

torch.autograd.set_detect_anomaly(True)
n_samples = len(cache)
pbar = tqdm(range(self.iters), desc='')
for i in pbar:
Expand Down Expand Up @@ -512,7 +567,8 @@ def _optimize_learned_round_block(
# Scale loss and perform gradient step
loss = loss * self.loss_scaling_factor
loss.backward()
self._step(optimizer, lr_scheduler)
self._optim_step(optimizer, optimizer_scale)
self._lr_sched_step(lr_scheduler, lr_scheduler_scale)

# Update progress bar
pbar.set_description("{}".format(block_loss.format_loss_components(*loss_components)))
Expand Down Expand Up @@ -696,6 +752,9 @@ def apply_learned_round(
# Remove hooks needed to offload the model blocks to cpu
remove_hooks(model)

# Retrieve scales
scale_params = return_scale_parameters(block)

# The parameters of the block that are not part of the rounding quantizers
# need to be frozen, as only the rounding needs to be optimized.
block.eval()
Expand All @@ -707,6 +766,10 @@ def apply_learned_round(
block_learned_round_module.train()
for params in block_learned_round_module.parameters():
params.requires_grad = True
# As well as the scale parameters, if enabled
if self.learn_scale:
for params in scale_params:
params.requires_grad = True

# Move block to GPU if available
if torch.cuda.is_available():
Expand All @@ -729,6 +792,7 @@ def apply_learned_round(
cache=cache,
block_loss=block_loss,
block_forward=block_forward,
scale_params=scale_params,
)

print(
Expand All @@ -741,6 +805,8 @@ def apply_learned_round(
block_learned_round_module.eval()
for params in block_learned_round_module.parameters():
params.requires_grad = False
for params in scale_params:
params.requires_grad = False

# Move the block back to CPU
block.cpu()
Expand Down
2 changes: 2 additions & 0 deletions src/brevitas_examples/llm/llm_quant/learned_round_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def apply_learned_round(
lr_scheduler: Optional[str] = "linear",
optimizer_lr: float = 5e-3,
batch_size: int = 8,
learn_scale: bool = False,
use_best_model: bool = True,
use_amp: bool = True,
amp_dtype: torch.dtype = torch.float16,
Expand Down Expand Up @@ -203,6 +204,7 @@ def apply_learned_round(
optimizer_lr=optimizer_lr,
batch_size=batch_size,
iters=iters,
learn_scale=learn_scale,
use_best_model=use_best_model,
use_amp=use_amp,
amp_dtype=amp_dtype,
Expand Down
16 changes: 15 additions & 1 deletion src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,12 @@ def main(args):
if args.learned_round:
print("Applying learned round...")
remove_hooks(model)
apply_learned_round(model, calibration_loader)
apply_learned_round(
model,
calibration_loader,
iters=args.learned_round_iters,
learn_scale=args.learned_round_scale,
)
print("Learned round applied.")

model = offload_model(model)
Expand Down Expand Up @@ -560,6 +565,15 @@ def parse_args(args):
type=int,
default=64,
help='Group size for per_group input quantization. Default: 64.')
parser.add_argument(
'--learned-round-iters',
type=int,
default=200,
help='Number of iterations for learned round. Default: 200.')
parser.add_argument(
'--learned-round-scale',
action='store_true',
help='Learned scale factor together with round.')
parser.add_argument(
'--quantize-input-zero-point', action='store_true', help='Quantize input zero-point.')
parser.add_argument(
Expand Down

0 comments on commit 39095c0

Please sign in to comment.