Skip to content

Commit

Permalink
Improved channel transposition process
Browse files Browse the repository at this point in the history
  • Loading branch information
PINTO0309 committed Oct 29, 2022
1 parent c53052a commit 578cdd0
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 2 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ Video speed is adjusted approximately 50 times slower than actual speed.
$ docker run --rm -it \
-v `pwd`:/workdir \
-w /workdir \
ghcr.io/pinto0309/onnx2tf:1.0.30
ghcr.io/pinto0309/onnx2tf:1.0.31
or
Expand Down
2 changes: 1 addition & 1 deletion onnx2tf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from onnx2tf.onnx2tf import convert, main

__version__ = '1.0.30'
__version__ = '1.0.31'
10 changes: 10 additions & 0 deletions onnx2tf/utils/common_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,16 @@ def channel_transpose(
and list(const_or_var_2.shape).count(1) == len(const_or_var_2.shape)-1 \
and const_or_var_1.shape[-1] == max(const_or_var_2.shape):
const_or_var_2 = const_or_var_2.squeeze()
elif isinstance(const_or_var_2, tf.Tensor) \
and len(const_or_var_2.shape) > 1 \
and list(const_or_var_2.shape).count(1) == len(const_or_var_2.shape)-1 \
and const_or_var_1.shape[-1] == max(const_or_var_2.shape):
const_or_var_2 = tf.squeeze(const_or_var_2)
elif tf.keras.backend.is_keras_tensor(const_or_var_2) \
and len(const_or_var_2.shape) > 1 \
and list(const_or_var_2.shape).count(1) == len(const_or_var_2.shape)-1 \
and const_or_var_1.shape[-1] == max(const_or_var_2.shape):
const_or_var_2 = tf.squeeze(const_or_var_2)
return const_or_var_2


Expand Down

0 comments on commit 578cdd0

Please sign in to comment.