Skip to content

Commit

Permalink
Merge pull request #699 from PINTO0309/fix_mul_ab
Browse files Browse the repository at this point in the history
Fixed to force switch between `X` and `Y` when X: `np.ndarray`, Y: `Tensor`
  • Loading branch information
PINTO0309 authored Oct 5, 2024
2 parents 769874d + 7fbe509 commit dda694f
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 3 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -299,15 +299,15 @@ Video speed is adjusted approximately 50 times slower than actual speed.
docker run --rm -it \
-v `pwd`:/workdir \
-w /workdir \
ghcr.io/pinto0309/onnx2tf:1.25.14
ghcr.io/pinto0309/onnx2tf:1.25.15

or

# Authentication is not required for pulls from Docker Hub.
docker run --rm -it \
-v `pwd`:/workdir \
-w /workdir \
docker.io/pinto0309/onnx2tf:1.25.14
docker.io/pinto0309/onnx2tf:1.25.15

or

Expand Down
2 changes: 1 addition & 1 deletion onnx2tf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from onnx2tf.onnx2tf import convert, main

__version__ = '1.25.14'
__version__ = '1.25.15'
4 changes: 4 additions & 0 deletions onnx2tf/ops/Add.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ def make_node(
input_tensor_2 = tf_layers_dict[graph_node_input_2.name]['tf_node'] \
if isinstance(graph_node_input_2, gs.Variable) else graph_node_input_2

# issue: https://github.com/PINTO0309/onnx2tf/issues/698
if isinstance(input_tensor_1, np.ndarray) and not isinstance(input_tensor_2, np.ndarray):
input_tensor_1, input_tensor_2 = input_tensor_2, input_tensor_1

disable_strict_mode: bool = kwargs['disable_strict_mode']
gelu_replace_op_names: dict = kwargs['gelu_replace_op_names']

Expand Down
4 changes: 4 additions & 0 deletions onnx2tf/ops/Mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ def make_node(
input_tensor_2 = tf_layers_dict[graph_node_input_2.name]['tf_node'] \
if isinstance(graph_node_input_2, gs.Variable) else graph_node_input_2

# issue: https://github.com/PINTO0309/onnx2tf/issues/698
if isinstance(input_tensor_1, np.ndarray) and not isinstance(input_tensor_2, np.ndarray):
input_tensor_1, input_tensor_2 = input_tensor_2, input_tensor_1

disable_strict_mode: bool = kwargs['disable_strict_mode']
gelu_replace_op_names: dict = kwargs['gelu_replace_op_names']

Expand Down

0 comments on commit dda694f

Please sign in to comment.