Skip to content

Commit

Permalink
Update PyTorch script (#637)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Jun 29, 2024
1 parent d24b6c1 commit ab7083a
Show file tree
Hide file tree
Showing 3 changed files with 188 additions and 4 deletions.
166 changes: 166 additions & 0 deletions source/pytorch-metadata.json
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,142 @@
{ "type": "Tensor" }
]
},
{
"name": "aten::Bool.Tensor",
"inputs": [
{ "name": "a", "type": "Tensor" }
],
"outputs": [
{ "type": "boolean" }
]
},
{
"name": "aten::Bool.float",
"inputs": [
{ "name": "a", "type": "float32" }
],
"outputs": [
{ "type": "boolean" }
]
},
{
"name": "aten::Bool.int",
"inputs": [
{ "name": "a", "type": "int64" }
],
"outputs": [
{ "type": "boolean" }
]
},
{
"name": "aten::Complex.Scalar",
"inputs": [
{ "name": "a", "type": "Scalar" }
],
"outputs": [
{ "type": "complex" }
]
},
{
"name": "aten::Complex.Tensor_Tensor",
"inputs": [
{ "name": "a", "type": "Tensor" },
{ "name": "b", "type": "Tensor" }
],
"outputs": [
{ "type": "complex" }
]
},
{
"name": "aten::Float.Scalar",
"inputs": [
{ "name": "a", "type": "Scalar" }
],
"outputs": [
{ "type": "float32" }
]
},
{
"name": "aten::Float.Tensor",
"inputs": [
{ "name": "a", "type": "Tensor" }
],
"outputs": [
{ "type": "float32" }
]
},
{
"name": "aten::Float.bool",
"inputs": [
{ "name": "a", "type": "boolean" }
],
"outputs": [
{ "type": "float32" }
]
},
{
"name": "aten::Float.int",
"inputs": [
{ "name": "a", "type": "int64" }
],
"outputs": [
{ "type": "float32" }
]
},
{
"name": "aten::Float.str",
"inputs": [
{ "name": "a", "type": "string" }
],
"outputs": [
{ "type": "float32" }
]
},
{
"name": "aten::Int.Scalar",
"inputs": [
{ "name": "a", "type": "Scalar" }
],
"outputs": [
{ "type": "int64" }
]
},
{
"name": "aten::Int.Tensor",
"inputs": [
{ "name": "a", "type": "Tensor" }
],
"outputs": [
{ "type": "int64" }
]
},
{
"name": "aten::Int.bool",
"inputs": [
{ "name": "a", "type": "boolean" }
],
"outputs": [
{ "type": "int64" }
]
},
{
"name": "aten::Int.float",
"inputs": [
{ "name": "a", "type": "float32" }
],
"outputs": [
{ "type": "int64" }
]
},
{
"name": "aten::Int.str",
"inputs": [
{ "name": "a", "type": "string" }
],
"outputs": [
{ "type": "int64" }
]
},
{
"name": "aten::__and__.Scalar",
"inputs": [
Expand Down Expand Up @@ -12674,6 +12810,15 @@
{ "type": "Tensor" }
]
},
{
"name": "aten::str",
"inputs": [
{ "name": "elem", "type": "t" }
],
"outputs": [
{ "type": "string" }
]
},
{
"name": "aten::stride",
"inputs": [
Expand Down Expand Up @@ -14523,6 +14668,27 @@
{ "name": "Y", "type": "Tensor" }
]
},
{
"name": "prim::NumToTensor"
},
{
"name": "prim::NumToTensor.Scalar",
"inputs": [
{ "name": "a", "type": "Scalar" }
],
"outputs": [
{ "type": "Tensor" }
]
},
{
"name": "prim::NumToTensor.bool",
"inputs": [
{ "name": "a", "type": "boolean" }
],
"outputs": [
{ "type": "Tensor" }
]
},
{
"name": "quantized::add",
"inputs": [
Expand Down
19 changes: 17 additions & 2 deletions source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -2471,9 +2471,9 @@ pytorch.jit.Execution = class extends pytorch.Execution {

_overload(target, name, args, context) {
let moduleName = pytorch.Utility.target(target);
if (moduleName && name) {
if (moduleName) {
let outputTypes = null;
let type = `${moduleName}.${name}`;
let type = name ? `${moduleName}.${name}` : moduleName;
if (type === 'ops.prim.NumToTensor' && args.length === 1 && args[0].type === 'call' && args[0].target.member.type === 'id') {
const [arg] = args;
moduleName = pytorch.Utility.target(arg.target.target);
Expand All @@ -2486,6 +2486,19 @@ pytorch.jit.Execution = class extends pytorch.Execution {
let overloads = null;
if (type.startsWith('torch.')) {
overloads = this._types.get(`aten::${type.substring(6)}`);
/* } else if (type.startsWith('ops.prim.')) {
overloads = this._types.get(`prim::${type.substring(9)}`);
} else if (type === 'int') {
overloads = this._types.get(`aten::Int`);
// "bool": "aten::Bool"
// "int": "aten::Int"
// "float": "aten::Float"
// "complex": "aten::Complex"
// "abs": "prim::abs"
// "max": "prim::max"
// "min": "prim::min"
// "range": "fake::does_not_exist"
*/
} else if (type.startsWith('ops.') && !type.startsWith('ops.prim.')) {
const path = type.split('.');
if (path.length === 3) {
Expand Down Expand Up @@ -2612,6 +2625,8 @@ pytorch.jit.Execution = class extends pytorch.Execution {
case 'Tensor':
case 'Tensor[]':
break;
// case 'int64':
// break;
case '__torch__.torch.classes.xnnpack.Conv2dOpContext':
case '__torch__.torch.classes.xnnpack.LinearOpContext':
break;
Expand Down
7 changes: 5 additions & 2 deletions tools/pytorch_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def _write_metadata(value):
re.compile(r'(aten::.*->\s*.*)"', re.MULTILINE)),
('torch/csrc/jit/runtime/register_special_ops.cpp',
re.compile(r'(aten::.*->\s*.*)"', re.MULTILINE)),
('torch/jit/_shape_functions.py',
re.compile(r'(prim::.*->\s*.*)"', re.MULTILINE))
]

known_schema_definitions = [
Expand All @@ -70,7 +72,7 @@ def _parse_schemas():
definition = entry[2] + value if len(entry) > 2 else value
schema = pytorch.Schema(definition)
if schema.name in schemas:
raise KeyError()
raise KeyError(schema.name)
schemas[schema.name] = schema
for definition in known_schema_definitions:
schema = pytorch.Schema(definition)
Expand Down Expand Up @@ -124,7 +126,8 @@ def _check_types(types, schemas):
for key in list(types.keys()):
if key.startswith('torch.nn'):
types.pop(key)
if key.startswith('torchvision::') or \
if key.startswith('prim::') or \
key.startswith('torchvision::') or \
key.startswith('torchaudio::') or \
key.startswith('neuron::'):
types.pop(key)
Expand Down

0 comments on commit ab7083a

Please sign in to comment.