diff --git a/onnx2tf/ops/Unsqueeze.py b/onnx2tf/ops/Unsqueeze.py index efa0bfc6..8d58a5da 100644 --- a/onnx2tf/ops/Unsqueeze.py +++ b/onnx2tf/ops/Unsqueeze.py @@ -65,14 +65,14 @@ def make_node( axes = [ convert_axis( axis=idx, - tensor_rank=tensor_rank, + tensor_rank=tensor_rank+len(axes), before_op_output_shape_trans=before_op_output_shape_trans, ) for idx in axes ] elif axes is not None and isinstance(axes, np.ndarray) and len(axes.shape) == 0: axes = convert_axis( axis=axes, - tensor_rank=tensor_rank, + tensor_rank=tensor_rank+1, before_op_output_shape_trans=before_op_output_shape_trans, ) axes = list(axes[np.newaxis])