Skip to content

Commit

Permalink
Trilu
Browse files Browse the repository at this point in the history
  • Loading branch information
PINTO0309 committed Oct 26, 2022
1 parent 56324bb commit 4baf3f3
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 3 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ Video speed is adjusted approximately 50 times slower than actual speed.
$ docker run --rm -it \
-v `pwd`:/workdir \
-w /workdir \
ghcr.io/pinto0309/onnx2tf:1.0.22
ghcr.io/pinto0309/onnx2tf:1.0.23
or
Expand Down Expand Up @@ -588,7 +588,7 @@ Please don't post such low level questions as issues.
|Tile|:heavy_check_mark:|
|TopK|:heavy_check_mark:|
|Transpose|:heavy_check_mark:|
|Trilu|**Help wanted**|
|Trilu|:heavy_check_mark:|
|Unique|**Help wanted**|
|Unsqueeze|:heavy_check_mark:|
|Where|:heavy_check_mark:|
Expand Down
2 changes: 1 addition & 1 deletion onnx2tf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from onnx2tf.onnx2tf import convert, main

__version__ = '1.0.22'
__version__ = '1.0.23'
132 changes: 132 additions & 0 deletions onnx2tf/ops/Trilu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import random
random.seed(0)
import numpy as np
np.random.seed(0)
import tensorflow as tf
import onnx_graphsurgeon as gs
from onnx2tf.utils.common_functions import (
get_constant_or_variable,
print_node_info,
inverted_operation_enable_disable,
make_tf_node_info,
)


@print_node_info
@inverted_operation_enable_disable
def make_node(
*,
graph_node: gs.Node,
tf_layers_dict: dict,
**kwargs: dict,
):
"""Trilu
Parameters
----------
graph_node: gs.Node
graph_surgeon Node
tf_layers_dict: dict
optype, shape, dtype, tensorflow graph
"""
before_op_output_shape_trans_1 = \
tf_layers_dict.get(graph_node.inputs[0].name, {}).get('before_op_output_shape_trans', True)
before_op_output_shape_trans = \
before_op_output_shape_trans_1

graph_node_input_1 = get_constant_or_variable(
graph_node.inputs[0],
before_op_output_shape_trans,
)
input_tensor = tf_layers_dict[graph_node_input_1.name]['tf_node'] \
if isinstance(graph_node_input_1, gs.Variable) else graph_node_input_1
tensor_shape = input_tensor.shape

graph_node_input_2 = None
k = None
if len(graph_node.inputs) >= 2:
graph_node_input_2 = get_constant_or_variable(
graph_node.inputs[1],
before_op_output_shape_trans,
)
k = tf_layers_dict[graph_node_input_2.name]['tf_node'] \
if isinstance(graph_node_input_2, gs.Variable) else graph_node_input_2
if k is not None:
if k > tensor_shape[-1]:
k = tensor_shape[-1]
elif k < 0 - tensor_shape[-2]:
k = 0 - tensor_shape[-2]
else:
k = tf.constant(0, dtype=tf.int64)
keep_triangle = tf.constant(-1, dtype=k.dtype)

upper = bool(graph_node.attrs.get('upper', 1))

graph_node_output: gs.Variable = graph_node.outputs[0]
shape = graph_node_output.shape
dtype = graph_node_output.dtype

# Preserving Graph Structure (Dict)
tf_layers_dict[graph_node_output.name] = {
'optype': graph_node.op,
'shape': shape,
'dtype': dtype,
}

# Generation of TF OP
if upper:
if k > 0:
tf_layers_dict[graph_node_output.name]['tf_node'] = \
tf.subtract(
x=input_tensor,
y=tf.linalg.band_part(
input=input_tensor,
num_lower=keep_triangle,
num_upper=k - 1,
),
name=graph_node.name,
)
else:
tf_layers_dict[graph_node_output.name]['tf_node'] = \
tf.linalg.band_part(
input=input_tensor,
num_lower=-k,
num_upper=keep_triangle,
name=graph_node.name,
)
else:
if k >= 0:
tf_layers_dict[graph_node_output.name]['tf_node'] = \
tf.linalg.band_part(
input=input_tensor,
num_lower=keep_triangle,
num_upper=k,
name=graph_node.name,
)
else:
tf_layers_dict[graph_node_output.name]['tf_node'] = \
tf.subtract(
x=input_tensor,
y=tf.linalg.band_part(
input=input_tensor,
num_lower=-1 - k,
num_upper=keep_triangle,
),
name=graph_node.name,
)

# Generation of Debug Info
tf_layers_dict[graph_node_output.name]['tf_node_info'] = \
make_tf_node_info(
node_info={
'tf_op_type': 'Trilu',
'tf_inputs': {
'input': input_tensor,
'k': k,
},
'tf_outputs': {
'output': tf_layers_dict[graph_node_output.name]['tf_node'],
},
}
)

0 comments on commit 4baf3f3

Please sign in to comment.