diff --git a/source/pytorch.js b/source/pytorch.js index bf668fe9c9..09405c6ad5 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -151,24 +151,8 @@ pytorch.Graph = class { const name = `CONSTANTS.${key}`; if (pytorch.Utility.isTensor(value)) { initializers.set(value, new pytorch.Tensor(name, value)); - } else if (value && value.__class__ && value.__class__.__module__ && value.__class__.__name__) { - const type = `${value.__class__.__module__}.${value.__class__.__name__}`; - switch (type) { - case '__torch__.torch.classes.xnnpack.LinearOpContext': - case '__torch__.torch.classes.xnnpack.Conv2dOpContext': - case '__torch__.torch.classes.quantized.LinearPackedParamsBase': - case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase': { - for (const [key, tensor] of Object.entries(value)) { - if (pytorch.Utility.isTensor(tensor)) { - initializers.set(tensor, new pytorch.Tensor(`${name}.${key}`, tensor)); - } - } - break; - } - default: { - throw new pytorch.Error(`Unsupported constant context '${type}'.`); - } - } + } else if (pytorch.Utility.isObject(value)) { + initializers.set(value, value); } else { throw new pytorch.Error('Unsupported constant.'); } @@ -191,6 +175,11 @@ pytorch.Graph = class { initializers.set(parameter, new pytorch.Tensor(parameter.name, parameter)); } } + } else if (pytorch.Utility.isObject(obj)) { + if (obj.__count__ === undefined || obj.__count__ === 1) { + initializers.set(obj, obj); + } + queue.push(obj); } else if (obj && obj.__class__) { obj.__parent__ = module; obj.__name__ = obj.__name__ || key; @@ -352,7 +341,7 @@ pytorch.Node = class { stack = stack || new Set(); if (obj) { for (const [name, value] of Object.entries(obj)) { - if (name === '__class__') { + if (name === '__class__' || name === '__hide__') { continue; } else if (name === '_parameters' && value instanceof Map) { for (const [name, parameter] of Array.from(value)) { @@ -472,28 +461,16 @@ pytorch.Node = class { let count = 0; for (const input of node.inputs()) { const value = input.value; - const name = value && value.__class__ && value.__class__.__module__ && value.__class__.__name__ ? `${value.__class__.__module__}.${value.__class__.__name__}` : ''; let values = []; - switch (name) { - case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase': - case '__torch__.torch.classes.quantized.Conv3dPackedParamsBase': - case '__torch__.torch.classes.quantized.LinearPackedParamsBase': - case '__torch__.torch.classes.xnnpack.Conv2dOpContext': - case '__torch__.torch.classes.xnnpack.LinearOpContext': { - values = Object.values(value); - break; - } - default: { - if (pytorch.Utility.isTensor(value)) { - values = [value]; - } - if (input.node() && - input.node().kind() === 'prim::ListConstruct' && - input.uses().length === 1 && - input.node().inputs().every((input) => pytorch.Utility.isTensor(input.value))) { - values = input.node().inputs().map((input) => input.value); - } - break; + if (pytorch.Utility.isObject(value)) { + values = Object.values(value); + } else if (pytorch.Utility.isTensor(value)) { + values = [value]; + if (input.node() && + input.node().kind() === 'prim::ListConstruct' && + input.uses().length === 1 && + input.node().inputs().every((input) => pytorch.Utility.isTensor(input.value))) { + values = input.node().inputs().map((input) => input.value); } } for (const value of values) { @@ -526,64 +503,47 @@ pytorch.Node = class { const inputs = node.inputs(); for (let i = 0; i < inputs.length; i++) { const input = inputs[i]; - const metadata = this.type && this.type.inputs && i < this.type.inputs.length ? this.type.inputs[i] : null; - const name = metadata && metadata.name ? metadata.name : i.toString(); - const type = metadata && metadata.type ? metadata.type : null; - switch (type) { - case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase': - case '__torch__.torch.classes.quantized.Conv3dPackedParamsBase': - case '__torch__.torch.classes.quantized.LinearPackedParamsBase': - case '__torch__.torch.classes.xnnpack.Conv2dOpContext': - case '__torch__.torch.classes.xnnpack.LinearOpContext': { - for (const [key, value] of Object.entries(input.value)) { - if (key.startsWith('__') && key.endsWith('__')) { - continue; - } - if (pytorch.Utility.isTensor(value)) { - const initializer = initializers.get(value); - const identifier = initializer ? initializer.name : input.unique().toString(); - const argument = new pytorch.Argument(key, [values.map(identifier, null, initializer)]); - this.inputs.push(argument); - } else { - const attribute = createAttribute(null, key, value); - this.attributes.push(attribute); - } - } - break; + const schema = this.type && this.type.inputs && i < this.type.inputs.length ? this.type.inputs[i] : null; + const name = schema && schema.name ? schema.name : i.toString(); + const type = schema && schema.type ? schema.type : null; + let argument = null; + if (pytorch.Utility.isObjectType(type)) { + const obj = input.value; + if (initializers.has(obj)) { + const node = new pytorch.Node(metadata, group, { name, type, obj }, initializers, values); + argument = new pytorch.Argument(name, node, 'object'); + } else { + const identifier = input.unique().toString(); + const value = values.map(identifier); + argument = new pytorch.Argument(name, [value]); } - default: { - if (pytorch.Utility.isTensor(input.value) || input.value === undefined || input.value === null) { - let list = [input]; - if (input.node() && - input.node().kind() === 'prim::ListConstruct' && - input.uses().length === 1 && - input.node().inputs().every((input) => pytorch.Utility.isTensor(input.value))) { - list = input.node().inputs(); - } - const args = list.map((input) => { - let initializer = null; - let identifier = input.unique().toString(); - if (input.value) { - const value = input.value; - const hide = value.__parent__ ? value.__parent__.__hide__ : true; - initializer = hide ? initializers.get(value) : null; - identifier = initializer ? initializer.name : identifier; - } - if (initializer) { - return new pytorch.Value(identifier, null, null, initializer); - } - return values.map(identifier); - }); - const argument = new pytorch.Argument(name, args); - this.inputs.push(argument); - } else { - const attribute = createAttribute(metadata, metadata.name, input.value); - // this.attributes.push(attribute); - this.inputs.push(attribute); - } - break; + } else if (pytorch.Utility.isTensor(input.value) || input.value === undefined || input.value === null) { + let list = [input]; + if (input.node() && + input.node().kind() === 'prim::ListConstruct' && + input.uses().length === 1 && + input.node().inputs().every((input) => pytorch.Utility.isTensor(input.value))) { + list = input.node().inputs(); } + const args = list.map((input) => { + let initializer = null; + let identifier = input.unique().toString(); + if (input.value) { + const value = input.value; + const hide = value.__parent__ ? value.__parent__.__hide__ : true; + initializer = hide ? initializers.get(value) : null; + identifier = initializer ? initializer.name : identifier; + } + if (initializer) { + return new pytorch.Value(identifier, null, null, initializer); + } + return values.map(identifier); + }); + argument = new pytorch.Argument(name, args); + } else { + argument = createAttribute(schema, schema.name, input.value); } + this.inputs.push(argument); } const outputs = node.outputs(); for (let i = 0; i < outputs.length; i++) { @@ -1116,6 +1076,7 @@ pytorch.Container.Zip = class extends pytorch.Container { } if (torchscript) { const module = torch.jit.load(reader); + execution.trace = true; if (module.data && module.data.forward) { this._modules = new Map([['', module]]); } else { @@ -2078,393 +2039,400 @@ pytorch.jit.Execution = class extends pytorch.Execution { } call(target, name, args, context) { - const overload = this._overload(target, name, args, context); - if (overload) { - const [schema, args, evalArgs] = overload; - const copyArgs = Array.prototype.slice.call(args); - const copyEvalArgs = Array.prototype.slice.call(evalArgs); - const node = this._graph.create(schema.name); - node.schema = schema; - const referencedParameters = []; - const parameters = Array.prototype.slice.call(schema.inputs || []).concat(Array.prototype.slice.call(schema.attributes || [])); - while (copyEvalArgs.length > 0) { - if (parameters.length <= 0) { - if (schema.name.startsWith('_caffe2::')) { - break; - } - throw new pytorch.Error(); - } - if (copyArgs.every((arg) => arg.type === '=' && arg.target && arg.target.type === 'id') && - parameters.every((parameter) => parameter.type !== 'Tensor' && parameter.type !== 'Tensor[]')) { - const map = new Map(parameters.map((parameter) => [parameter.name, parameter])); - while (copyArgs.length > 0) { - const argument = copyArgs.shift(); - const arg = copyEvalArgs.shift(); - const parameter = map.get(argument.target.value); - if (!parameter) { - throw new pytorch.Error(); - } - if (!pytorch.Utility.isType(arg, parameter.type)) { - if (parameter.optional) { - continue; - } - throw new pytorch.Error(); + if (this.trace) { + const overload = this._overload(target, name, args, context); + if (overload) { + const [schema, args, evalArgs] = overload; + const copyArgs = Array.prototype.slice.call(args); + const copyEvalArgs = Array.prototype.slice.call(evalArgs); + const node = this._graph.create(schema.name); + node.schema = schema; + const referencedParameters = []; + const parameters = Array.prototype.slice.call(schema.inputs || []).concat(Array.prototype.slice.call(schema.attributes || [])); + while (copyEvalArgs.length > 0) { + if (parameters.length <= 0) { + if (schema.name.startsWith('_caffe2::')) { + break; } - const value = this.variable(arg); - value.value = arg; - node.addInput(value); + throw new pytorch.Error(); } - continue; - } - const parameter = parameters.shift(); - const [argument] = copyEvalArgs; - if (parameter.type === 'Tensor' || (parameter.type === 'Scalar' && pytorch.Utility.isTensor(argument))) { - if (Array.isArray(argument) || (!pytorch.Utility.isTensor(argument) && argument !== null && argument !== undefined)) { - if (parameter.optional) { - continue; + if (copyArgs.every((arg) => arg.type === '=' && arg.target && arg.target.type === 'id') && + parameters.every((parameter) => parameter.type !== 'Tensor' && parameter.type !== 'Tensor[]')) { + const map = new Map(parameters.map((parameter) => [parameter.name, parameter])); + while (copyArgs.length > 0) { + const argument = copyArgs.shift(); + const arg = copyEvalArgs.shift(); + const parameter = map.get(argument.target.value); + if (!parameter) { + throw new pytorch.Error(); + } + if (!pytorch.Utility.isType(arg, parameter.type)) { + if (parameter.optional) { + continue; + } + throw new pytorch.Error(); + } + const value = this.variable(arg); + value.value = arg; + node.addInput(value); } - throw new pytorch.Error(); - } else { - copyArgs.shift(); - copyEvalArgs.shift(); - const tensor = (argument === null || argument === undefined) ? {} : argument; - const value = this.variable(tensor); - referencedParameters.push(tensor); - node.addInput(value); + continue; } - } else if (parameter.type === 'Tensor[]') { + const parameter = parameters.shift(); const [argument] = copyEvalArgs; - if (!Array.isArray(argument) || !argument.every((item) => pytorch.Utility.isTensor(item) || item === null)) { - if (parameter.optional) { - continue; - } - throw new pytorch.Error(); - } else { - copyArgs.shift(); - copyEvalArgs.shift(); - - const list = this._graph.create('prim::ListConstruct'); - for (const arg of argument) { - const tensor = arg; - if (tensor) { - tensor.__count__ = (tensor.__count__ || 0) + 1; + if (parameter.type === 'Tensor' || (parameter.type === 'Scalar' && pytorch.Utility.isTensor(argument))) { + if (Array.isArray(argument) || (!pytorch.Utility.isTensor(argument) && argument !== null && argument !== undefined)) { + if (parameter.optional) { + continue; } + throw new pytorch.Error(); + } else { + copyArgs.shift(); + copyEvalArgs.shift(); + const tensor = (argument === null || argument === undefined) ? {} : argument; const value = this.variable(tensor); - list.addInput(value); + referencedParameters.push(tensor); + node.addInput(value); } + } else if (parameter.type === 'Tensor[]') { + const [argument] = copyEvalArgs; + if (!Array.isArray(argument) || !argument.every((item) => pytorch.Utility.isTensor(item) || item === null)) { + if (parameter.optional) { + continue; + } + throw new pytorch.Error(); + } else { + copyArgs.shift(); + copyEvalArgs.shift(); + + const list = this._graph.create('prim::ListConstruct'); + for (const arg of argument) { + const tensor = arg; + if (tensor) { + tensor.__count__ = (tensor.__count__ || 0) + 1; + } + const value = this.variable(tensor); + list.addInput(value); + } - const value = list.addOutput(); - node.addInput(value); - } - } else { - const [arg] = copyArgs; - if (!pytorch.Utility.isType(argument, parameter.type) && argument !== null) { - if (parameter.optional) { - continue; + const value = list.addOutput(); + node.addInput(value); } - throw new pytorch.Error(); - } else if (arg.type === '=') { - throw new pytorch.Error('Expected named argument.'); } else { - copyArgs.shift(); - copyEvalArgs.shift(); - switch (parameter.type) { - case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase': - case '__torch__.torch.classes.quantized.Conv3dPackedParamsBase': - case '__torch__.torch.classes.quantized.LinearPackedParamsBase': - case '__torch__.torch.classes.xnnpack.Conv2dOpContext': - case '__torch__.torch.classes.xnnpack.LinearOpContext': { - const value = this.variable(argument); - value.value = argument; - node.addInput(value); - for (const [, value] of Object.entries(argument)) { - if (pytorch.Utility.isTensor(value)) { - const tensor = value; - referencedParameters.push(tensor); + const [arg] = copyArgs; + if (!pytorch.Utility.isType(argument, parameter.type) && argument !== null) { + if (parameter.optional) { + continue; + } + throw new pytorch.Error(); + } else if (arg.type === '=') { + throw new pytorch.Error('Expected named argument.'); + } else { + copyArgs.shift(); + copyEvalArgs.shift(); + switch (parameter.type) { + case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase': + case '__torch__.torch.classes.quantized.Conv3dPackedParamsBase': + case '__torch__.torch.classes.quantized.LinearPackedParamsBase': + case '__torch__.torch.classes.xnnpack.Conv2dOpContext': + case '__torch__.torch.classes.xnnpack.LinearOpContext': { + const value = this.variable(argument); + value.value = argument; + node.addInput(value); + /* + for (const [, value] of Object.entries(argument)) { + if (pytorch.Utility.isTensor(value)) { + const tensor = value; + referencedParameters.push(tensor); + } } + */ + break; + } + default: { + const value = this.variable(argument); + node.addInput(value); + value.value = argument; + break; } - break; - } - default: { - const value = this.variable(argument); - node.addInput(value); - value.value = argument; - break; } } } } - } - const result = []; - for (let i = 0; i < schema.outputs.length; i++) { - const parameter = schema.outputs[i]; - switch (parameter.type) { - case 'Scalar': - case 'Tensor': { - const output = this.invoke('torch.Tensor', []); - output.__origin__ = schema.name; - if (i === 0) { - switch (schema.name) { - case 'aten::conv1d': - case 'aten::embedding': { - output.resize_([NaN, NaN, NaN]); - break; - } - case 'aten::cat': - case 'aten::conv2d': - case 'aten::dropout': - case 'aten::flatten': - case 'aten::flatten.named_out_dim': - case 'aten::max_pool2d': - case 'aten::adaptive_avg_pool2d': - case 'aten::avg_pool2d': - case 'aten::quantize_per_tensor': - case 'aten::relu_': - case 'aten::prelu': - case 'aten::hardtanh_': - case 'aten::upsample_bilinear2d': - case 'prepacked::conv2d_clamp_run': { - const [input] = evalArgs; - if (pytorch.Utility.isTensor(input) && input.size() === undefined) { - input.resize_([NaN, NaN, NaN, NaN]); - } - output.resize_([NaN, NaN, NaN, NaN]); - break; - } - case 'aten::slice': - case 'aten::slice.Tensor': { - const [input] = evalArgs; - if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) { - const size = input.size(); - output.resize_(size); + const result = []; + for (let i = 0; i < schema.outputs.length; i++) { + const parameter = schema.outputs[i]; + switch (parameter.type) { + case 'Scalar': + case 'Tensor': { + const output = this.invoke('torch.Tensor', []); + output.__origin__ = schema.name; + if (i === 0) { + switch (schema.name) { + case 'aten::conv1d': + case 'aten::embedding': { + output.resize_([NaN, NaN, NaN]); + break; } - break; - } - case 'aten::to': - case 'aten::to.device': - case 'aten::to.dtype': - case 'aten::to.dtype_layout': { - const [input] = evalArgs; - if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) { - const size = input.size(); - output.resize_(size); + case 'aten::cat': + case 'aten::conv2d': + case 'aten::dropout': + case 'aten::flatten': + case 'aten::flatten.named_out_dim': + case 'aten::max_pool2d': + case 'aten::adaptive_avg_pool2d': + case 'aten::avg_pool2d': + case 'aten::quantize_per_tensor': + case 'aten::relu_': + case 'aten::prelu': + case 'aten::hardtanh_': + case 'aten::upsample_bilinear2d': + case 'prepacked::conv2d_clamp_run': { + const [input] = evalArgs; + if (pytorch.Utility.isTensor(input) && input.size() === undefined) { + input.resize_([NaN, NaN, NaN, NaN]); + } + output.resize_([NaN, NaN, NaN, NaN]); + break; } - break; - } - case 'aten::conv3d': { - output.resize_([NaN, NaN, NaN, NaN, NaN]); - break; - } - case 'aten::roll': - case 'aten::detach': - case 'aten::mean': - case 'aten::mul': - case 'aten::mul.Scalar': - case 'aten::div': - case 'aten::div.Scalar': - case 'aten::batch_norm': - case 'aten::gelu': - case 'aten::relu': - case 'aten::clamp': - case 'aten::clamp_': - case 'aten::_add_relu_': - case 'aten::hardswish_': { - const [input] = evalArgs; - if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) { - output.resize_(input.size()); + case 'aten::slice': + case 'aten::slice.Tensor': { + const [input] = evalArgs; + if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) { + const size = input.size(); + output.resize_(size); + } + break; } - break; - } - case 'aten::add': - case 'aten::add.Scalar': - case 'aten::sub': - case 'aten::sub.Scalar': { - const [input] = evalArgs; - if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) { - output.resize_(input.size()); - } else { - const [, other] = evalArgs; - if (pytorch.Utility.isTensor(other) && Array.isArray(other.size())) { - output.resize_(other.size()); + case 'aten::to': + case 'aten::to.device': + case 'aten::to.dtype': + case 'aten::to.dtype_layout': { + const [input] = evalArgs; + if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) { + const size = input.size(); + output.resize_(size); } + break; } - break; - } - case 'aten::select': - case 'aten::select.int': { - const [input] = evalArgs; - if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) { - output.resize_(Array(input.size().length - 1).fill(NaN)); + case 'aten::conv3d': { + output.resize_([NaN, NaN, NaN, NaN, NaN]); + break; } - break; - } - case 'aten::layer_norm': { - const [input, normalized_shape] = evalArgs; - if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) { - const shape = input.size(); - if (Array.isArray(normalized_shape) && normalized_shape.length === 1) { - const [value] = normalized_shape; - shape[shape.length - 1] = value; + case 'aten::roll': + case 'aten::detach': + case 'aten::mean': + case 'aten::mul': + case 'aten::mul.Scalar': + case 'aten::div': + case 'aten::div.Scalar': + case 'aten::batch_norm': + case 'aten::gelu': + case 'aten::relu': + case 'aten::clamp': + case 'aten::clamp_': + case 'aten::_add_relu_': + case 'aten::hardswish_': { + const [input] = evalArgs; + if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) { + output.resize_(input.size()); } - output.resize_(shape); + break; } - break; - } - case 'aten::empty': - case 'aten::ones': - case 'aten::zeros': - case 'aten::zeros_like': { - output.resize_(evalArgs[0]); - break; - } - case 'aten::view': - case 'aten::reshape': - case 'aten::new_full': { - output.resize_(evalArgs[1]); - break; - } - case 'aten::squeeze': - case 'aten::squeeze.dim': { - const [input] = evalArgs; - const size = input.size(); - if (Array.isArray(size)) { - switch (evalArgs.length) { - case 1: { - output.resize_(size.filter((value) => value !== 1)); - break; + case 'aten::add': + case 'aten::add.Scalar': + case 'aten::sub': + case 'aten::sub.Scalar': { + const [input] = evalArgs; + if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) { + output.resize_(input.size()); + } else { + const [, other] = evalArgs; + if (pytorch.Utility.isTensor(other) && Array.isArray(other.size())) { + output.resize_(other.size()); } - case 2: { - const [, dim] = evalArgs; - output.resize_(size.filter((value, index) => (value !== 1 && !isNaN(value)) || index !== dim)); - break; + } + break; + } + case 'aten::select': + case 'aten::select.int': { + const [input] = evalArgs; + if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) { + output.resize_(Array(input.size().length - 1).fill(NaN)); + } + break; + } + case 'aten::layer_norm': { + const [input, normalized_shape] = evalArgs; + if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) { + const shape = input.size(); + if (Array.isArray(normalized_shape) && normalized_shape.length === 1) { + const [value] = normalized_shape; + shape[shape.length - 1] = value; } - default: { - break; + output.resize_(shape); + } + break; + } + case 'aten::empty': + case 'aten::ones': + case 'aten::zeros': + case 'aten::zeros_like': { + output.resize_(evalArgs[0]); + break; + } + case 'aten::view': + case 'aten::reshape': + case 'aten::new_full': { + output.resize_(evalArgs[1]); + break; + } + case 'aten::squeeze': + case 'aten::squeeze.dim': { + const [input] = evalArgs; + const size = input.size(); + if (Array.isArray(size)) { + switch (evalArgs.length) { + case 1: { + output.resize_(size.filter((value) => value !== 1)); + break; + } + case 2: { + const [, dim] = evalArgs; + output.resize_(size.filter((value, index) => (value !== 1 && !isNaN(value)) || index !== dim)); + break; + } + default: { + break; + } } } + break; } - break; - } - case 'aten::unsqueeze': { - const [input, dim] = evalArgs; - const size = input.size(); - if (Array.isArray(size) && dim !== undefined) { - const shape = size.slice(); - shape.splice(dim, 0, 1); - output.resize_(shape); - } else { - output.resize_([NaN, NaN, NaN, NaN]); + case 'aten::unsqueeze': { + const [input, dim] = evalArgs; + const size = input.size(); + if (Array.isArray(size) && dim !== undefined) { + const shape = size.slice(); + shape.splice(dim, 0, 1); + output.resize_(shape); + } else { + output.resize_([NaN, NaN, NaN, NaN]); + } + break; } - break; + case 'aten::transpose': + case 'aten::transpose.int': { + const [input, dim0, dim1] = evalArgs; + if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) { + const size = input.size().slice(); + const d0 = dim0 >= 0 ? dim0 : size.length + dim0; + const d1 = dim1 >= 0 ? dim1 : size.length + dim1; + const value = size[dim0]; + /* eslint-disable prefer-destructuring */ + size[d0] = size[1]; + /* eslint-enable prefer-destructuring */ + size[d1] = value; + output.resize_(size); + } + break; + } + case 'aten::contiguous': { + const [source] = evalArgs; + output.__source__ = source; + break; + } + case 'quantized::cat': + case 'quantized::cat_relu': + case 'quantized::linear': + case 'quantized::conv2d': + case 'quantized::conv2d.new': + case 'quantized::conv2d_relu': + case 'quantized::conv2d_relu.new': + case 'quantized::add': + case 'quantized::add_relu': + output.resize_([NaN, NaN, NaN, NaN]); + output.__quantized__ = true; + break; + default: + break; } - case 'aten::transpose': - case 'aten::transpose.int': { - const [input, dim0, dim1] = evalArgs; - if (pytorch.Utility.isTensor(input) && Array.isArray(input.size())) { - const size = input.size().slice(); - const d0 = dim0 >= 0 ? dim0 : size.length + dim0; - const d1 = dim1 >= 0 ? dim1 : size.length + dim1; - const value = size[dim0]; - /* eslint-disable prefer-destructuring */ - size[d0] = size[1]; - /* eslint-enable prefer-destructuring */ - size[d1] = value; - output.resize_(size); + } + this.variable(output, node); + result.push(output); + break; + } + case 'Tensor[]': { + let count = 1; + switch (schema.name) { + case 'aten::chunk': + count = node.inputs()[1].value; + break; + case 'aten::meshgrid': { + const list = node.inputs()[0].node(); + if (list.kind() === 'prim::ListConstruct') { + count = list.inputs().length; } break; } - case 'aten::contiguous': { - const [source] = evalArgs; - output.__source__ = source; + case 'aten::unbind': + case 'aten::unbind.int': + count = args[0].__tuple__ || count; break; - } - case 'quantized::cat': - case 'quantized::cat_relu': - case 'quantized::linear': - case 'quantized::conv2d': - case 'quantized::conv2d.new': - case 'quantized::conv2d_relu': - case 'quantized::conv2d_relu.new': - case 'quantized::add': - case 'quantized::add_relu': - output.resize_([NaN, NaN, NaN, NaN]); - output.__quantized__ = true; + case 'aten::broadcast_tensors': + case 'aten::split': + case 'aten::split.Tensor': + case 'aten::split_with_sizes': + if (context.target.length > 0) { + count = context.target[context.target.length - 1].length; + } break; default: break; } - } - this.variable(output, node); - result.push(output); - break; - } - case 'Tensor[]': { - let count = 1; - switch (schema.name) { - case 'aten::chunk': - count = node.inputs()[1].value; - break; - case 'aten::meshgrid': { - const list = node.inputs()[0].node(); - if (list.kind() === 'prim::ListConstruct') { - count = list.inputs().length; - } - break; - } - case 'aten::unbind': - case 'aten::unbind.int': - count = args[0].__tuple__ || count; - break; - case 'aten::broadcast_tensors': - case 'aten::split': - case 'aten::split.Tensor': - case 'aten::split_with_sizes': - if (context.target.length > 0) { - count = context.target[context.target.length - 1].length; - } - break; - default: - break; - } - const value = node.addOutput(); - const list = this._graph.create('prim::ListUnpack'); - list.addInput(value); + const value = node.addOutput(); + const list = this._graph.create('prim::ListUnpack'); + list.addInput(value); - const tensors = []; - for (let i = 0; i < count; i ++) { - const tensor = this.invoke('torch.Tensor', []); - tensor.__origin__ = schema.name; - this.variable(tensor, list); - tensors.push(tensor); + const tensors = []; + for (let i = 0; i < count; i ++) { + const tensor = this.invoke('torch.Tensor', []); + tensor.__origin__ = schema.name; + this.variable(tensor, list); + tensors.push(tensor); + } + result.push(tensors); + break; + } + case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase': + case '__torch__.torch.classes.quantized.Conv3dPackedParamsBase': + case '__torch__.torch.classes.quantized.LinearPackedParamsBase': + case '__torch__.torch.classes.xnnpack.Conv2dOpContext': + case '__torch__.torch.classes.xnnpack.LinearOpContext': { + const value = this.invoke(parameter.type, []); + this.variable(value, node); + result.push(value); + break; + } + default: { + const output = this.invoke('torch.Tensor', []); + output.resize_([]); + output.__origin__ = schema.name; + this.variable(output, node); + result.push(output); + break; } - result.push(tensors); - break; - } - case '__torch__.torch.classes.xnnpack.Conv2dOpContext': - case '__torch__.torch.classes.xnnpack.LinearOpContext': { - const value = this.invoke(parameter.type, []); - this.variable(value, node); - result.push(value); - break; - } - default: { - const output = this.invoke('torch.Tensor', []); - output.resize_([]); - output.__origin__ = schema.name; - this.variable(output, node); - result.push(output); - break; } } + for (const referencedParameter of referencedParameters) { + referencedParameter.__count__ = (referencedParameter.__count__ || 0) + 1; + } + if (result.length > 1) { + return result; + } + return result[0]; } - for (const referencedParameter of referencedParameters) { - referencedParameter.__count__ = (referencedParameter.__count__ || 0) + 1; - } - if (result.length > 1) { - return result; - } - return result[0]; } return super.call(target, name, args, context); } @@ -2627,8 +2595,11 @@ pytorch.jit.Execution = class extends pytorch.Execution { break; // case 'int64': // break; - case '__torch__.torch.classes.xnnpack.Conv2dOpContext': case '__torch__.torch.classes.xnnpack.LinearOpContext': + case '__torch__.torch.classes.xnnpack.Conv2dOpContext': + case '__torch__.torch.classes.quantized.LinearPackedParamsBase': + case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase': + case '__torch__.torch.classes.quantized.Conv3dPackedParamsBase': break; default: { if (!outputTypes || schema.outputs.length !== 1 || schema.outputs[0].type !== outputTypes[0]) { @@ -3404,6 +3375,24 @@ pytorch.Utility = class { } } + static isObjectType(type) { + switch (type) { + case '__torch__.torch.classes.xnnpack.LinearOpContext': + case '__torch__.torch.classes.xnnpack.Conv2dOpContext': + case '__torch__.torch.classes.quantized.LinearPackedParamsBase': + case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase': + case '__torch__.torch.classes.quantized.Conv3dPackedParamsBase': + return true; + default: + return false; + } + } + + static isObject(obj) { + const type = obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__ ? `${obj.__class__.__module__}.${obj.__class__.__name__}` : null; + return pytorch.Utility.isObjectType(type); + } + static getType(value) { if (value === null || value === undefined) { return undefined; diff --git a/source/view.js b/source/view.js index b09477a8d7..13788f39ce 100644 --- a/source/view.js +++ b/source/view.js @@ -2029,6 +2029,23 @@ view.Node = class extends grapher.Node { } attributes.sort((a, b) => a.name.toUpperCase().localeCompare(b.name.toUpperCase())); } + if (Array.isArray(node.inputs)) { + for (const input of node.inputs) { + switch (input.type) { + case 'graph': + case 'object': + case 'object[]': + case 'function': + case 'function[]': { + objects.push(input); + break; + } + default: { + break; + } + } + } + } if (initializers.length > 0 || hiddenInitializers || attributes.length > 0 || objects.length > 0) { const list = this.list(); list.on('click', () => this.context.activate(node)); diff --git a/test/models.json b/test/models.json index a66e8b1f2a..77e173e3ea 100644 --- a/test/models.json +++ b/test/models.json @@ -4930,6 +4930,7 @@ "target": "deeplabv3_scripted.ptl", "source": "https://github.com/lutzroeder/netron/files/9562007/deeplabv3_scripted.ptl.zip[deeplabv3_scripted.ptl]", "format": "TorchScript v1.6", + "assert": "model.graphs[0].nodes[0].inputs[1].value.type.name == '__torch__.torch.classes.xnnpack.Conv2dOpContext'", "link": "https://github.com/lutzroeder/netron/issues/842" }, { @@ -5252,7 +5253,7 @@ "target": "model.ptl", "source": "https://github.com/lutzroeder/netron/files/11149538/model.ptl.zip[model.ptl]", "format": "TorchScript v1.6", - "assert": "model.graphs[0].nodes[0].inputs[1].value[0].initializer.name == 'CONSTANTS.c0.weight'", + "assert": "model.graphs[0].nodes[0].inputs[1].type == 'object'", "link": "https://github.com/lutzroeder/netron/issues/1067" }, { @@ -5465,6 +5466,7 @@ "target": "quant_3d.pt", "source": "https://github.com/lutzroeder/netron/files/5877566/quant_3d.pt.zip[quant_3d.pt]", "format": "TorchScript v1.6", + "assert": "model.graphs[0].nodes[1].inputs[1].value.type.name == '__torch__.torch.classes.quantized.Conv3dPackedParamsBase'", "link": "https://github.com/lutzroeder/netron/issues/546" }, { @@ -5836,6 +5838,7 @@ "target": "test.8bit.pth", "source": "https://github.com/lutzroeder/netron/files/5238524/test.8bit.pth.zip[test.8bit.pth]", "format": "TorchScript v1.6", + "assert": "model.graphs[0].nodes[8].type.name == 'conv1d_prepack'", "link": "https://github.com/lutzroeder/netron/issues/546" }, {