From 7fbe509c3ff429e2fec1b9e6f6f3cf908dcebb1f Mon Sep 17 00:00:00 2001 From: pinto0309 Date: Sun, 6 Oct 2024 00:04:35 +0900 Subject: [PATCH] Fixed to force switch between X and Y when X: np.ndarray, Y: Tensor --- README.md | 4 ++-- onnx2tf/__init__.py | 2 +- onnx2tf/ops/Add.py | 4 ++++ onnx2tf/ops/Mul.py | 4 ++++ 4 files changed, 11 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 85a7c53e..5d5c0c40 100644 --- a/README.md +++ b/README.md @@ -299,7 +299,7 @@ Video speed is adjusted approximately 50 times slower than actual speed. docker run --rm -it \ -v `pwd`:/workdir \ -w /workdir \ - ghcr.io/pinto0309/onnx2tf:1.25.14 + ghcr.io/pinto0309/onnx2tf:1.25.15 or @@ -307,7 +307,7 @@ Video speed is adjusted approximately 50 times slower than actual speed. docker run --rm -it \ -v `pwd`:/workdir \ -w /workdir \ - docker.io/pinto0309/onnx2tf:1.25.14 + docker.io/pinto0309/onnx2tf:1.25.15 or diff --git a/onnx2tf/__init__.py b/onnx2tf/__init__.py index f6db3dee..a365168d 100644 --- a/onnx2tf/__init__.py +++ b/onnx2tf/__init__.py @@ -1,3 +1,3 @@ from onnx2tf.onnx2tf import convert, main -__version__ = '1.25.14' +__version__ = '1.25.15' diff --git a/onnx2tf/ops/Add.py b/onnx2tf/ops/Add.py index a51fd527..8b8ebb19 100644 --- a/onnx2tf/ops/Add.py +++ b/onnx2tf/ops/Add.py @@ -92,6 +92,10 @@ def make_node( input_tensor_2 = tf_layers_dict[graph_node_input_2.name]['tf_node'] \ if isinstance(graph_node_input_2, gs.Variable) else graph_node_input_2 + # issue: https://github.com/PINTO0309/onnx2tf/issues/698 + if isinstance(input_tensor_1, np.ndarray) and not isinstance(input_tensor_2, np.ndarray): + input_tensor_1, input_tensor_2 = input_tensor_2, input_tensor_1 + disable_strict_mode: bool = kwargs['disable_strict_mode'] gelu_replace_op_names: dict = kwargs['gelu_replace_op_names'] diff --git a/onnx2tf/ops/Mul.py b/onnx2tf/ops/Mul.py index ca9447a7..b168ec05 100644 --- a/onnx2tf/ops/Mul.py +++ b/onnx2tf/ops/Mul.py @@ -84,6 +84,10 @@ def make_node( input_tensor_2 = tf_layers_dict[graph_node_input_2.name]['tf_node'] \ if isinstance(graph_node_input_2, gs.Variable) else graph_node_input_2 + # issue: https://github.com/PINTO0309/onnx2tf/issues/698 + if isinstance(input_tensor_1, np.ndarray) and not isinstance(input_tensor_2, np.ndarray): + input_tensor_1, input_tensor_2 = input_tensor_2, input_tensor_1 + disable_strict_mode: bool = kwargs['disable_strict_mode'] gelu_replace_op_names: dict = kwargs['gelu_replace_op_names']