Skip to content

Commit

Permalink
Reshape bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
PINTO0309 committed Oct 7, 2022
1 parent 3a765c9 commit 982d908
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions onnx2tf/ops/Reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,23 +90,26 @@ def make_node(
a=input_tensor,
perm=list(perm) if perm is not None else None,
)
perm_shape = [
convert_axis(
axis=idx,
tensor_rank=len(reshape_shape),
before_op_output_shape_trans=before_op_output_shape_trans,
) for idx in range(len(reshape_shape))
]
transposed_reshape_shape = list(reshape_shape)
if before_op_output_shape_trans:
transposed_reshape_shape = [
transposed_reshape_shape[perm_shape_dim] for perm_shape_dim in list(perm_shape)
if isinstance(reshape_shape, np.ndarray):
perm_shape = [
convert_axis(
axis=idx,
tensor_rank=len(reshape_shape),
before_op_output_shape_trans=before_op_output_shape_trans,
) for idx in range(len(reshape_shape))
]
transposed_reshape_shape = list(reshape_shape)
if before_op_output_shape_trans:
transposed_reshape_shape = [
transposed_reshape_shape[perm_shape_dim] for perm_shape_dim in list(perm_shape)
]
else:
transposed_reshape_shape = reshape_shape

# Reshape
tf_layers_dict[graph_node_output.name]['tf_node'] = \
tf.reshape(
tensor=transposed_tensor,
shape=transposed_reshape_shape,
name=graph_node.name,
)
)

0 comments on commit 982d908

Please sign in to comment.