diff --git a/src/brevitas/export/onnx/standard/qcdq/handler.py b/src/brevitas/export/onnx/standard/qcdq/handler.py index 87f0184ea..642ae9174 100644 --- a/src/brevitas/export/onnx/standard/qcdq/handler.py +++ b/src/brevitas/export/onnx/standard/qcdq/handler.py @@ -74,42 +74,42 @@ def quantize_fn(self, x, scale, zero_point, dtype, axis): return QuantizeLinearFn.apply(x, scale, zero_point, dtype, axis) -class StdQCDQONNXWeightQuantProxyHandler(StdCDQCastONNXMixin, - QCDQWeightQuantProxyHandlerMixin, - ONNXBaseHandler): +class StdQCDQCastONNXWeightQuantProxyHandler(StdCDQCastONNXMixin, + QCDQWeightQuantProxyHandlerMixin, + ONNXBaseHandler): pass -class StdQCDQONNXDecoupledWeightQuantProxyHandler(StdCDQCastONNXMixin, - QCDQDecoupledWeightQuantProxyHandlerMixin, - ONNXBaseHandler): +class StdQCDQCastONNXDecoupledWeightQuantProxyHandler(StdCDQCastONNXMixin, + QCDQDecoupledWeightQuantProxyHandlerMixin, + ONNXBaseHandler): pass -class StdQCDQONNXDecoupledWeightQuantWithInputProxyHandler( +class StdQCDQCastONNXDecoupledWeightQuantWithInputProxyHandler( StdCDQCastONNXMixin, QCDQDecoupledWeightQuantWithInputProxyHandlerMixin, ONNXBaseHandler): pass -class StdQCDQONNXActQuantProxyHandler(StdQCDQCastONNXMixin, - QCDQActQuantProxyHandlerMixin, - ONNXBaseHandler): +class StdQCDQCastONNXActQuantProxyHandler(StdQCDQCastONNXMixin, + QCDQActQuantProxyHandlerMixin, + ONNXBaseHandler): pass -class StdQCDQONNXBiasQuantProxyHandler(StdDQCastONNXMixin, - QCDQBiasQuantProxyHandlerMixin, - ONNXBaseHandler): +class StdQCDQCastONNXBiasQuantProxyHandler(StdDQCastONNXMixin, + QCDQBiasQuantProxyHandlerMixin, + ONNXBaseHandler): pass -class StdQCDQONNXTruncQuantProxyHandler(StdQCDQCastONNXMixin, - QCDQTruncQuantProxyHandlerMixin, - ONNXBaseHandler): +class StdQCDQCastONNXTruncQuantProxyHandler(StdQCDQCastONNXMixin, + QCDQTruncQuantProxyHandlerMixin, + ONNXBaseHandler): pass -class StdQCDQONNXQuantLSTMLayerHandler(QuantLSTMLayerHandler): +class StdQCDQCastONNXQuantLSTMLayerHandler(QuantLSTMLayerHandler): def quantized_cell_symbolic_execution( self, diff --git a/src/brevitas/export/onnx/standard/qcdq/manager.py b/src/brevitas/export/onnx/standard/qcdq/manager.py index ad2e58dff..ec712672a 100644 --- a/src/brevitas/export/onnx/standard/qcdq/manager.py +++ b/src/brevitas/export/onnx/standard/qcdq/manager.py @@ -14,13 +14,13 @@ from ..function import IntClipFn from ..function import QuantizeLinearFn from ..manager import StdONNXBaseManager -from .handler import StdQCDQONNXActQuantProxyHandler -from .handler import StdQCDQONNXBiasQuantProxyHandler -from .handler import StdQCDQONNXDecoupledWeightQuantProxyHandler -from .handler import StdQCDQONNXDecoupledWeightQuantWithInputProxyHandler -from .handler import StdQCDQONNXQuantLSTMLayerHandler -from .handler import StdQCDQONNXTruncQuantProxyHandler -from .handler import StdQCDQONNXWeightQuantProxyHandler +from .handler import StdQCDQCastONNXActQuantProxyHandler +from .handler import StdQCDQCastONNXBiasQuantProxyHandler +from .handler import StdQCDQCastONNXDecoupledWeightQuantProxyHandler +from .handler import StdQCDQCastONNXDecoupledWeightQuantWithInputProxyHandler +from .handler import StdQCDQCastONNXQuantLSTMLayerHandler +from .handler import StdQCDQCastONNXTruncQuantProxyHandler +from .handler import StdQCDQCastONNXWeightQuantProxyHandler class StdQCDQONNXManager(StdONNXBaseManager): @@ -33,13 +33,13 @@ class StdQCDQONNXManager(StdONNXBaseManager): "eliminate_unused_initializer"] handlers = [ - StdQCDQONNXWeightQuantProxyHandler, - StdQCDQONNXBiasQuantProxyHandler, - StdQCDQONNXActQuantProxyHandler, - StdQCDQONNXDecoupledWeightQuantProxyHandler, - StdQCDQONNXTruncQuantProxyHandler, - StdQCDQONNXDecoupledWeightQuantWithInputProxyHandler, - StdQCDQONNXQuantLSTMLayerHandler] + StdQCDQCastONNXWeightQuantProxyHandler, + StdQCDQCastONNXBiasQuantProxyHandler, + StdQCDQCastONNXActQuantProxyHandler, + StdQCDQCastONNXDecoupledWeightQuantProxyHandler, + StdQCDQCastONNXTruncQuantProxyHandler, + StdQCDQCastONNXDecoupledWeightQuantWithInputProxyHandler, + StdQCDQCastONNXQuantLSTMLayerHandler] custom_fns = [ DebugMarkerFunction, diff --git a/src/brevitas/export/torch/qcdq/handler.py b/src/brevitas/export/torch/qcdq/handler.py index 988daf4a5..b3474a246 100644 --- a/src/brevitas/export/torch/qcdq/handler.py +++ b/src/brevitas/export/torch/qcdq/handler.py @@ -93,9 +93,9 @@ def forward(self, *args, **kwargs): return self.symbolic_execution(*args, **kwargs) -class TorchQCDQWeightQuantProxyHandler(TorchCDQCastMixin, - QCDQWeightQuantProxyHandlerMixin, - TorchQCDQHandler): +class TorchQCDQCastWeightQuantProxyHandler(TorchCDQCastMixin, + QCDQWeightQuantProxyHandlerMixin, + TorchQCDQHandler): @classmethod def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width): @@ -103,9 +103,9 @@ def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width): return _itemize_clip_bounds(clip_args) -class TorchQCDQDecoupledWeightQuantProxyHandler(TorchCDQCastMixin, - QCDQDecoupledWeightQuantProxyHandlerMixin, - TorchQCDQHandler): +class TorchQCDQCastDecoupledWeightQuantProxyHandler(TorchCDQCastMixin, + QCDQDecoupledWeightQuantProxyHandlerMixin, + TorchQCDQHandler): @classmethod def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width): @@ -113,7 +113,7 @@ def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width): return _itemize_clip_bounds(clip_args) -class TorchQCDQDecoupledWeightQuantWithInputProxyHandler( +class TorchQCDQCastDecoupledWeightQuantWithInputProxyHandler( TorchCDQCastMixin, QCDQDecoupledWeightQuantWithInputProxyHandlerMixin, TorchQCDQHandler): @classmethod @@ -122,9 +122,9 @@ def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width): return _itemize_clip_bounds(clip_args) -class TorchQCDQActQuantProxyHandler(TorchQCDQCastMixin, - QCDQActQuantProxyHandlerMixin, - TorchQCDQHandler): +class TorchQCDQCastActQuantProxyHandler(TorchQCDQCastMixin, + QCDQActQuantProxyHandlerMixin, + TorchQCDQHandler): @classmethod def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width): @@ -132,15 +132,15 @@ def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width): return _itemize_clip_bounds(clip_args) -class TorchQCDQBiasQuantProxyHandler(TorchDQCastMixin, - QCDQBiasQuantProxyHandlerMixin, - TorchQCDQHandler): +class TorchQCDQCastBiasQuantProxyHandler(TorchDQCastMixin, + QCDQBiasQuantProxyHandlerMixin, + TorchQCDQHandler): pass -class TorchQCDQTruncQuantProxyHandler(TorchQCDQCastMixin, - QCDQTruncQuantProxyHandlerMixin, - TorchQCDQHandler): +class TorchQCDQCastTruncQuantProxyHandler(TorchQCDQCastMixin, + QCDQTruncQuantProxyHandlerMixin, + TorchQCDQHandler): @classmethod def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width): diff --git a/src/brevitas/export/torch/qcdq/manager.py b/src/brevitas/export/torch/qcdq/manager.py index 62f3bedb9..1da072a2d 100644 --- a/src/brevitas/export/torch/qcdq/manager.py +++ b/src/brevitas/export/torch/qcdq/manager.py @@ -11,24 +11,24 @@ from brevitas.export.manager import BaseManager from brevitas.export.manager import ExportContext -from .handler import TorchQCDQActQuantProxyHandler -from .handler import TorchQCDQBiasQuantProxyHandler -from .handler import TorchQCDQDecoupledWeightQuantProxyHandler -from .handler import TorchQCDQDecoupledWeightQuantWithInputProxyHandler -from .handler import TorchQCDQTruncQuantProxyHandler -from .handler import TorchQCDQWeightQuantProxyHandler +from .handler import TorchQCDQCastActQuantProxyHandler +from .handler import TorchQCDQCastBiasQuantProxyHandler +from .handler import TorchQCDQCastDecoupledWeightQuantProxyHandler +from .handler import TorchQCDQCastDecoupledWeightQuantWithInputProxyHandler +from .handler import TorchQCDQCastTruncQuantProxyHandler +from .handler import TorchQCDQCastWeightQuantProxyHandler class TorchQCDQManager(BaseManager): target_name = 'torch' handlers = [ - TorchQCDQWeightQuantProxyHandler, - TorchQCDQDecoupledWeightQuantProxyHandler, - TorchQCDQDecoupledWeightQuantWithInputProxyHandler, - TorchQCDQActQuantProxyHandler, - TorchQCDQBiasQuantProxyHandler, - TorchQCDQTruncQuantProxyHandler] + TorchQCDQCastWeightQuantProxyHandler, + TorchQCDQCastDecoupledWeightQuantProxyHandler, + TorchQCDQCastDecoupledWeightQuantWithInputProxyHandler, + TorchQCDQCastActQuantProxyHandler, + TorchQCDQCastBiasQuantProxyHandler, + TorchQCDQCastTruncQuantProxyHandler] @classmethod def set_export_mode(cls, model: Module, enabled: bool):