Skip to content

Commit

Permalink
Renaming
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 17, 2023
1 parent 4f2a20b commit e3e9287
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 59 deletions.
34 changes: 17 additions & 17 deletions src/brevitas/export/onnx/standard/qcdq/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,42 +74,42 @@ def quantize_fn(self, x, scale, zero_point, dtype, axis):
return QuantizeLinearFn.apply(x, scale, zero_point, dtype, axis)


class StdQCDQONNXWeightQuantProxyHandler(StdCDQCastONNXMixin,
QCDQWeightQuantProxyHandlerMixin,
ONNXBaseHandler):
class StdQCDQCastONNXWeightQuantProxyHandler(StdCDQCastONNXMixin,
QCDQWeightQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQONNXDecoupledWeightQuantProxyHandler(StdCDQCastONNXMixin,
QCDQDecoupledWeightQuantProxyHandlerMixin,
ONNXBaseHandler):
class StdQCDQCastONNXDecoupledWeightQuantProxyHandler(StdCDQCastONNXMixin,
QCDQDecoupledWeightQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQONNXDecoupledWeightQuantWithInputProxyHandler(
class StdQCDQCastONNXDecoupledWeightQuantWithInputProxyHandler(
StdCDQCastONNXMixin, QCDQDecoupledWeightQuantWithInputProxyHandlerMixin, ONNXBaseHandler):
pass


class StdQCDQONNXActQuantProxyHandler(StdQCDQCastONNXMixin,
QCDQActQuantProxyHandlerMixin,
ONNXBaseHandler):
class StdQCDQCastONNXActQuantProxyHandler(StdQCDQCastONNXMixin,
QCDQActQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQONNXBiasQuantProxyHandler(StdDQCastONNXMixin,
QCDQBiasQuantProxyHandlerMixin,
ONNXBaseHandler):
class StdQCDQCastONNXBiasQuantProxyHandler(StdDQCastONNXMixin,
QCDQBiasQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQONNXTruncQuantProxyHandler(StdQCDQCastONNXMixin,
QCDQTruncQuantProxyHandlerMixin,
ONNXBaseHandler):
class StdQCDQCastONNXTruncQuantProxyHandler(StdQCDQCastONNXMixin,
QCDQTruncQuantProxyHandlerMixin,
ONNXBaseHandler):
pass


class StdQCDQONNXQuantLSTMLayerHandler(QuantLSTMLayerHandler):
class StdQCDQCastONNXQuantLSTMLayerHandler(QuantLSTMLayerHandler):

def quantized_cell_symbolic_execution(
self,
Expand Down
28 changes: 14 additions & 14 deletions src/brevitas/export/onnx/standard/qcdq/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
from ..function import IntClipFn
from ..function import QuantizeLinearFn
from ..manager import StdONNXBaseManager
from .handler import StdQCDQONNXActQuantProxyHandler
from .handler import StdQCDQONNXBiasQuantProxyHandler
from .handler import StdQCDQONNXDecoupledWeightQuantProxyHandler
from .handler import StdQCDQONNXDecoupledWeightQuantWithInputProxyHandler
from .handler import StdQCDQONNXQuantLSTMLayerHandler
from .handler import StdQCDQONNXTruncQuantProxyHandler
from .handler import StdQCDQONNXWeightQuantProxyHandler
from .handler import StdQCDQCastONNXActQuantProxyHandler
from .handler import StdQCDQCastONNXBiasQuantProxyHandler
from .handler import StdQCDQCastONNXDecoupledWeightQuantProxyHandler
from .handler import StdQCDQCastONNXDecoupledWeightQuantWithInputProxyHandler
from .handler import StdQCDQCastONNXQuantLSTMLayerHandler
from .handler import StdQCDQCastONNXTruncQuantProxyHandler
from .handler import StdQCDQCastONNXWeightQuantProxyHandler


class StdQCDQONNXManager(StdONNXBaseManager):
Expand All @@ -33,13 +33,13 @@ class StdQCDQONNXManager(StdONNXBaseManager):
"eliminate_unused_initializer"]

handlers = [
StdQCDQONNXWeightQuantProxyHandler,
StdQCDQONNXBiasQuantProxyHandler,
StdQCDQONNXActQuantProxyHandler,
StdQCDQONNXDecoupledWeightQuantProxyHandler,
StdQCDQONNXTruncQuantProxyHandler,
StdQCDQONNXDecoupledWeightQuantWithInputProxyHandler,
StdQCDQONNXQuantLSTMLayerHandler]
StdQCDQCastONNXWeightQuantProxyHandler,
StdQCDQCastONNXBiasQuantProxyHandler,
StdQCDQCastONNXActQuantProxyHandler,
StdQCDQCastONNXDecoupledWeightQuantProxyHandler,
StdQCDQCastONNXTruncQuantProxyHandler,
StdQCDQCastONNXDecoupledWeightQuantWithInputProxyHandler,
StdQCDQCastONNXQuantLSTMLayerHandler]

custom_fns = [
DebugMarkerFunction,
Expand Down
32 changes: 16 additions & 16 deletions src/brevitas/export/torch/qcdq/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,27 +93,27 @@ def forward(self, *args, **kwargs):
return self.symbolic_execution(*args, **kwargs)


class TorchQCDQWeightQuantProxyHandler(TorchCDQCastMixin,
QCDQWeightQuantProxyHandlerMixin,
TorchQCDQHandler):
class TorchQCDQCastWeightQuantProxyHandler(TorchCDQCastMixin,
QCDQWeightQuantProxyHandlerMixin,
TorchQCDQHandler):

@classmethod
def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width):
clip_args = super().int_clip_symbolic_kwargs(narrow, signed, bit_width)
return _itemize_clip_bounds(clip_args)


class TorchQCDQDecoupledWeightQuantProxyHandler(TorchCDQCastMixin,
QCDQDecoupledWeightQuantProxyHandlerMixin,
TorchQCDQHandler):
class TorchQCDQCastDecoupledWeightQuantProxyHandler(TorchCDQCastMixin,
QCDQDecoupledWeightQuantProxyHandlerMixin,
TorchQCDQHandler):

@classmethod
def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width):
clip_args = super().int_clip_symbolic_kwargs(narrow, signed, bit_width)
return _itemize_clip_bounds(clip_args)


class TorchQCDQDecoupledWeightQuantWithInputProxyHandler(
class TorchQCDQCastDecoupledWeightQuantWithInputProxyHandler(
TorchCDQCastMixin, QCDQDecoupledWeightQuantWithInputProxyHandlerMixin, TorchQCDQHandler):

@classmethod
Expand All @@ -122,25 +122,25 @@ def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width):
return _itemize_clip_bounds(clip_args)


class TorchQCDQActQuantProxyHandler(TorchQCDQCastMixin,
QCDQActQuantProxyHandlerMixin,
TorchQCDQHandler):
class TorchQCDQCastActQuantProxyHandler(TorchQCDQCastMixin,
QCDQActQuantProxyHandlerMixin,
TorchQCDQHandler):

@classmethod
def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width):
clip_args = super().int_clip_symbolic_kwargs(narrow, signed, bit_width)
return _itemize_clip_bounds(clip_args)


class TorchQCDQBiasQuantProxyHandler(TorchDQCastMixin,
QCDQBiasQuantProxyHandlerMixin,
TorchQCDQHandler):
class TorchQCDQCastBiasQuantProxyHandler(TorchDQCastMixin,
QCDQBiasQuantProxyHandlerMixin,
TorchQCDQHandler):
pass


class TorchQCDQTruncQuantProxyHandler(TorchQCDQCastMixin,
QCDQTruncQuantProxyHandlerMixin,
TorchQCDQHandler):
class TorchQCDQCastTruncQuantProxyHandler(TorchQCDQCastMixin,
QCDQTruncQuantProxyHandlerMixin,
TorchQCDQHandler):

@classmethod
def int_clip_symbolic_kwargs(cls, narrow, signed, bit_width):
Expand Down
24 changes: 12 additions & 12 deletions src/brevitas/export/torch/qcdq/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,24 @@
from brevitas.export.manager import BaseManager
from brevitas.export.manager import ExportContext

from .handler import TorchQCDQActQuantProxyHandler
from .handler import TorchQCDQBiasQuantProxyHandler
from .handler import TorchQCDQDecoupledWeightQuantProxyHandler
from .handler import TorchQCDQDecoupledWeightQuantWithInputProxyHandler
from .handler import TorchQCDQTruncQuantProxyHandler
from .handler import TorchQCDQWeightQuantProxyHandler
from .handler import TorchQCDQCastActQuantProxyHandler
from .handler import TorchQCDQCastBiasQuantProxyHandler
from .handler import TorchQCDQCastDecoupledWeightQuantProxyHandler
from .handler import TorchQCDQCastDecoupledWeightQuantWithInputProxyHandler
from .handler import TorchQCDQCastTruncQuantProxyHandler
from .handler import TorchQCDQCastWeightQuantProxyHandler


class TorchQCDQManager(BaseManager):
target_name = 'torch'

handlers = [
TorchQCDQWeightQuantProxyHandler,
TorchQCDQDecoupledWeightQuantProxyHandler,
TorchQCDQDecoupledWeightQuantWithInputProxyHandler,
TorchQCDQActQuantProxyHandler,
TorchQCDQBiasQuantProxyHandler,
TorchQCDQTruncQuantProxyHandler]
TorchQCDQCastWeightQuantProxyHandler,
TorchQCDQCastDecoupledWeightQuantProxyHandler,
TorchQCDQCastDecoupledWeightQuantWithInputProxyHandler,
TorchQCDQCastActQuantProxyHandler,
TorchQCDQCastBiasQuantProxyHandler,
TorchQCDQCastTruncQuantProxyHandler]

@classmethod
def set_export_mode(cls, model: Module, enabled: bool):
Expand Down

0 comments on commit e3e9287

Please sign in to comment.