diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index 3e820d71b..dd995ec6d 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -197,23 +197,23 @@ def catch_stopfwd(self, *args, **kwargs): class gpfq_mode(gpxq_mode): """ - Apply GPTQ algorithm https://arxiv.org/abs/2210.17323. + Apply GPFQ algorithm. Args: - model (Module): The model to quantize with GPTQ - inplace (bool): Wheter to apply GPTQ inplace or perform a deepcopy. Default: True + model (Module): The model to quantize with GPFQ + inplace (bool): Wheter to apply GPFQ inplace or perform a deepcopy. Default: True use_quant_activations (bool): Wheter to leave quantize activations enabled while performing - GPTQ. Default: False + GPFQ. Default: False Example: >>> with torch.no_grad(): - >>> with gptq_mode(model) as gptq: - >>> gptq_model = gptq.model - >>> for i in tqdm(range(gptq.num_layers)): + >>> with gpfq_mode(model) as gpfq: + >>> gpfq_model = gpfq.model + >>> for i in tqdm(range(gpfq.num_layers)): >>> for img, t in calib_loader: >>> img = img.cuda() - >>> gptq_model(img) - >>> gptq.update() + >>> gpfq_model(img) + >>> gpfq.update() """ def __init__(