From 2beafa17ac59d8a7b6e282591cfc3bf516e452a4 Mon Sep 17 00:00:00 2001 From: costigt-dev <156176839+costigt-dev@users.noreply.github.com> Date: Mon, 11 Mar 2024 14:25:38 +0000 Subject: [PATCH] removed original_cat --- src/brevitas/__init__.py | 23 ---------------------- src/brevitas/graph/quantize_impl.py | 13 ++++++------ src/brevitas/quant_tensor/torch_handler.py | 2 +- 3 files changed, 7 insertions(+), 31 deletions(-) diff --git a/src/brevitas/__init__.py b/src/brevitas/__init__.py index 0617fd82b..eddc35a02 100644 --- a/src/brevitas/__init__.py +++ b/src/brevitas/__init__.py @@ -23,29 +23,6 @@ else: torch_version = version.parse(torch.__version__) -original_cat = torch.cat -if torch_version < version.parse('1.7.0'): - from torch._overrides import handle_torch_function - from torch._overrides import has_torch_function - - @torch.jit.ignore - def unsupported_jit_cat(tensors, dim): - if not isinstance(tensors, (tuple, list)): - tensors = tuple(tensors) - return unsupported_jit_cat(tensors, dim) - if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors): - return handle_torch_function( - original_cat, relevant_args=tensors, tensors=tensors, dim=dim) - else: - return original_cat(tensors=tensors, dim=dim) - - def cat(tensors: List[Tensor], dim: int = 0) -> Tensor: - if not torch.jit.is_scripting(): - return unsupported_jit_cat(tensors, dim) - return original_cat(tensors, dim=dim) - - torch.cat = cat - try: __version__ = get_distribution(__name__).version except DistributionNotFound: diff --git a/src/brevitas/graph/quantize_impl.py b/src/brevitas/graph/quantize_impl.py index 4fc8e5c66..e76ba2dea 100644 --- a/src/brevitas/graph/quantize_impl.py +++ b/src/brevitas/graph/quantize_impl.py @@ -18,8 +18,6 @@ ADD_METHODS = ['add', 'add_'] -CAT = brevitas.original_cat - SIGN_PRESERVING_MODULES = ( nn.Dropout, nn.Dropout2d, @@ -87,7 +85,8 @@ def are_inputs_unsigned(model, node, is_unsigned_list, quant_act_map, unsigned_a else: is_unsigned_list.append(False) elif inp_node.op == 'call_function': - if inp_node.target in [torch.reshape, torch.flatten, torch.transpose, CAT] + ADD_FNS: + if inp_node.target in [torch.reshape, torch.flatten, torch.transpose, torch.cat + ] + ADD_FNS: are_inputs_unsigned( model, inp_node, is_unsigned_list, quant_act_map, unsigned_act_tuple) else: @@ -141,7 +140,7 @@ def are_inputs_quantized_and_aligned(model, node, quantized_modules_list, quant_ if inp_node.target in [torch.reshape, torch.flatten, torch.transpose]: are_inputs_quantized_and_aligned( model, inp_node, quantized_modules_list, quant_act_map, same_sign) - elif inp_node.target is CAT: + elif inp_node.target is torch.cat: are_inputs_quantized_and_aligned( model, inp_node, quantized_modules_list, quant_act_map, True) elif inp_node.target in ADD_FNS: @@ -281,7 +280,7 @@ def recursive_input_handler( quant_identity_map, align_input_quant_fn, align_sign) - elif inp_node.op == 'call_function' and inp_node.target is CAT: + elif inp_node.op == 'call_function' and inp_node.target is torch.cat: recursive_input_handler( model, inp_node, @@ -329,12 +328,12 @@ def residual_handler( def is_converged(model): for node in model.graph.nodes: - if (node.op == 'call_function' and node.target in ADD_FNS + [CAT] or + if (node.op == 'call_function' and node.target in ADD_FNS + [torch.cat] or node.op == 'call_method' and node.target in ADD_METHODS): rewriters = [] # If the op is CAT, check that inputs have same sign, and in recursive_input_handler # force that the sign is aligned - same_sign = node.target is CAT + same_sign = node.target is torch.cat # If input to the CAT or ADD node are quantized and aligned correctly, continue to # the next node diff --git a/src/brevitas/quant_tensor/torch_handler.py b/src/brevitas/quant_tensor/torch_handler.py index 3b64bca89..860a9e0fa 100644 --- a/src/brevitas/quant_tensor/torch_handler.py +++ b/src/brevitas/quant_tensor/torch_handler.py @@ -44,7 +44,7 @@ def transpose_handler(inp, *args, **kwargs): return inp.transpose(*args, **kwargs) -@implements(brevitas.original_cat) +@implements(torch.cat) def cat_handler(*args, **kwargs): from brevitas.quant_tensor import QuantTensor return QuantTensor.cat(*args, **kwargs)