diff --git a/README.md b/README.md index 68ec933b..04ddeb66 100644 --- a/README.md +++ b/README.md @@ -70,7 +70,7 @@ Video speed is adjusted approximately 50 times slower than actual speed. $ docker run --rm -it \ -v `pwd`:/workdir \ -w /workdir \ -ghcr.io/pinto0309/onnx2tf:1.0.20 +ghcr.io/pinto0309/onnx2tf:1.0.21 or @@ -512,7 +512,7 @@ Please don't post such low level questions as issues. |Mish|:heavy_check_mark:| |Mod|:heavy_check_mark:| |Mul|:heavy_check_mark:| - |Multinomial|**Help wanted**| + |Multinomial|:heavy_check_mark:| |Neg|:heavy_check_mark:| |NonMaxSuppression|:heavy_check_mark:| |NonZero|:heavy_check_mark:| diff --git a/onnx2tf/__init__.py b/onnx2tf/__init__.py index 2706adb3..e68e1227 100644 --- a/onnx2tf/__init__.py +++ b/onnx2tf/__init__.py @@ -1,3 +1,3 @@ from onnx2tf.onnx2tf import convert, main -__version__ = '1.0.20' +__version__ = '1.0.21' diff --git a/onnx2tf/ops/Multinomial.py b/onnx2tf/ops/Multinomial.py new file mode 100644 index 00000000..715e0e27 --- /dev/null +++ b/onnx2tf/ops/Multinomial.py @@ -0,0 +1,86 @@ +import random +random.seed(0) +import numpy as np +np.random.seed(0) +import tensorflow as tf +import onnx_graphsurgeon as gs +from onnx2tf.utils.common_functions import ( + get_constant_or_variable, + print_node_info, + inverted_operation_enable_disable, + make_tf_node_info, +) +from onnx2tf.utils.enums import ONNX_DTYPES_TO_TF_DTYPES + + +@print_node_info +@inverted_operation_enable_disable +def make_node( + *, + graph_node: gs.Node, + tf_layers_dict: dict, + **kwargs: dict, +): + """Multinomial + + Parameters + ---------- + graph_node: gs.Node + graph_surgeon Node + + tf_layers_dict: dict + optype, shape, dtype, tensorflow graph + """ + before_op_output_shape_trans_1 = \ + tf_layers_dict.get(graph_node.inputs[0].name, {}).get('before_op_output_shape_trans', True) + before_op_output_shape_trans = \ + before_op_output_shape_trans_1 + + graph_node_input_1 = get_constant_or_variable( + graph_node.inputs[0], + before_op_output_shape_trans, + ) + input_tensor = tf_layers_dict[graph_node_input_1.name]['tf_node'] \ + if isinstance(graph_node_input_1, gs.Variable) else graph_node_input_1 + + graph_node_output: gs.Variable = graph_node.outputs[0] + shape = graph_node_output.shape + dtype = graph_node_output.dtype + + input_tensor_dtype = ONNX_DTYPES_TO_TF_DTYPES[graph_node.attrs.get('dtype', 6)] # int32 + sample_size = graph_node.attrs.get('sample_size', 1) + seed = graph_node.attrs.get('seed', 0) + + # Preserving Graph Structure (Dict) + tf_layers_dict[graph_node_output.name] = { + 'optype': graph_node.op, + 'shape': shape, + 'dtype': dtype, + } + + # Generation of TF OP + tf_layers_dict[graph_node_output.name]['tf_node'] = \ + tf.random.categorical( + logits=input_tensor, + num_samples=sample_size, + dtype=input_tensor_dtype, + seed=seed, + name=graph_node.name, + ) + + # Generation of Debug Info + tf_layers_dict[graph_node_output.name]['tf_node_info'] = \ + make_tf_node_info( + node_info={ + 'tf_op_type': tf.random.categorical, + 'tf_inputs': { + 'logits': input_tensor, + 'num_samples': sample_size, + 'dtype': input_tensor_dtype, + 'seed': seed, + }, + 'tf_outputs': { + 'output': tf_layers_dict[graph_node_output.name]['tf_node'], + }, + } + )