diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 2e0fe12d0..44cc7262e 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -163,8 +163,7 @@ 'sym': Fp8e4m3DynamicOCPActPerTensorFloat}}}}}}} -def quantize_model( - model, +def generate_quantizers( dtype, weight_bit_width, weight_param_method, @@ -174,7 +173,6 @@ def quantize_model( weight_group_size, quantize_weight_zero_point, weight_quant_format='int', - name_blacklist=None, input_bit_width=None, input_quant_format='', input_scale_precision=None, @@ -184,7 +182,6 @@ def quantize_model( input_quant_granularity=None, input_group_size=None, quantize_input_zero_point=False, - quantize_embedding=False, use_ocp=False, device=None, weight_kwargs=None, @@ -200,20 +197,20 @@ def quantize_model( weight_float_format = { 'exponent_bit_width': int(weight_quant_format[1]), 'mantissa_bit_width': int(weight_quant_format[3])} + ocp_weight_format = weight_quant_format + weight_quant_format = 'float' if use_ocp: weight_quant_format += '_ocp' - ocp_weight_format = weight_quant_format - weight_quant_format = 'float' else: weight_float_format = {} if re.compile(r'e[1-8]m[1-8]').match(input_quant_format): input_float_format = { 'exponent_bit_width': int(input_quant_format[1]), 'mantissa_bit_width': int(input_quant_format[3])} + ocp_input_format = input_quant_format + input_quant_format = 'float' if use_ocp: input_quant_format += '_ocp' - ocp_input_format = input_quant_format - input_quant_format = 'float' else: input_float_format = {} @@ -230,15 +227,15 @@ def quantize_model( input_scale_type][input_quant_type] elif input_bit_width is not None: if ocp_input_format: - input_quant = INPUT_QUANT_MAP[input_quant_format][ocp_input_format][input_scale_type][ + input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][ocp_input_format][ input_scale_precision][input_param_method][input_quant_granularity][ input_quant_type] # Some activations in MHA should always be symmetric - sym_input_quant = INPUT_QUANT_MAP[input_quant_format][ocp_input_format][ - input_scale_type][input_scale_precision][input_param_method][ + sym_input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][ + ocp_input_format][input_scale_precision][input_param_method][ input_quant_granularity]['sym'] - linear_input_quant = INPUT_QUANT_MAP[input_quant_format][ocp_input_format][ - input_scale_type][input_scale_precision][input_param_method][ + linear_input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][ + ocp_input_format][input_scale_precision][input_param_method][ input_quant_granularity][input_quant_type] else: input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][ @@ -365,6 +362,21 @@ def quantize_model( linear_input_quant = linear_input_quant.let( **{ 'group_dim': -1, 'group_size': input_group_size}) + return linear_input_quant, weight_quant, input_quant, q_scaled_quant, k_transposed_quant, v_quant, attn_output_weights_quant + + +def generate_quant_maps( + linear_input_quant, + weight_quant, + input_quant, + q_scaled_quant, + k_transposed_quant, + v_quant, + attn_output_weights_quant, + dtype, + device, + input_quant_format, + quantize_embedding): quant_linear_kwargs = { 'input_quant': linear_input_quant, @@ -380,7 +392,7 @@ def quantize_model( 'in_proj_bias_quant': None, 'softmax_input_quant': None, 'attn_output_weights_quant': attn_output_weights_quant, - 'attn_output_weights_signed': input_quant_format == 'float', + 'attn_output_weights_signed': 'float' in input_quant_format, 'q_scaled_quant': q_scaled_quant, 'k_transposed_quant': k_transposed_quant, 'v_quant': v_quant, @@ -406,7 +418,71 @@ def quantize_model( if quantize_embedding: quant_embedding_kwargs = {'weight_quant': weight_quant, 'dtype': dtype, 'device': device} layer_map[nn.Embedding] = (qnn.QuantEmbedding, quant_embedding_kwargs) + return layer_map + + +def quantize_model( + model, + dtype, + weight_bit_width, + weight_param_method, + weight_scale_precision, + weight_quant_type, + weight_quant_granularity, + weight_group_size, + quantize_weight_zero_point, + weight_quant_format='int', + name_blacklist=None, + input_bit_width=None, + input_quant_format='', + input_scale_precision=None, + input_scale_type=None, + input_param_method=None, + input_quant_type=None, + input_quant_granularity=None, + input_group_size=None, + quantize_input_zero_point=False, + quantize_embedding=False, + use_ocp=False, + device=None, + weight_kwargs=None, + input_kwargs=None): + linear_input_quant, weight_quant, input_quant, q_scaled_quant, k_transposed_quant, v_quant, attn_output_weights_quant = generate_quantizers( + dtype, + weight_bit_width, + weight_param_method, + weight_scale_precision, + weight_quant_type, + weight_quant_granularity, + weight_group_size, + quantize_weight_zero_point, + weight_quant_format, + input_bit_width, + input_quant_format, + input_scale_precision, + input_scale_type, + input_param_method, + input_quant_type, + input_quant_granularity, + input_group_size, + quantize_input_zero_point, + use_ocp, + device, + weight_kwargs, + input_kwargs) + layer_map = generate_quant_maps( + linear_input_quant, + weight_quant, + input_quant, + q_scaled_quant, + k_transposed_quant, + v_quant, + attn_output_weights_quant, + dtype, + device, + input_quant_format, + quantize_embedding) model = layerwise_quantize( model=model, compute_layer_map=layer_map, name_blacklist=name_blacklist) return model diff --git a/src/brevitas_examples/stable_diffusion/README.md b/src/brevitas_examples/stable_diffusion/README.md index ad2dc57a8..e1f98b7da 100644 --- a/src/brevitas_examples/stable_diffusion/README.md +++ b/src/brevitas_examples/stable_diffusion/README.md @@ -30,7 +30,7 @@ Activation quantization is optional, and disabled by default. To enable, set bot We support ONNX integer export, and we are planning to release soon export for floating point quantization (e.g., FP8). -To export the model with fp16 scale factors, enable `export-cuda-float16`. This will performing the tracing necessary for export on GPU, leaving the model in fp16. +To export the model with fp16 scale factors, disable `export-cpu-float32`. This will performing the tracing necessary for export on GPU, leaving the model in fp16. If the flag is not enabled, the model will be moved to CPU and cast to float32 before export because of missing CPU kernels in fp16. To use MLPerf inference setup, check and install the correct requirements specified in the `requirements.txt` file under mlperf_evaluation. @@ -70,7 +70,7 @@ usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT] [--gptq | --no-gptq] [--bias-correction | --no-bias-correction] [--dtype {float32,float16,bfloat16}] [--attention-slicing | --no-attention-slicing] - [--export-target {,torch,onnx}] + [--export-target {,onnx}] [--export-weight-q-node | --no-export-weight-q-node] [--conv-weight-bit-width CONV_WEIGHT_BIT_WIDTH] [--linear-weight-bit-width LINEAR_WEIGHT_BIT_WIDTH] @@ -93,15 +93,11 @@ usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT] [--weight-group-size WEIGHT_GROUP_SIZE] [--quantize-weight-zero-point | --no-quantize-weight-zero-point] [--quantize-input-zero-point | --no-quantize-input-zero-point] - [--export-cuda-float16 | --no-export-cuda-float16] + [--export-cpu-float32 | --no-export-cpu-float32] [--use-mlperf-inference | --no-use-mlperf-inference] [--use-ocp | --no-use-ocp] [--use-negative-prompts | --no-use-negative-prompts] [--dry-run | --no-dry-run] - [--quantize-time-emb | --no-quantize-time-emb] - [--quantize-conv-in | --no-quantize-conv-in] - [--quantize-input-time-emb | --no-quantize-input-time-emb] - [--quantize-input-conv-in | --no-quantize-input-conv-in] Stable Diffusion quantization @@ -160,7 +156,7 @@ options: --attention-slicing Enable Enable attention slicing. Default: Disabled --no-attention-slicing Disable Enable attention slicing. Default: Disabled - --export-target {,torch,onnx} + --export-target {,onnx} Target export flow. --export-weight-q-node Enable Enable export of floating point weights + QDQ @@ -224,10 +220,9 @@ options: Enable Quantize input zero-point. Default: Enabled --no-quantize-input-zero-point Disable Quantize input zero-point. Default: Enabled - --export-cuda-float16 - Enable Export FP16 on CUDA. Default: Disabled - --no-export-cuda-float16 - Disable Export FP16 on CUDA. Default: Disabled + --export-cpu-float32 Enable Export FP32 on CPU. Default: Disabled + --no-export-cpu-float32 + Disable Export FP32 on CPU. Default: Disabled --use-mlperf-inference Enable Evaluate FID score with MLPerf pipeline. Default: False @@ -248,23 +243,5 @@ options: calibration. Default: Disabled --no-dry-run Disable Generate a quantized model without any calibration. Default: Disabled - --quantize-time-emb Enable Quantize time embedding layers. Default: True - --no-quantize-time-emb - Disable Quantize time embedding layers. Default: True - --quantize-conv-in Enable Quantize first conv layer. Default: True - --no-quantize-conv-in - Disable Quantize first conv layer. Default: True - --quantize-input-time-emb - Enable Quantize input to time embedding layers. - Default: Disabled - --no-quantize-input-time-emb - Disable Quantize input to time embedding layers. - Default: Disabled - --quantize-input-conv-in - Enable Quantize input to first conv layer. Default: - Enabled - --no-quantize-input-conv-in - Disable Quantize input to first conv layer. Default: - Enabled ``` diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 8aa960d0f..af5b9203d 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -22,17 +22,18 @@ from brevitas.core.stats.stats_op import NegativeMinOrZero from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager -from brevitas.export.torch.qcdq.manager import TorchQCDQManager from brevitas.graph.calibrate import bias_correction_mode from brevitas.graph.calibrate import calibration_mode from brevitas.graph.calibrate import load_quant_model_mode from brevitas.graph.equalize import activation_equalization_mode from brevitas.graph.gptq import gptq_mode -from brevitas.inject.enum import QuantType +from brevitas.graph.quantize import layerwise_quantize from brevitas.inject.enum import StatsOp from brevitas.nn.equalized_layer import EqualizedModule -from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer +from brevitas.nn.quant_activation import QuantIdentity from brevitas.utils.torch_utils import KwargsForwardHook +from brevitas_examples.common.generative.quantize import generate_quant_maps +from brevitas_examples.common.generative.quantize import generate_quantizers from brevitas_examples.common.generative.quantize import quantize_model from brevitas_examples.common.parse_utils import add_bool_arg from brevitas_examples.common.parse_utils import quant_format_validator @@ -41,8 +42,6 @@ from brevitas_examples.stable_diffusion.sd_quant.constants import SD_2_1_EMBEDDINGS_SHAPE from brevitas_examples.stable_diffusion.sd_quant.constants import SD_XL_EMBEDDINGS_SHAPE from brevitas_examples.stable_diffusion.sd_quant.export import export_onnx -from brevitas_examples.stable_diffusion.sd_quant.export import export_torch_export -from brevitas_examples.stable_diffusion.sd_quant.utils import brevitas_proxy_inference_mode from brevitas_examples.stable_diffusion.sd_quant.utils import generate_latents from brevitas_examples.stable_diffusion.sd_quant.utils import generate_unet_21_rand_inputs from brevitas_examples.stable_diffusion.sd_quant.utils import generate_unet_xl_rand_inputs @@ -141,12 +140,9 @@ def main(args): calibration_prompts = CALIBRATION_PROMPTS if args.calibration_prompt_path is not None: calibration_prompts = load_calib_prompts(args.calibration_prompt_path) - prompts = list() - for i, v in enumerate(calibration_prompts): - if i == args.calibration_prompt: - break - prompts.append(v) - calibration_prompts = prompts + print(args.calibration_prompt, len(calibration_prompts)) + assert args.calibration_prompt <= len(calibration_prompts) , f"Only {len(calibration_prompts)} prompts are available" + calibration_prompts = calibration_prompts[:args.calibration_prompt] latents = None if args.path_to_latents is not None: @@ -176,15 +172,11 @@ def main(args): if args.prompt > 0 and not args.use_mlperf_inference: print(f"Running inference with prompt ...") - prompts = [] - for i, v in enumerate(TESTING_PROMPTS): - if i == args.prompt: - break - prompts.append(v) + testing_prompts = TESTING_PROMPTS[:args.prompt] float_images = run_test_inference( pipe, args.resolution, - prompts, + testing_prompts, test_seeds, output_dir, args.device, @@ -203,9 +195,7 @@ def main(args): # Extract list of layers to avoid blacklist = [] for name, _ in pipe.unet.named_modules(): - if 'time_emb' in name and not args.quantize_time_emb: - blacklist.append(name.split('.')[-1]) - if 'conv_in' in name and not args.quantize_conv_in: + if 'time_emb' in name: blacklist.append(name.split('.')[-1]) print(f"Blacklisted layers: {blacklist}") @@ -263,23 +253,12 @@ def input_bit_width(module): return args.linear_input_bit_width elif isinstance(module, nn.Conv2d): return args.conv_input_bit_width + elif isinstance(module, QuantIdentity): + return args.quant_identity_bit_width else: raise RuntimeError(f"Module {module} not supported.") input_kwargs = dict() - if args.linear_input_bit_width is None or args.conv_input_bit_width is None: - - @value - def input_quant_enabled(module): - if args.linear_input_bit_width is None and isinstance(module, nn.Linear): - return QuantType.FP - elif args.conv_input_bit_width is None and isinstance(module, nn.Conv2d): - return QuantType.FP - else: - return QuantType.INT - - input_kwargs['quant_type'] = input_quant_enabled - if args.input_scale_stats_op == 'minmax': @value @@ -303,11 +282,9 @@ def input_zp_stats_type(): input_kwargs['zero_point_stats_impl'] = input_zp_stats_type print("Applying model quantization...") - quantize_model( - pipe.unet, + quantizers = generate_quantizers( dtype=dtype, device=args.device, - name_blacklist=blacklist, weight_bit_width=weight_bit_width, weight_quant_format=args.weight_quant_format, weight_quant_type=args.weight_quant_type, @@ -326,23 +303,30 @@ def input_zp_stats_type(): input_quant_granularity=args.input_quant_granularity, use_ocp=args.use_ocp, input_kwargs=input_kwargs) + + layer_map = generate_quant_maps( + *quantizers, dtype, args.device, args.input_quant_format, False) + + linear_qkwargs = layer_map[torch.nn.Linear][1] + linear_qkwargs[ + 'input_quant'] = None if args.linear_input_bit_width is None else linear_qkwargs[ + 'input_quant'] + linear_qkwargs[ + 'weight_quant'] = None if args.linear_weight_bit_width == 0 else linear_qkwargs[ + 'weight_quant'] + layer_map[torch.nn.Linear] = (layer_map[torch.nn.Linear][0], linear_qkwargs) + + conv_qkwargs = layer_map[torch.nn.Conv2d][1] + conv_qkwargs['input_quant'] = None if args.conv_input_bit_width is None else conv_qkwargs[ + 'input_quant'] + conv_qkwargs['weight_quant'] = None if args.conv_weight_bit_width == 0 else conv_qkwargs[ + 'weight_quant'] + layer_map[torch.nn.Conv2d] = (layer_map[torch.nn.Conv2d][0], conv_qkwargs) + + pipe.unet = layerwise_quantize( + model=pipe.unet, compute_layer_map=layer_map, name_blacklist=blacklist) print("Model quantization applied.") - skipped_layers = [] - for name, module in pipe.unet.named_modules(): - if 'time_emb' in name and not args.quantize_input_time_emb: - if hasattr(module, 'input_quant'): - module.input_quant.quant_injector = module.input_quant.quant_injector.let( - **{'quant_type': QuantType.FP}) - module.input_quant.init_tensor_quant() - skipped_layers.append(name) - if 'conv_in' in name and not args.quantize_input_conv_in: - if hasattr(module, 'input_quant'): - module.input_quant.quant_injector = module.input_quant.quant_injector.let( - **{'quant_type': QuantType.FP}) - module.input_quant.init_tensor_quant() - skipped_layers.append(name) - print(f"Skipped input quantization for layers: {skipped_layers}") pipe.set_progress_bar_config(disable=True) if args.dry_run: @@ -427,15 +411,13 @@ def input_zp_stats_type(): compute_mlperf_fid(args.model, args.path_to_coco, pipe, args.prompt, output_dir) else: print(f"Computing accuracy on default prompt") - prompts = list() - for i, v in enumerate(TESTING_PROMPTS): - if i == args.prompt: - break - prompts.append(v) + testing_prompts = TESTING_PROMPTS[:args.prompt] + assert args.prompt <= len(TESTING_PROMPTS), f"Only {len(TESTING_PROMPTS)} prompts are available" + quant_images = run_test_inference( pipe, args.resolution, - prompts, + testing_prompts, test_seeds, output_dir, args.device, @@ -461,8 +443,8 @@ def input_zp_stats_type(): if args.export_target: # Move to cpu and to float32 to enable CPU export - if not (dtype == torch.float16 and args.export_cuda_float16): - pipe.unet.to('cpu').to(dtype) + if args.export_cpu_float32: + pipe.unet.to('cpu').to(torch.float32) pipe.unet.eval() device = next(iter(pipe.unet.parameters())).device dtype = next(iter(pipe.unet.parameters())).dtype @@ -487,13 +469,6 @@ def input_zp_stats_type(): export_manager = StdQCDQONNXManager export_manager.change_weight_export(export_weight_q_node=args.export_weight_q_node) export_onnx(pipe, trace_inputs, output_dir, export_manager) - if args.export_target == 'torch': - if args.weight_quant_granularity == 'per_group': - export_manager = BlockQuantProxyLevelManager - else: - export_manager = TorchQCDQManager - export_manager.change_weight_export(export_weight_q_node=args.export_weight_q_node) - export_torch_export(pipe, trace_inputs, output_dir, export_manager) if __name__ == "__main__": @@ -583,11 +558,7 @@ def input_zp_stats_type(): default=False, help='Enable attention slicing. Default: Disabled') parser.add_argument( - '--export-target', - type=str, - default='', - choices=['', 'torch', 'onnx'], - help='Target export flow.') + '--export-target', type=str, default='', choices=['', 'onnx'], help='Target export flow.') add_bool_arg( parser, 'export-weight-q-node', @@ -708,7 +679,7 @@ def input_zp_stats_type(): default=False, help='Quantize input zero-point. Default: Enabled') add_bool_arg( - parser, 'export-cuda-float16', default=False, help='Export FP16 on CUDA. Default: Disabled') + parser, 'export-cpu-float32', default=False, help='Export FP32 on CPU. Default: Disabled') add_bool_arg( parser, 'use-mlperf-inference', @@ -729,23 +700,6 @@ def input_zp_stats_type(): 'dry-run', default=False, help='Generate a quantized model without any calibration. Default: Disabled') - add_bool_arg( - parser, - 'quantize-time-emb', - default=True, - help='Quantize time embedding layers. Default: True') - add_bool_arg( - parser, 'quantize-conv-in', default=True, help='Quantize first conv layer. Default: True') - add_bool_arg( - parser, - 'quantize-input-time-emb', - default=False, - help='Quantize input to time embedding layers. Default: Disabled') - add_bool_arg( - parser, - 'quantize-input-conv-in', - default=True, - help='Quantize input to first conv layer. Default: Enabled') args = parser.parse_args() print("Args: " + str(vars(args))) main(args) diff --git a/src/brevitas_examples/stable_diffusion/sd_quant/export.py b/src/brevitas_examples/stable_diffusion/sd_quant/export.py index 7ce70e783..70c9dda75 100644 --- a/src/brevitas_examples/stable_diffusion/sd_quant/export.py +++ b/src/brevitas_examples/stable_diffusion/sd_quant/export.py @@ -11,27 +11,8 @@ from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode -class UnetExportWrapper(nn.Module): - - def __init__(self, unet): - super().__init__() - self.unet = unet - - def forward(self, *args, **kwargs): - return self.unet(*args, **kwargs, return_dict=False) - - def export_onnx(pipe, trace_inputs, output_dir, export_manager): output_path = os.path.join(output_dir, 'unet.onnx') print(f"Saving unet to {output_path} ...") with torch.no_grad(), brevitas_proxy_export_mode(pipe.unet, export_manager): torch.onnx.export(pipe.unet, args=trace_inputs, f=output_path) - - -def export_torch_export(pipe, trace_inputs, output_dir, export_manager): - output_path = os.path.join(output_dir, 'unet.onnx') - print(trace_inputs[1]) - print(f"Saving unet to {output_path} ...") - with torch.no_grad(), brevitas_proxy_export_mode(pipe.unet, export_manager): - torch.export.export( - UnetExportWrapper(pipe.unet), args=(trace_inputs[0],), kwargs=trace_inputs[1]) diff --git a/src/brevitas_examples/stable_diffusion/sd_quant/utils.py b/src/brevitas_examples/stable_diffusion/sd_quant/utils.py index 2700dd032..b2c30176f 100644 --- a/src/brevitas_examples/stable_diffusion/sd_quant/utils.py +++ b/src/brevitas_examples/stable_diffusion/sd_quant/utils.py @@ -3,106 +3,8 @@ SPDX-License-Identifier: MIT """ -from contextlib import contextmanager - import torch -from brevitas.export.common.handler.base import BaseHandler -from brevitas.export.manager import _set_proxy_export_handler -from brevitas.export.manager import _set_proxy_export_mode -from brevitas.export.manager import BaseManager -from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer -from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector - - -class InferenceWeightProxyHandler(BaseHandler): - handled_layer = WeightQuantProxyFromInjector - - def __init__(self): - super(InferenceWeightProxyHandler, self).__init__() - self.scale = None - self.zero_point = None - self.bit_width = None - self.dtype = None - self.float_weight = None - - def scaling_impl(self, proxy_module): - return proxy_module.tensor_quant.scaling_impl - - def zero_point_impl(self, proxy_module): - return proxy_module.tensor_quant.zero_point_impl - - def bit_width_impl(self, proxy_module): - return proxy_module.tensor_quant.msb_clamp_bit_width_impl - - def export_scale(self, proxy_module, bit_width): - scaling_impl = self.scaling_impl(proxy_module) - int_scaling_impl = proxy_module.tensor_quant.int_scaling_impl - int_threshold = int_scaling_impl(bit_width) - threshold = scaling_impl.stats_scaling_impl(scaling_impl.parameter_list_stats()) - return threshold / int_threshold - - def export_zero_point(self, proxy_module, weight, scale, bit_width): - zero_point_impl = self.zero_point_impl(proxy_module) - return zero_point_impl(weight, scale, bit_width) - - def prepare_for_export(self, module): - assert len(module.tracked_module_list) == 1, "Shared quantizers not supported." - self.bit_width = self.bit_width_impl(module)() - assert self.bit_width <= 8., "Only 8b or lower is supported." - quant_layer = module.tracked_module_list[0] - self.float_weight = quant_layer.quant_weight() - self.dtype = self.float_weight.value.dtype - # if (self.float_weight.zero_point != 0.).any(): - # self.zero_point = self.export_zero_point(module, quant_layer.weight, self.scale, self.bit_width).detach().cpu() - # self.scale = self.export_scale(module, self.bit_width).detach().cpu() - # quant_layer.weight.data = quant_layer.weight.data.cpu() - - def forward(self, x): - - return self.float_weight.value, self.float_weight.scale, self.float_weight.zero_point, self.bit_width - - -class InferenceWeightProxyManager(BaseManager): - handlers = [InferenceWeightProxyHandler] - - @classmethod - def set_export_handler(cls, module): - if hasattr(module, - 'requires_export_handler') and module.requires_export_handler and not isinstance( - module, (WeightQuantProxyFromInjector)): - return - _set_proxy_export_handler(cls, module) - - -def store_mapping_tensor_state_dict(model): - mapping = dict() - for module in model.modules(): - if isinstance(module, QuantWeightBiasInputOutputLayer): - mapping[module.weight.data_ptr()] = module.weight.device - return mapping - - -def restore_mapping(model, mapping): - for module in model.modules(): - if isinstance(module, QuantWeightBiasInputOutputLayer): - module.weight.data = module.weight.data.to(mapping[module.weight.data_ptr()]) - - -@contextmanager -def brevitas_proxy_inference_mode(model): - mapping = store_mapping_tensor_state_dict(model) - is_training = model.training - model.eval() - model.apply(InferenceWeightProxyManager.set_export_handler) - _set_proxy_export_mode(model, enabled=True, proxy_class=WeightQuantProxyFromInjector) - try: - yield model - finally: - restore_mapping(model, mapping) - _set_proxy_export_mode(model, enabled=False) - model.train(is_training) - def unet_input_shape(resolution): return (4, resolution // 8, resolution // 8)