From 23fb286b22b35f3ba0b66fa67b14ac6a58be9c57 Mon Sep 17 00:00:00 2001 From: pinto0309 Date: Thu, 13 Oct 2022 18:38:59 +0900 Subject: [PATCH] [WIP] Parameter replacement. Div,Gemm,Mul,Reshape,Resize,Sub,Transpose --- README.md | 18 ++++++++-- onnx2tf/__init__.py | 2 +- onnx2tf/onnx2tf.py | 4 +-- onnx2tf/ops/Div.py | 33 +++++++++++++----- onnx2tf/ops/Gemm.py | 57 ++++++++++++++++++++++++++++--- onnx2tf/ops/Mul.py | 19 ++++++++++- onnx2tf/ops/Reshape.py | 17 +++++++++ onnx2tf/ops/Resize.py | 49 +++++++++++++++++++++++++- onnx2tf/ops/Sub.py | 19 ++++++++++- onnx2tf/ops/Transpose.py | 16 +++++++-- onnx2tf/utils/common_functions.py | 33 +++++++++++++++++- onnx2tf/utils/enums.py | 20 ++++++++++- 12 files changed, 262 insertions(+), 25 deletions(-) diff --git a/README.md b/README.md index 9b018cc8..0bb02aef 100644 --- a/README.md +++ b/README.md @@ -41,7 +41,7 @@ Self-Created Tools to convert ONNX files (NCHW) to TensorFlow format (NHWC). The $ docker run --rm -it \ -v `pwd`:/workdir \ -w /workdir \ -ghcr.io/pinto0309/onnx2tf:0.0.25 +ghcr.io/pinto0309/onnx2tf:0.0.26 or @@ -72,6 +72,7 @@ usage: onnx2tf [-rlr] [-rpw] [-me] +[-prf PARAM_REPLACEMENT_FILE] [-n] optional arguments: @@ -130,6 +131,9 @@ optional arguments: (input_tensor - mean) / tf.sqrt(variance + mvn_epsilon) Default: 0.0000000001 + -prf PARAM_REPLACEMENT_FILE, --param_replacement_file PARAM_REPLACEMENT_FILE + Parameter replacement file path. (.json) + -n, --non_verbose Do not show all information logs. Only error logs are displayed. ``` @@ -155,6 +159,7 @@ convert( replace_leakyrelu_to_pseudo_leakyrelu: Union[bool, NoneType] = False, replace_power_to_pseudo_power: Optional[bool] = False, mvn_epsilon: Union[float, NoneType] = 0.0000000001, + param_replacement_file: Optional[str] = '', non_verbose: Union[bool, NoneType] = False ) -> keras.engine.training.Model @@ -222,6 +227,9 @@ convert( (input_tensor - mean) / tf.sqrt(variance + mvn_epsilon) Default: 0.0000000001 + param_replacement_file: Optional[str] + Parameter replacement file path. (.json) + non_verbose: Optional[bool] Do not show all information logs. Only error logs are displayed. Only one of replace_argmax_to_reducemax_and_indicies_is_int64 and @@ -235,7 +243,7 @@ convert( ``` ## [WIP] Parameter replacement -This tool is used to convert `NCW` to `NWC`, `NCHW` to `NHWC`, `NCDHW` to `NDHWC`, `NCDDHW` to `NDDHWC`, `NCDDDDDDHW` to `NDDDDDDHWC`. Therefore, as stated in the Key Concepts, the conversion will inevitably break down at some point in the model. You need to look at the entire conversion log to see which OP transpositions are failing and correct them yourself. I dare to explain very little because I know that no matter how much detail I put in the README, you guys will not read it at all. +This tool is used to convert `NCW` to `NWC`, `NCHW` to `NHWC`, `NCDHW` to `NDHWC`, `NCDDHW` to `NDDHWC`, `NCDDDDDDHW` to `NDDDDDDHWC`. Therefore, as stated in the Key Concepts, the conversion will inevitably break down at some point in the model. You need to look at the entire conversion log to see which OP transpositions are failing and correct them yourself. I dare to explain very little because I know that no matter how much detail I put in the README, you guys will not read it at all. `attribute` or `INPUT constant` or `INPUT Initializer` can be replaced with the specified value. "A conversion error occurs." Please don't post such low level questions as issues. @@ -255,6 +263,12 @@ This tool is used to convert `NCW` to `NWC`, `NCHW` to `NHWC`, `NCDHW` to `NDHWC "param_target": "attributes", # attributes or inputs "param_name": "axes", "values": [2] # Disable parameter transposition or overwrite parameters + }, + { + "op_name": "Resize__697", + "param_target": "inputs", + "param_name": "Concat__696:0", + "values": [26,26] # Replacement of unk__x (Resize OP, sizes height/width parameter) } ] } diff --git a/onnx2tf/__init__.py b/onnx2tf/__init__.py index 70121acc..b94c855e 100644 --- a/onnx2tf/__init__.py +++ b/onnx2tf/__init__.py @@ -1,3 +1,3 @@ from onnx2tf.onnx2tf import convert, main -__version__ = '0.0.25' +__version__ = '0.0.26' diff --git a/onnx2tf/onnx2tf.py b/onnx2tf/onnx2tf.py index 476f1af9..708d30d4 100644 --- a/onnx2tf/onnx2tf.py +++ b/onnx2tf/onnx2tf.py @@ -114,7 +114,7 @@ def convert( Default: 0.0000000001 param_replacement_file: Optional[str] - Parameter replacement path. (.json) + Parameter replacement file path. (.json) non_verbose: Optional[bool] Do not show all information logs. Only error logs are displayed.\n @@ -453,7 +453,7 @@ def main(): '--param_replacement_file', type=str, default='', - help='Parameter replacement path. (.json)' + help='Parameter replacement file path. (.json)' ) parser.add_argument( '-n', diff --git a/onnx2tf/ops/Div.py b/onnx2tf/ops/Div.py index 3e1fc595..5fa250db 100644 --- a/onnx2tf/ops/Div.py +++ b/onnx2tf/ops/Div.py @@ -5,6 +5,8 @@ import tensorflow as tf import onnx_graphsurgeon as gs from onnx2tf.utils.common_functions import ( + get_replacement_parameter, + replace_parameter, get_constant_or_variable, print_node_info, inverted_operation_enable_disable, @@ -14,6 +16,7 @@ @print_node_info @inverted_operation_enable_disable +@get_replacement_parameter def make_node( *, graph_node: gs.Node, @@ -57,6 +60,25 @@ def make_node( 'dtype': dtype, } + input_tensor_1 = tf_layers_dict[graph_node_input_1.name]['tf_node'] \ + if isinstance(graph_node_input_1, gs.Variable) else graph_node_input_1 + input_tensor_2 = tf_layers_dict[graph_node_input_2.name]['tf_node'] \ + if isinstance(graph_node_input_2, gs.Variable) else graph_node_input_2 + + # Param replacement + input_tensor_1 = replace_parameter( + value_before_replacement=input_tensor_1, + param_target='inputs', + param_name=graph_node.inputs[0].name, + **kwargs, + ) + input_tensor_2 = replace_parameter( + value_before_replacement=input_tensor_2, + param_target='inputs', + param_name=graph_node.inputs[1].name, + **kwargs, + ) + # Generation of TF OP target_cast_dtype = [ np.int8, @@ -65,16 +87,9 @@ def make_node( np.int64, ] - input_tensor_1 = tf_layers_dict[graph_node_input_1.name]['tf_node'] \ - if isinstance(graph_node_input_1, gs.Variable) else graph_node_input_1 - input_tensor_2 = tf_layers_dict[graph_node_input_2.name]['tf_node'] \ - if isinstance(graph_node_input_2, gs.Variable) else graph_node_input_2 - divided_tensor = tf.math.divide( - x=tf_layers_dict[graph_node_input_1.name]['tf_node'] \ - if isinstance(graph_node_input_1, gs.Variable) else graph_node_input_1, - y=tf_layers_dict[graph_node_input_2.name]['tf_node'] \ - if isinstance(graph_node_input_2, gs.Variable) else graph_node_input_2, + x=input_tensor_1, + y=input_tensor_2, name=graph_node.name, ) diff --git a/onnx2tf/ops/Gemm.py b/onnx2tf/ops/Gemm.py index 20ee2ef5..3c8444e9 100644 --- a/onnx2tf/ops/Gemm.py +++ b/onnx2tf/ops/Gemm.py @@ -5,6 +5,8 @@ import tensorflow as tf import onnx_graphsurgeon as gs from onnx2tf.utils.common_functions import ( + get_replacement_parameter, + replace_parameter, get_constant_or_variable, print_node_info, inverted_operation_enable_disable, @@ -14,6 +16,7 @@ @print_node_info @inverted_operation_enable_disable +@get_replacement_parameter def make_node( *, graph_node: gs.Node, @@ -71,10 +74,8 @@ def make_node( shape = graph_node_output.shape dtype = graph_node_output.dtype - if bool(graph_node.attrs.get('transA', 0)): - x = tf.transpose(x) - if bool(graph_node.attrs.get('transB', 0)): - y = tf.transpose(y) + transA = bool(graph_node.attrs.get('transA', 0)) + transB = bool(graph_node.attrs.get('transB', 0)) alpha = graph_node.attrs.get('alpha', 1.0) beta = graph_node.attrs.get('beta', 1.0) @@ -85,7 +86,55 @@ def make_node( 'dtype': dtype, } + # Param replacement + x = replace_parameter( + value_before_replacement=x, + param_target='inputs', + param_name=graph_node.inputs[0].name, + **kwargs, + ) + y = replace_parameter( + value_before_replacement=y, + param_target='inputs', + param_name=graph_node.inputs[1].name, + **kwargs, + ) + z = replace_parameter( + value_before_replacement=z, + param_target='inputs', + param_name=graph_node.inputs[2].name, + **kwargs, + ) + alpha = replace_parameter( + value_before_replacement=alpha, + param_target='attributes', + param_name='alpha', + **kwargs, + ) + beta = replace_parameter( + value_before_replacement=beta, + param_target='attributes', + param_name='beta', + **kwargs, + ) + transA = replace_parameter( + value_before_replacement=transA, + param_target='attributes', + param_name='transA', + **kwargs, + ) + transB = replace_parameter( + value_before_replacement=transB, + param_target='attributes', + param_name='transB', + **kwargs, + ) + # Generation of TF OP + if transA == True: + x = tf.transpose(x) + if transB == True: + y = tf.transpose(y) # We cast to either input or attribute data type to preserve precision if input_tensor_x_dtype in [tf.float64]: # cast to input data type diff --git a/onnx2tf/ops/Mul.py b/onnx2tf/ops/Mul.py index d3989149..83fb555a 100644 --- a/onnx2tf/ops/Mul.py +++ b/onnx2tf/ops/Mul.py @@ -5,6 +5,8 @@ import tensorflow as tf import onnx_graphsurgeon as gs from onnx2tf.utils.common_functions import ( + get_replacement_parameter, + replace_parameter, get_constant_or_variable, print_node_info, inverted_operation_enable_disable, @@ -14,6 +16,7 @@ @print_node_info @inverted_operation_enable_disable +@get_replacement_parameter def make_node( *, graph_node: gs.Node, @@ -57,12 +60,26 @@ def make_node( 'dtype': dtype, } - # Generation of TF OP input_tensor_1 = tf_layers_dict[graph_node_input_1.name]['tf_node'] \ if isinstance(graph_node_input_1, gs.Variable) else graph_node_input_1 input_tensor_2 = tf_layers_dict[graph_node_input_2.name]['tf_node'] \ if isinstance(graph_node_input_2, gs.Variable) else graph_node_input_2 + # Param replacement + input_tensor_1 = replace_parameter( + value_before_replacement=input_tensor_1, + param_target='inputs', + param_name=graph_node.inputs[0].name, + **kwargs, + ) + input_tensor_2 = replace_parameter( + value_before_replacement=input_tensor_2, + param_target='inputs', + param_name=graph_node.inputs[1].name, + **kwargs, + ) + + # Generation of TF OP tf_layers_dict[graph_node_output.name]['tf_node'] = \ tf.math.multiply( x=input_tensor_1, diff --git a/onnx2tf/ops/Reshape.py b/onnx2tf/ops/Reshape.py index c41b5dec..5a4d49b7 100644 --- a/onnx2tf/ops/Reshape.py +++ b/onnx2tf/ops/Reshape.py @@ -5,6 +5,8 @@ import tensorflow as tf import onnx_graphsurgeon as gs from onnx2tf.utils.common_functions import ( + get_replacement_parameter, + replace_parameter, get_constant_or_variable, convert_axis, print_node_info, @@ -16,6 +18,7 @@ @print_node_info @inverted_operation_enable_disable +@get_replacement_parameter def make_node( *, graph_node: gs.Node, @@ -106,6 +109,20 @@ def make_node( else: transposed_reshape_shape = reshape_shape + # Param replacement + transposed_tensor = replace_parameter( + value_before_replacement=transposed_tensor, + param_target='inputs', + param_name=graph_node.inputs[0].name, + **kwargs, + ) + transposed_reshape_shape = replace_parameter( + value_before_replacement=transposed_reshape_shape, + param_target='inputs', + param_name=graph_node.inputs[1].name, + **kwargs, + ) + # Reshape tf_layers_dict[graph_node_output.name]['tf_node'] = \ tf.reshape( diff --git a/onnx2tf/ops/Resize.py b/onnx2tf/ops/Resize.py index c481f6af..dcbe945a 100644 --- a/onnx2tf/ops/Resize.py +++ b/onnx2tf/ops/Resize.py @@ -6,6 +6,8 @@ from tensorflow.keras.layers import Lambda # type: ignore import onnx_graphsurgeon as gs from onnx2tf.utils.common_functions import ( + get_replacement_parameter, + replace_parameter, get_constant_or_variable, print_node_info, inverted_operation_enable_disable, @@ -15,6 +17,7 @@ @print_node_info @inverted_operation_enable_disable +@get_replacement_parameter def make_node( *, graph_node: gs.Node, @@ -84,7 +87,6 @@ def make_node( } - def upsampling2d_bilinear(input_tensor, new_size, align_corners, half_pixel_centers, name): return tf.compat.v1.image.resize_bilinear( images=input_tensor, @@ -151,6 +153,51 @@ def upsampling2d_nearest(input_tensor, new_size, align_corners, half_pixel_cente new_values[new_idx] = graph_node_output.shape[idx] new_size = new_values[-3:-1] + # Param replacement + input_tensor = replace_parameter( + value_before_replacement=input_tensor, + param_target='inputs', + param_name=graph_node.inputs[0].name, + **kwargs, + ) + roi = replace_parameter( + value_before_replacement=roi, + param_target='inputs', + param_name=graph_node.inputs[1].name, + **kwargs, + ) + scales = replace_parameter( + value_before_replacement=scales, + param_target='inputs', + param_name=graph_node.inputs[2].name, + **kwargs, + ) + new_size = replace_parameter( + value_before_replacement=new_size, + param_target='inputs', + param_name=graph_node.inputs[3].name, + **kwargs, + ) + + coordinate_transformation_mode = replace_parameter( + value_before_replacement=coordinate_transformation_mode, + param_target='attributes', + param_name='coordinate_transformation_mode', + **kwargs, + ) + extrapolation_value = replace_parameter( + value_before_replacement=extrapolation_value, + param_target='attributes', + param_name='extrapolation_value', + **kwargs, + ) + mode = replace_parameter( + value_before_replacement=mode, + param_target='attributes', + param_name='mode', + **kwargs, + ) + # TODO: upsampling2d_bilinear_5d # TODO: upsampling2d_nearest_5d resized_tensor = None diff --git a/onnx2tf/ops/Sub.py b/onnx2tf/ops/Sub.py index 40284c04..0826ee3b 100644 --- a/onnx2tf/ops/Sub.py +++ b/onnx2tf/ops/Sub.py @@ -5,6 +5,8 @@ import tensorflow as tf import onnx_graphsurgeon as gs from onnx2tf.utils.common_functions import ( + get_replacement_parameter, + replace_parameter, get_constant_or_variable, print_node_info, inverted_operation_enable_disable, @@ -14,6 +16,7 @@ @print_node_info @inverted_operation_enable_disable +@get_replacement_parameter def make_node( *, graph_node: gs.Node, @@ -57,12 +60,26 @@ def make_node( 'dtype': dtype, } - # Generation of TF OP input_tensor_1 = tf_layers_dict[graph_node_input_1.name]['tf_node'] \ if isinstance(graph_node_input_1, gs.Variable) else graph_node_input_1 input_tensor_2 = tf_layers_dict[graph_node_input_2.name]['tf_node'] \ if isinstance(graph_node_input_2, gs.Variable) else graph_node_input_2 + # Param replacement + input_tensor_1 = replace_parameter( + value_before_replacement=input_tensor_1, + param_target='inputs', + param_name=graph_node.inputs[0].name, + **kwargs, + ) + input_tensor_2 = replace_parameter( + value_before_replacement=input_tensor_2, + param_target='inputs', + param_name=graph_node.inputs[1].name, + **kwargs, + ) + + # Generation of TF OP tf_layers_dict[graph_node_output.name]['tf_node'] = \ tf.math.subtract( x=input_tensor_1, diff --git a/onnx2tf/ops/Transpose.py b/onnx2tf/ops/Transpose.py index f830ae93..9f0d0122 100644 --- a/onnx2tf/ops/Transpose.py +++ b/onnx2tf/ops/Transpose.py @@ -120,9 +120,21 @@ def make_node( 'dtype': dtype, } - # Param replacement perm = list(perm) if perm is not None else None - perm = replace_parameter(perm, 'attributes', 'perm', **kwargs) + + # Param replacement + input_tensor = replace_parameter( + value_before_replacement=input_tensor, + param_target='inputs', + param_name=graph_node.inputs[0].name, + **kwargs, + ) + perm = replace_parameter( + value_before_replacement=perm, + param_target='attributes', + param_name='perm', + **kwargs, + ) # Generation of TF OP tf_layers_dict[graph_node_output.name]['tf_node'] = \ diff --git a/onnx2tf/utils/common_functions.py b/onnx2tf/utils/common_functions.py index e99692d2..83f3c55e 100644 --- a/onnx2tf/utils/common_functions.py +++ b/onnx2tf/utils/common_functions.py @@ -12,6 +12,10 @@ from functools import wraps from collections import namedtuple +from onnx2tf.utils.enums import ( + TF_DTYPES_TO_NUMPY_DTYPES, +) + def get_replacement_parameter(func): @wraps(func) @@ -30,16 +34,43 @@ def get_replacement_parameter_wrapper_func(*args, **kwargs): def replace_parameter( + *, value_before_replacement, param_target, param_name, **kwargs, ): replace_value = value_before_replacement - for op_rep_param in kwargs['op_rep_params']: + op_rep_params = kwargs.get('op_rep_params', []) + for op_rep_param in op_rep_params: if op_rep_param['param_target'] == param_target \ and op_rep_param['param_name'] == param_name: replace_value = op_rep_param.get('values', value_before_replacement) + if isinstance(value_before_replacement, np.ndarray): + replace_value = np.asarray( + replace_value, + dtype=value_before_replacement.dtype, + ) + elif isinstance(value_before_replacement, list): + replace_value = list(replace_value) + elif isinstance(value_before_replacement, bool): + replace_value = \ + bool(replace_value) if isinstance(replace_value, int) and replace_value in [0, 1] else \ + bool(int(replace_value)) if isinstance(replace_value, str) and replace_value in ["0", "1"] else \ + False if isinstance(replace_value, str) and replace_value.lower() == "false" else \ + True if isinstance(replace_value, str) and replace_value.lower() == "True" else \ + replace_value + elif isinstance(value_before_replacement, int): + replace_value = int(replace_value) + elif isinstance(value_before_replacement, float): + replace_value = float(replace_value) + elif isinstance(value_before_replacement, str): + replace_value = str(replace_value) + elif tf.keras.backend.is_keras_tensor(value_before_replacement): + replace_value = np.asarray( + replace_value, + dtype=TF_DTYPES_TO_NUMPY_DTYPES[value_before_replacement.dtype], + ) break return replace_value diff --git a/onnx2tf/utils/enums.py b/onnx2tf/utils/enums.py index 78c6d838..643b80dc 100644 --- a/onnx2tf/utils/enums.py +++ b/onnx2tf/utils/enums.py @@ -48,4 +48,22 @@ np.dtype('int64'): tf.int64, np.dtype('bool_'): tf.bool, -} \ No newline at end of file +} + +TF_DTYPES_TO_NUMPY_DTYPES = { + tf.float16: np.dtype('int32'), + tf.float32: np.dtype('float32'), + tf.float64: np.dtype('float64'), + + tf.uint8: np.dtype('uint8'), + tf.uint16: np.dtype('uint16'), + tf.uint32: np.dtype('uint32'), + tf.uint64: np.dtype('uint64'), + + tf.int8: np.dtype('int8'), + tf.int16: np.dtype('int16'), + tf.int32: np.dtype('int32'), + tf.int64: np.dtype('int64'), + + tf.bool: np.dtype('bool_'), +}