Skip to content

Commit

Permalink
Dealing with garbage-like broken structures in ONNX (ArgMin) #695
Browse files Browse the repository at this point in the history
  • Loading branch information
PINTO0309 committed Sep 26, 2024
1 parent 69efa5b commit bdbf2f4
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -299,15 +299,15 @@ Video speed is adjusted approximately 50 times slower than actual speed.
docker run --rm -it \
-v `pwd`:/workdir \
-w /workdir \
ghcr.io/pinto0309/onnx2tf:1.25.13
ghcr.io/pinto0309/onnx2tf:1.25.14

or

# Authentication is not required for pulls from Docker Hub.
docker run --rm -it \
-v `pwd`:/workdir \
-w /workdir \
docker.io/pinto0309/onnx2tf:1.25.13
docker.io/pinto0309/onnx2tf:1.25.14

or

Expand Down
2 changes: 1 addition & 1 deletion onnx2tf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from onnx2tf.onnx2tf import convert, main

__version__ = '1.25.13'
__version__ = '1.25.14'
10 changes: 5 additions & 5 deletions onnx2tf/ops/ArgMin.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,15 @@ def make_node(
shape = graph_node_output.shape
dtype = graph_node_output.dtype

# Generation of TF OP
input_tensor = tf_layers_dict[graph_node_input.name]['tf_node'] \
if isinstance(graph_node_input, gs.Variable) else graph_node_input

axis = graph_node.attrs.get('axis', 0)
# NCHW->NHWC, NCDHW->NDHWC
axis = convert_axis(
axis=axis,
tensor_rank=len(graph_node_input.shape),
tensor_rank=len(graph_node_input.shape) if graph_node_input.shape is not None else len(input_tensor.shape),
before_op_output_shape_trans=before_op_output_shape_trans,
)

Expand All @@ -73,10 +77,6 @@ def make_node(
and 'nhwc' in tf_layers_dict[graph_node_input.name].keys() else False
}

# Generation of TF OP
input_tensor = tf_layers_dict[graph_node_input.name]['tf_node'] \
if isinstance(graph_node_input, gs.Variable) else graph_node_input

# Pre-process transpose
input_tensor = pre_process_transpose(
value_before_transpose=input_tensor,
Expand Down

0 comments on commit bdbf2f4

Please sign in to comment.