Skip to content

Commit

Permalink
Merge pull request #692 from PINTO0309/fix_flatten_no_axis
Browse files Browse the repository at this point in the history
Improved handling when `axis` attribute is not defined and the batch size of the first dimension is undefined.
  • Loading branch information
PINTO0309 authored Sep 15, 2024
2 parents 8d0541a + 4a1570f commit 5e1364f
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 12 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -299,15 +299,15 @@ Video speed is adjusted approximately 50 times slower than actual speed.
docker run --rm -it \
-v `pwd`:/workdir \
-w /workdir \
ghcr.io/pinto0309/onnx2tf:1.25.11
ghcr.io/pinto0309/onnx2tf:1.25.12

or

# Authentication is not required for pulls from Docker Hub.
docker run --rm -it \
-v `pwd`:/workdir \
-w /workdir \
docker.io/pinto0309/onnx2tf:1.25.11
docker.io/pinto0309/onnx2tf:1.25.12

or

Expand Down Expand Up @@ -404,7 +404,7 @@ The given SavedModel SignatureDef contains the following input(s):
The given SavedModel SignatureDef contains the following output(s):
outputs['output_0'] tensor_info:
dtype: DT_FLOAT
shape: (1, 1000) # <-- Model design bug in resnet18-v1-7.onnx
shape: (-1, 1000)
name: PartitionedCall:0
Method name is: tensorflow/serving/predict

Expand Down
2 changes: 1 addition & 1 deletion onnx2tf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from onnx2tf.onnx2tf import convert, main

__version__ = '1.25.11'
__version__ = '1.25.12'
19 changes: 11 additions & 8 deletions onnx2tf/ops/Flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,17 @@ def make_node(
output_shape = graph_node_output.shape
dtype = graph_node_output.dtype

axis = graph_node.attrs.get("axis", 0)
if graph_node_input.shape is not None \
and axis < input_tensor_rank:
axis = convert_axis(
axis=axis,
tensor_rank=len(graph_node_input.shape),
before_op_output_shape_trans=before_op_output_shape_trans,
)
axis = graph_node.attrs.get("axis", None)
if axis is not None:
if graph_node_input.shape is not None \
and axis < input_tensor_rank:
axis = convert_axis(
axis=axis,
tensor_rank=len(graph_node_input.shape),
before_op_output_shape_trans=before_op_output_shape_trans,
)
else:
axis = input_tensor_rank - 1

# Preserving Graph Structure (Dict)
tf_layers_dict[graph_node_output.name] = {
Expand Down

0 comments on commit 5e1364f

Please sign in to comment.