Skip to content

Commit

Permalink
fine tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
PINTO0309 committed Oct 4, 2022
1 parent a0cb1ad commit 17b8f24
Showing 1 changed file with 50 additions and 142 deletions.
192 changes: 50 additions & 142 deletions onnx2tf/ops/Slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,17 @@ def make_node(
)
starts = tf_layers_dict[starts.name]['tf_node'] \
if isinstance(starts, gs.Variable) else starts
if isinstance(starts, np.ndarray):
starts = tf.constant(starts, dtype=NUMPY_DTYPES_TO_TF_DTYPES[starts.dtype])
# if isinstance(starts, np.ndarray):
# starts = tf.constant(starts, dtype=NUMPY_DTYPES_TO_TF_DTYPES[starts.dtype])

ends = get_constant_or_variable(
graph_node.inputs[2],
before_op_output_shape_trans,
)
ends = tf_layers_dict[ends.name]['tf_node'] \
if isinstance(ends, gs.Variable) else ends
if isinstance(ends, np.ndarray):
ends = tf.constant(ends, dtype=NUMPY_DTYPES_TO_TF_DTYPES[ends.dtype])
# if isinstance(ends, np.ndarray):
# ends = tf.constant(ends, dtype=NUMPY_DTYPES_TO_TF_DTYPES[ends.dtype])

# input_tensor_shape = tf.shape(
# input=input_tensor,
Expand Down Expand Up @@ -117,141 +117,49 @@ def make_node(
'dtype': dtype,
}

# # Generation of TF OP
# is_axes_negative = tf.less(
# axes,
# tf.zeros_like(axes),
# )
# axes = tf.where(
# is_axes_negative,
# # axes + tf.cast(tf.rank(input_tensor), axes.dtype),
# axes + tf.cast(input_tensor_rank, axes.dtype),
# axes,
# )

# # expand a dimension of 1 at the end
# sparse_indices = tf.cast(
# tf.expand_dims(axes, -1),
# tf.int64,
# )

# # build the indexed dimension sizes as sparse_shape
# sparse_shape = tf.gather_nd(
# params=input_tensor_shape,
# indices=sparse_indices,
# )
# sparse_shape = tf.cast(
# sparse_shape,
# ends.dtype,
# )

# # take care of starts, ends that are larger than the dim size.
# starts_min = tf.minimum(
# starts,
# sparse_shape,
# )
# ends_min = tf.minimum(
# ends,
# sparse_shape,
# )

# # take care of starts, ends that are negative
# is_starts_negative = tf.less(
# starts_min,
# tf.zeros_like(starts_min),
# )
# starts_final = tf.where(
# is_starts_negative,
# starts_min + sparse_shape,
# starts_min,
# )
# is_ends_negative = tf.less(
# ends_min,
# tf.zeros_like(ends_min),
# )
# ends_final = tf.where(
# is_ends_negative,
# ends_min + sparse_shape,
# ends_min,
# )

# # need to densify everything for the inputs to slice
# # the output shape is the input_tensor rank
# output_shape = tf.reshape(
# input_tensor_rank,
# [1],
# )
# output_shape = tf.cast(
# output_shape,
# tf.int64,
# )

# # create dense tensor, pad 0 as default begins
# def create_sparse_tensor(sparse_indices, starts_final, output_shape):
# return tf.sparse.SparseTensor(sparse_indices, starts_final, output_shape)

# sparse_tensor = Lambda(
# create_sparse_tensor,
# output_shape=output_shape,
# arguments={
# 'starts_final': starts_final,
# 'output_shape': output_shape,

# }
# )(sparse_indices)

# dense_begins = tf.sparse.to_dense(
# tf.sparse.SparseTensor(
# sparse_indices,
# starts_final,
# output_shape,
# )
# )

# # create dense tensor, pad -1 for next step
# dense_ends = tf.sparse.SparseTensor(
# sparse_indices,
# ends_final,
# output_shape,
# )
# dense_ends = tf.sparse.to_dense(
# dense_ends,
# default_value=tf.constant(-1, dtype=dense_begins.dtype)
# )
# dense_ends = tf.where(
# tf.equal(dense_ends, tf.constant(-1, dtype=dense_begins.dtype)),
# input_tensor_shape,
# dense_ends,
# )

# # create dense tensor for steps if not already so
# if len(graph_node.inputs) >= 5:
# dense_steps = tf.sparse.SparseTensor(
# sparse_indices,
# steps,
# output_shape
# )
# dense_steps = tf.sparse.to_dense(
# dense_steps,
# default_value=tf.constant(1, dtype=steps.dtype)
# )
# else:
# dense_steps = tf.ones(tf.shape(input_tensor_shape), ends.dtype)

# tf_layers_dict[graph_node_output.name]['tf_node'] = \
# tf.strided_slice(
# input_=input_tensor,
# begin=dense_begins,
# end=dense_ends,
# strides=dense_steps,
# name=graph_node.name,
# )

tf_layers_dict[graph_node_output.name]['tf_node'] = \
tf.strided_slice(
input_=input_tensor,
begin=starts,
end=ends,
strides=steps,
name=graph_node.name,
)
# Generation of TF OP
slice_len = len(starts)

if slice_len == 1:
# Shape output as int64 since the spec implicitly allows int64
full_sizes = np.asarray(input_tensor.shape, dtype=np.int64)

updated_full_sizes = [0] * len(input_tensor.get_shape())
updated_full_begin = [0] * len(input_tensor.get_shape())
updated_starts = [0] * slice_len
updated_ends = [0] * slice_len

for axis in range(input_tensor.shape.rank):
if axis not in axes:
# Update the sizes for axes that are not in the axes attribute
# No need to change the default of 0 in begins
updated_full_sizes[axis] = full_sizes[axis]
else:
# Update the begins and sizes for each axis in the axes attribute
for i in range(slice_len):
if axis == axes[i]:
updated_starts[i] = full_sizes[axis] + starts[i] if starts[i] < 0 else starts[i]
updated_ends[i] = full_sizes[axis] + ends[i] if ends[i] < 0 else ends[i]
if full_sizes[axis] is not None:
updated_ends[i] = min(full_sizes[axis], updated_ends[i])
updated_starts[i] = min(full_sizes[axis], updated_starts[i])

updated_full_begin[axis] = updated_starts[i]
updated_full_sizes[axis] = updated_ends[i] - updated_starts[i]

tf_layers_dict[graph_node_output.name]['tf_node'] = \
tf.strided_slice(
input_=input_tensor,
begin=updated_full_begin,
end=updated_full_sizes,
name=graph_node.name,
)
else:
tf_layers_dict[graph_node_output.name]['tf_node'] = \
tf.strided_slice(
input_=input_tensor,
begin=starts,
end=ends,
strides=steps,
name=graph_node.name,
)

0 comments on commit 17b8f24

Please sign in to comment.