From 924464de08864369ce0aec09d80643082a313ffe Mon Sep 17 00:00:00 2001 From: pinto0309 Date: Thu, 12 Sep 2024 08:40:22 +0900 Subject: [PATCH] Improved the conversion stability of BatchNormalization. --- README.md | 4 +- onnx2tf/__init__.py | 2 +- onnx2tf/ops/BatchNormalization.py | 141 ++++++++++++++++-------------- 3 files changed, 77 insertions(+), 70 deletions(-) diff --git a/README.md b/README.md index 1fec40e7..4123b849 100644 --- a/README.md +++ b/README.md @@ -299,7 +299,7 @@ Video speed is adjusted approximately 50 times slower than actual speed. docker run --rm -it \ -v `pwd`:/workdir \ -w /workdir \ - ghcr.io/pinto0309/onnx2tf:1.25.10 + ghcr.io/pinto0309/onnx2tf:1.25.11 or @@ -307,7 +307,7 @@ Video speed is adjusted approximately 50 times slower than actual speed. docker run --rm -it \ -v `pwd`:/workdir \ -w /workdir \ - docker.io/pinto0309/onnx2tf:1.25.10 + docker.io/pinto0309/onnx2tf:1.25.11 or diff --git a/onnx2tf/__init__.py b/onnx2tf/__init__.py index 3eb301fb..b1fc0d33 100644 --- a/onnx2tf/__init__.py +++ b/onnx2tf/__init__.py @@ -1,3 +1,3 @@ from onnx2tf.onnx2tf import convert, main -__version__ = '1.25.10' +__version__ = '1.25.11' diff --git a/onnx2tf/ops/BatchNormalization.py b/onnx2tf/ops/BatchNormalization.py index 69d4aed1..dbb96881 100644 --- a/onnx2tf/ops/BatchNormalization.py +++ b/onnx2tf/ops/BatchNormalization.py @@ -338,11 +338,17 @@ def make_node( # Automatic correction of accuracy degradation min_abs_err = sys.maxsize - min_abs_err_perm_1: List[int] = [idx for idx in range(len(mean.shape))] + min_abs_err_perm_1: List[int] = [] + check_length = 0 + if input_tensor.shape is not None and mean.shape is not None and len(input_tensor.shape) >= len(mean.shape): + check_length = len(input_tensor.shape) + else: + check_length = len(mean.shape) + min_abs_err_perm_1: List[int] = [idx for idx in range(check_length)] if not disable_strict_mode: if onnx_tensor_infos is not None and validation_data is not None: - tensor_1_candidate_for_transpositions = list(itertools.permutations(range(len(mean.shape)))) + tensor_1_candidate_for_transpositions = list(itertools.permutations(range(check_length))) # Search for the axis with the smallest error for tensor_1_candidate_for_transposition in tensor_1_candidate_for_transpositions: try: @@ -470,71 +476,72 @@ def make_node( except Exception as ex: pass - tf_layers_dict[Y.name]['tf_node'] = \ - tf.nn.batch_normalization( - x=input_tensor, - mean=\ - transpose_with_flexing_deterrence( - input_tensor=mean, - perm=min_abs_err_perm_1, - output_shape=Y.shape \ - if None not in Y.shape and Y.shape != [] else None, - **kwargs, - ) if not isinstance(mean, np.ndarray) else \ - transpose_with_flexing_deterrence( - input_tensor=tf.convert_to_tensor(mean), - perm=min_abs_err_perm_1, - output_shape=Y.shape \ - if None not in Y.shape and Y.shape != [] else None, - **kwargs, - ), - variance=\ - transpose_with_flexing_deterrence( - input_tensor=var, - perm=min_abs_err_perm_1, - output_shape=Y.shape \ - if None not in Y.shape and Y.shape != [] else None, - **kwargs, - ) if not isinstance(var, np.ndarray) else \ - transpose_with_flexing_deterrence( - input_tensor=tf.convert_to_tensor(var), - perm=min_abs_err_perm_1, - output_shape=Y.shape \ - if None not in Y.shape and Y.shape != [] else None, - **kwargs, - ), - offset=\ - transpose_with_flexing_deterrence( - input_tensor=offset, - perm=min_abs_err_perm_1, - output_shape=Y.shape \ - if None not in Y.shape and Y.shape != [] else None, - **kwargs, - ) if not isinstance(offset, np.ndarray) else \ - transpose_with_flexing_deterrence( - input_tensor=tf.convert_to_tensor(offset), - perm=min_abs_err_perm_1, - output_shape=Y.shape \ - if None not in Y.shape and Y.shape != [] else None, - **kwargs, - ), - scale=\ - transpose_with_flexing_deterrence( - input_tensor=scale, - perm=min_abs_err_perm_1, - output_shape=Y.shape \ - if None not in Y.shape and Y.shape != [] else None, - **kwargs, - ) if not isinstance(scale, np.ndarray) else \ - transpose_with_flexing_deterrence( - input_tensor=tf.convert_to_tensor(scale), - perm=min_abs_err_perm_1, - output_shape=Y.shape \ - if None not in Y.shape and Y.shape != [] else None, - **kwargs, - ), - variance_epsilon=epsilon, - ) + if min_abs_err_perm_1 != [idx for idx in range(check_length)]: + tf_layers_dict[Y.name]['tf_node'] = \ + tf.nn.batch_normalization( + x=input_tensor, + mean=\ + transpose_with_flexing_deterrence( + input_tensor=mean, + perm=min_abs_err_perm_1, + output_shape=Y.shape \ + if None not in Y.shape and Y.shape != [] else None, + **kwargs, + ) if not isinstance(mean, np.ndarray) else \ + transpose_with_flexing_deterrence( + input_tensor=tf.convert_to_tensor(mean), + perm=min_abs_err_perm_1, + output_shape=Y.shape \ + if None not in Y.shape and Y.shape != [] else None, + **kwargs, + ), + variance=\ + transpose_with_flexing_deterrence( + input_tensor=var, + perm=min_abs_err_perm_1, + output_shape=Y.shape \ + if None not in Y.shape and Y.shape != [] else None, + **kwargs, + ) if not isinstance(var, np.ndarray) else \ + transpose_with_flexing_deterrence( + input_tensor=tf.convert_to_tensor(var), + perm=min_abs_err_perm_1, + output_shape=Y.shape \ + if None not in Y.shape and Y.shape != [] else None, + **kwargs, + ), + offset=\ + transpose_with_flexing_deterrence( + input_tensor=offset, + perm=min_abs_err_perm_1, + output_shape=Y.shape \ + if None not in Y.shape and Y.shape != [] else None, + **kwargs, + ) if not isinstance(offset, np.ndarray) else \ + transpose_with_flexing_deterrence( + input_tensor=tf.convert_to_tensor(offset), + perm=min_abs_err_perm_1, + output_shape=Y.shape \ + if None not in Y.shape and Y.shape != [] else None, + **kwargs, + ), + scale=\ + transpose_with_flexing_deterrence( + input_tensor=scale, + perm=min_abs_err_perm_1, + output_shape=Y.shape \ + if None not in Y.shape and Y.shape != [] else None, + **kwargs, + ) if not isinstance(scale, np.ndarray) else \ + transpose_with_flexing_deterrence( + input_tensor=tf.convert_to_tensor(scale), + perm=min_abs_err_perm_1, + output_shape=Y.shape \ + if None not in Y.shape and Y.shape != [] else None, + **kwargs, + ), + variance_epsilon=epsilon, + ) tf_type = tf.nn.batch_normalization # Post-process transpose