Skip to content

Commit

Permalink
Significant modifications to "Slice" processing
Browse files Browse the repository at this point in the history
  • Loading branch information
PINTO0309 committed Oct 26, 2022
1 parent 4baf3f3 commit 85d0456
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 55 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ Video speed is adjusted approximately 50 times slower than actual speed.
$ docker run --rm -it \
-v `pwd`:/workdir \
-w /workdir \
ghcr.io/pinto0309/onnx2tf:1.0.23
ghcr.io/pinto0309/onnx2tf:1.0.24
or
Expand Down
2 changes: 1 addition & 1 deletion onnx2tf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from onnx2tf.onnx2tf import convert, main

__version__ = '1.0.23'
__version__ = '1.0.24'
151 changes: 98 additions & 53 deletions onnx2tf/ops/Slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def make_node(
)
input_tensor = tf_layers_dict[input_tensor.name]['tf_node'] \
if isinstance(input_tensor, gs.Variable) else input_tensor
input_tensor_shape = input_tensor.shape
input_tensor_rank = len(input_tensor_shape)

starts = None
if len(graph_node.inputs) >= 2:
Expand Down Expand Up @@ -123,61 +125,104 @@ def make_node(
}

# Generation of TF OP
slice_len = len(starts)

if slice_len == 1:
# Shape output as int64 since the spec implicitly allows int64
full_sizes = None
if None not in input_tensor.shape:
full_sizes = np.asarray(input_tensor.shape, dtype=np.int64)
else:
full_sizes = [s if s is not None else -1 for s in input_tensor.shape]

updated_full_sizes = [0] * len(input_tensor.get_shape())
updated_full_begin = [0] * len(input_tensor.get_shape())
updated_starts = [0] * slice_len
updated_ends = [0] * slice_len

if axes is None:
axes = [i for i in range(input_tensor.shape.rank)]

for axis in range(input_tensor.shape.rank):
if axis not in axes:
# Update the sizes for axes that are not in the axes attribute
# No need to change the default of 0 in begins
updated_full_sizes[axis] = full_sizes[axis]
else:
# Update the begins and sizes for each axis in the axes attribute
for i in range(slice_len):
if axis == axes[i]:
updated_starts[i] = full_sizes[axis] + starts[i] if starts[i] < 0 else starts[i]
updated_ends[i] = full_sizes[axis] + ends[i] if ends[i] < 0 else ends[i]
if full_sizes[axis] is not None:
updated_ends[i] = min(full_sizes[axis], updated_ends[i])
updated_starts[i] = min(full_sizes[axis], updated_starts[i])

updated_full_begin[axis] = updated_starts[i]
updated_full_sizes[axis] = updated_ends[i]

starts = updated_full_begin
ends = updated_full_sizes

tf_layers_dict[graph_node_output.name]['tf_node'] = \
tf.strided_slice(
input_=input_tensor,
begin=starts,
end=ends,
name=graph_node.name,

# first of all, get the input tensor shape
is_axes_negative = tf.less(axes, tf.zeros_like(axes))
axes = tf.where(is_axes_negative, axes + tf.cast(input_tensor_rank, axes.dtype), axes)

# expand a dimension of 1 at the end
sparse_indices = tf.cast(tf.expand_dims(axes, -1), tf.int64)

# build the indexed dimension sizes as sparse_shape
sparse_shape = tf.gather_nd(params=input_tensor_shape, indices=sparse_indices)
sparse_shape = tf.cast(sparse_shape, tf.int64)

# take care of starts, ends that are larger than the dim size.
starts_min = tf.minimum(starts, sparse_shape)
ends_min = tf.minimum(ends, sparse_shape)

# need to densify everything for the inputs to slice
# the output shape is the input_tensor rank
output_shape = tf.reshape(input_tensor_rank, [1])
output_shape = tf.cast(output_shape, tf.int64)

def tf_SparseTensor(sparse_indices, pos, sparse_shape, start_or_end):
dense = None
# take care of starts, ends that are negative
if start_or_end == 'start':
is_starts_negative = tf.less(pos, tf.zeros_like(pos))
final = tf.where(is_starts_negative, pos + sparse_shape, pos)
sparse_tensor = tf.sparse.reorder(
sp_input=tf.sparse.SparseTensor(
indices=sparse_indices,
values=final,
dense_shape=sparse_shape,
)
)
else:
tf_layers_dict[graph_node_output.name]['tf_node'] = \
tf.strided_slice(
input_=input_tensor,
begin=starts,
end=ends,
strides=steps,
name=graph_node.name,
dense = tf.sparse.to_dense(sparse_tensor)
elif start_or_end == 'end':
is_ends_negative = tf.less(pos, tf.zeros_like(pos))
final = tf.where(is_ends_negative, pos + sparse_shape, pos)
sparse_tensor = tf.sparse.reorder(
sp_input=tf.sparse.SparseTensor(
indices=sparse_indices,
values=final,
dense_shape=sparse_shape,
)
)
dense_ends = tf.sparse.to_dense(
sparse_tensor,
default_value=tf.constant(-1, dtype=dense_begins.dtype)
)
dense = tf.where(
tf.equal(dense_ends, tf.constant(-1, dtype=dense_begins.dtype)),
input_tensor_shape,
dense_ends,
)
return dense

# create dense tensor, pad 0 as default begins
dense_begins = Lambda(
tf_SparseTensor,
arguments={
'pos': starts_min,
'sparse_shape': output_shape,
'start_or_end': 'start',
}
)(sparse_indices)
# create dense tensor, pad -1 for next step
dense_ends = Lambda(
tf_SparseTensor,
arguments={
'pos': ends_min,
'sparse_shape': output_shape,
'start_or_end': 'end'
}
)(sparse_indices)

# create dense tensor for steps if not already so
if steps is not None:
dense_steps = tf.sparse.reorder(
sp_input=tf.sparse.SparseTensor(
sparse_indices,
steps,
output_shape,
)
)
dense_steps = tf.sparse.to_dense(
dense_steps,
default_value=tf.constant(1, dtype=steps.dtype)
)
else:
dense_steps = tf.ones(input_tensor_shape, ends.dtype)

tf_layers_dict[graph_node_output.name]['tf_node'] = \
tf.strided_slice(
input_=input_tensor,
begin=dense_begins,
end=dense_ends,
strides=tf.cast(dense_steps, dtype=dense_begins.dtype),
)

# Generation of Debug Info
tf_layers_dict[graph_node_output.name]['tf_node_info'] = \
Expand Down

0 comments on commit 85d0456

Please sign in to comment.