Skip to content

Commit

Permalink
Multinomial
Browse files Browse the repository at this point in the history
  • Loading branch information
PINTO0309 committed Oct 26, 2022
1 parent 54c9ada commit 0fd1944
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 3 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ Video speed is adjusted approximately 50 times slower than actual speed.
$ docker run --rm -it \
-v `pwd`:/workdir \
-w /workdir \
ghcr.io/pinto0309/onnx2tf:1.0.20
ghcr.io/pinto0309/onnx2tf:1.0.21
or
Expand Down Expand Up @@ -512,7 +512,7 @@ Please don't post such low level questions as issues.
|Mish|:heavy_check_mark:|
|Mod|:heavy_check_mark:|
|Mul|:heavy_check_mark:|
|Multinomial|**Help wanted**|
|Multinomial|:heavy_check_mark:|
|Neg|:heavy_check_mark:|
|NonMaxSuppression|:heavy_check_mark:|
|NonZero|:heavy_check_mark:|
Expand Down
2 changes: 1 addition & 1 deletion onnx2tf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from onnx2tf.onnx2tf import convert, main

__version__ = '1.0.20'
__version__ = '1.0.21'
86 changes: 86 additions & 0 deletions onnx2tf/ops/Multinomial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import random
random.seed(0)
import numpy as np
np.random.seed(0)
import tensorflow as tf
import onnx_graphsurgeon as gs
from onnx2tf.utils.common_functions import (
get_constant_or_variable,
print_node_info,
inverted_operation_enable_disable,
make_tf_node_info,
)
from onnx2tf.utils.enums import ONNX_DTYPES_TO_TF_DTYPES


@print_node_info
@inverted_operation_enable_disable
def make_node(
*,
graph_node: gs.Node,
tf_layers_dict: dict,
**kwargs: dict,
):
"""Multinomial
Parameters
----------
graph_node: gs.Node
graph_surgeon Node
tf_layers_dict: dict
optype, shape, dtype, tensorflow graph
"""
before_op_output_shape_trans_1 = \
tf_layers_dict.get(graph_node.inputs[0].name, {}).get('before_op_output_shape_trans', True)
before_op_output_shape_trans = \
before_op_output_shape_trans_1

graph_node_input_1 = get_constant_or_variable(
graph_node.inputs[0],
before_op_output_shape_trans,
)
input_tensor = tf_layers_dict[graph_node_input_1.name]['tf_node'] \
if isinstance(graph_node_input_1, gs.Variable) else graph_node_input_1

graph_node_output: gs.Variable = graph_node.outputs[0]
shape = graph_node_output.shape
dtype = graph_node_output.dtype

input_tensor_dtype = ONNX_DTYPES_TO_TF_DTYPES[graph_node.attrs.get('dtype', 6)] # int32
sample_size = graph_node.attrs.get('sample_size', 1)
seed = graph_node.attrs.get('seed', 0)

# Preserving Graph Structure (Dict)
tf_layers_dict[graph_node_output.name] = {
'optype': graph_node.op,
'shape': shape,
'dtype': dtype,
}

# Generation of TF OP
tf_layers_dict[graph_node_output.name]['tf_node'] = \
tf.random.categorical(
logits=input_tensor,
num_samples=sample_size,
dtype=input_tensor_dtype,
seed=seed,
name=graph_node.name,
)

# Generation of Debug Info
tf_layers_dict[graph_node_output.name]['tf_node_info'] = \
make_tf_node_info(
node_info={
'tf_op_type': tf.random.categorical,
'tf_inputs': {
'logits': input_tensor,
'num_samples': sample_size,
'dtype': input_tensor_dtype,
'seed': seed,
},
'tf_outputs': {
'output': tf_layers_dict[graph_node_output.name]['tf_node'],
},
}
)

0 comments on commit 0fd1944

Please sign in to comment.