Skip to content

Commit

Permalink
Update TorchScript test file (#1067)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Jul 1, 2024
1 parent d31ea9a commit d4eaec9
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 97 deletions.
178 changes: 82 additions & 96 deletions source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -151,24 +151,8 @@ pytorch.Graph = class {
const name = `CONSTANTS.${key}`;
if (pytorch.Utility.isTensor(value)) {
initializers.set(value, new pytorch.Tensor(name, value));
} else if (value && value.__class__ && value.__class__.__module__ && value.__class__.__name__) {
const type = `${value.__class__.__module__}.${value.__class__.__name__}`;
switch (type) {
case '__torch__.torch.classes.xnnpack.LinearOpContext':
case '__torch__.torch.classes.xnnpack.Conv2dOpContext':
case '__torch__.torch.classes.quantized.LinearPackedParamsBase':
case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase': {
for (const [key, tensor] of Object.entries(value)) {
if (pytorch.Utility.isTensor(tensor)) {
initializers.set(tensor, new pytorch.Tensor(`${name}.${key}`, tensor));
}
}
break;
}
default: {
throw new pytorch.Error(`Unsupported constant context '${type}'.`);
}
}
} else if (pytorch.Utility.isObject(value)) {
initializers.set(value, value);
} else {
throw new pytorch.Error('Unsupported constant.');
}
Expand All @@ -191,6 +175,11 @@ pytorch.Graph = class {
initializers.set(parameter, new pytorch.Tensor(parameter.name, parameter));
}
}
} else if (pytorch.Utility.isObject(obj)) {
if (obj.__count__ === undefined || obj.__count__ === 1) {
initializers.set(obj, obj);
}
queue.push(obj);
} else if (obj && obj.__class__) {
obj.__parent__ = module;
obj.__name__ = obj.__name__ || key;
Expand Down Expand Up @@ -352,7 +341,7 @@ pytorch.Node = class {
stack = stack || new Set();
if (obj) {
for (const [name, value] of Object.entries(obj)) {
if (name === '__class__') {
if (name === '__class__' || name === '__hide__') {
continue;
} else if (name === '_parameters' && value instanceof Map) {
for (const [name, parameter] of Array.from(value)) {
Expand Down Expand Up @@ -472,28 +461,16 @@ pytorch.Node = class {
let count = 0;
for (const input of node.inputs()) {
const value = input.value;
const name = value && value.__class__ && value.__class__.__module__ && value.__class__.__name__ ? `${value.__class__.__module__}.${value.__class__.__name__}` : '';
let values = [];
switch (name) {
case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase':
case '__torch__.torch.classes.quantized.Conv3dPackedParamsBase':
case '__torch__.torch.classes.quantized.LinearPackedParamsBase':
case '__torch__.torch.classes.xnnpack.Conv2dOpContext':
case '__torch__.torch.classes.xnnpack.LinearOpContext': {
values = Object.values(value);
break;
}
default: {
if (pytorch.Utility.isTensor(value)) {
values = [value];
}
if (input.node() &&
input.node().kind() === 'prim::ListConstruct' &&
input.uses().length === 1 &&
input.node().inputs().every((input) => pytorch.Utility.isTensor(input.value))) {
values = input.node().inputs().map((input) => input.value);
}
break;
if (pytorch.Utility.isObject(value)) {
values = Object.values(value);
} else if (pytorch.Utility.isTensor(value)) {
values = [value];
if (input.node() &&
input.node().kind() === 'prim::ListConstruct' &&
input.uses().length === 1 &&
input.node().inputs().every((input) => pytorch.Utility.isTensor(input.value))) {
values = input.node().inputs().map((input) => input.value);
}
}
for (const value of values) {
Expand Down Expand Up @@ -526,64 +503,47 @@ pytorch.Node = class {
const inputs = node.inputs();
for (let i = 0; i < inputs.length; i++) {
const input = inputs[i];
const metadata = this.type && this.type.inputs && i < this.type.inputs.length ? this.type.inputs[i] : null;
const name = metadata && metadata.name ? metadata.name : i.toString();
const type = metadata && metadata.type ? metadata.type : null;
switch (type) {
case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase':
case '__torch__.torch.classes.quantized.Conv3dPackedParamsBase':
case '__torch__.torch.classes.quantized.LinearPackedParamsBase':
case '__torch__.torch.classes.xnnpack.Conv2dOpContext':
case '__torch__.torch.classes.xnnpack.LinearOpContext': {
for (const [key, value] of Object.entries(input.value)) {
if (key.startsWith('__') && key.endsWith('__')) {
continue;
}
if (pytorch.Utility.isTensor(value)) {
const initializer = initializers.get(value);
const identifier = initializer ? initializer.name : input.unique().toString();
const argument = new pytorch.Argument(key, [values.map(identifier, null, initializer)]);
this.inputs.push(argument);
} else {
const attribute = createAttribute(null, key, value);
this.attributes.push(attribute);
}
}
break;
const schema = this.type && this.type.inputs && i < this.type.inputs.length ? this.type.inputs[i] : null;
const name = schema && schema.name ? schema.name : i.toString();
const type = schema && schema.type ? schema.type : null;
let argument = null;
if (pytorch.Utility.isObjectType(type)) {
const obj = input.value;
if (initializers.has(obj)) {
const node = new pytorch.Node(metadata, group, { name, type, obj }, initializers, values);
argument = new pytorch.Argument(name, node, 'object');
} else {
const identifier = input.unique().toString();
const value = values.map(identifier);
argument = new pytorch.Argument(name, [value]);
}
default: {
if (pytorch.Utility.isTensor(input.value) || input.value === undefined || input.value === null) {
let list = [input];
if (input.node() &&
input.node().kind() === 'prim::ListConstruct' &&
input.uses().length === 1 &&
input.node().inputs().every((input) => pytorch.Utility.isTensor(input.value))) {
list = input.node().inputs();
}
const args = list.map((input) => {
let initializer = null;
let identifier = input.unique().toString();
if (input.value) {
const value = input.value;
const hide = value.__parent__ ? value.__parent__.__hide__ : true;
initializer = hide ? initializers.get(value) : null;
identifier = initializer ? initializer.name : identifier;
}
if (initializer) {
return new pytorch.Value(identifier, null, null, initializer);
}
return values.map(identifier);
});
const argument = new pytorch.Argument(name, args);
this.inputs.push(argument);
} else {
const attribute = createAttribute(metadata, metadata.name, input.value);
// this.attributes.push(attribute);
this.inputs.push(attribute);
}
break;
} else if (pytorch.Utility.isTensor(input.value) || input.value === undefined || input.value === null) {
let list = [input];
if (input.node() &&
input.node().kind() === 'prim::ListConstruct' &&
input.uses().length === 1 &&
input.node().inputs().every((input) => pytorch.Utility.isTensor(input.value))) {
list = input.node().inputs();
}
const args = list.map((input) => {
let initializer = null;
let identifier = input.unique().toString();
if (input.value) {
const value = input.value;
const hide = value.__parent__ ? value.__parent__.__hide__ : true;
initializer = hide ? initializers.get(value) : null;
identifier = initializer ? initializer.name : identifier;
}
if (initializer) {
return new pytorch.Value(identifier, null, null, initializer);
}
return values.map(identifier);
});
argument = new pytorch.Argument(name, args);
} else {
argument = createAttribute(schema, schema.name, input.value);
}
this.inputs.push(argument);
}
const outputs = node.outputs();
for (let i = 0; i < outputs.length; i++) {
Expand Down Expand Up @@ -2177,12 +2137,14 @@ pytorch.jit.Execution = class extends pytorch.Execution {
const value = this.variable(argument);
value.value = argument;
node.addInput(value);
/*
for (const [, value] of Object.entries(argument)) {
if (pytorch.Utility.isTensor(value)) {
const tensor = value;
referencedParameters.push(tensor);
}
}
*/
break;
}
default: {
Expand Down Expand Up @@ -2441,6 +2403,9 @@ pytorch.jit.Execution = class extends pytorch.Execution {
result.push(tensors);
break;
}
case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase':
case '__torch__.torch.classes.quantized.Conv3dPackedParamsBase':
case '__torch__.torch.classes.quantized.LinearPackedParamsBase':
case '__torch__.torch.classes.xnnpack.Conv2dOpContext':
case '__torch__.torch.classes.xnnpack.LinearOpContext': {
const value = this.invoke(parameter.type, []);
Expand Down Expand Up @@ -2627,8 +2592,11 @@ pytorch.jit.Execution = class extends pytorch.Execution {
break;
// case 'int64':
// break;
case '__torch__.torch.classes.xnnpack.Conv2dOpContext':
case '__torch__.torch.classes.xnnpack.LinearOpContext':
case '__torch__.torch.classes.xnnpack.Conv2dOpContext':
case '__torch__.torch.classes.quantized.LinearPackedParamsBase':
case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase':
case '__torch__.torch.classes.quantized.Conv3dPackedParamsBase':
break;
default: {
if (!outputTypes || schema.outputs.length !== 1 || schema.outputs[0].type !== outputTypes[0]) {
Expand Down Expand Up @@ -3404,6 +3372,24 @@ pytorch.Utility = class {
}
}

static isObjectType(type) {
switch (type) {
case '__torch__.torch.classes.xnnpack.LinearOpContext':
case '__torch__.torch.classes.xnnpack.Conv2dOpContext':
case '__torch__.torch.classes.quantized.LinearPackedParamsBase':
case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase':
case '__torch__.torch.classes.quantized.Conv3dPackedParamsBase':
return true;
default:
return false;
}
}

static isObject(obj) {
const type = obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__ ? `${obj.__class__.__module__}.${obj.__class__.__name__}` : null;
return pytorch.Utility.isObjectType(type);
}

static getType(value) {
if (value === null || value === undefined) {
return undefined;
Expand Down
17 changes: 17 additions & 0 deletions source/view.js
Original file line number Diff line number Diff line change
Expand Up @@ -2029,6 +2029,23 @@ view.Node = class extends grapher.Node {
}
attributes.sort((a, b) => a.name.toUpperCase().localeCompare(b.name.toUpperCase()));
}
if (Array.isArray(node.inputs)) {
for (const input of node.inputs) {
switch (input.type) {
case 'graph':
case 'object':
case 'object[]':
case 'function':
case 'function[]': {
objects.push(input);
break;
}
default: {
break;
}
}
}
}
if (initializers.length > 0 || hiddenInitializers || attributes.length > 0 || objects.length > 0) {
const list = this.list();
list.on('click', () => this.context.activate(node));
Expand Down
5 changes: 4 additions & 1 deletion test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -4930,6 +4930,7 @@
"target": "deeplabv3_scripted.ptl",
"source": "https://github.com/lutzroeder/netron/files/9562007/deeplabv3_scripted.ptl.zip[deeplabv3_scripted.ptl]",
"format": "TorchScript v1.6",
"assert": "model.graphs[0].nodes[0].inputs[1].value.type.name == '__torch__.torch.classes.xnnpack.Conv2dOpContext'",
"link": "https://github.com/lutzroeder/netron/issues/842"
},
{
Expand Down Expand Up @@ -5252,7 +5253,7 @@
"target": "model.ptl",
"source": "https://github.com/lutzroeder/netron/files/11149538/model.ptl.zip[model.ptl]",
"format": "TorchScript v1.6",
"assert": "model.graphs[0].nodes[0].inputs[1].value[0].initializer.name == 'CONSTANTS.c0.weight'",
"assert": "model.graphs[0].nodes[0].inputs[1].type == 'object'",
"link": "https://github.com/lutzroeder/netron/issues/1067"
},
{
Expand Down Expand Up @@ -5465,6 +5466,7 @@
"target": "quant_3d.pt",
"source": "https://github.com/lutzroeder/netron/files/5877566/quant_3d.pt.zip[quant_3d.pt]",
"format": "TorchScript v1.6",
"assert": "model.graphs[0].nodes[1].inputs[1].value.type.name == '__torch__.torch.classes.quantized.Conv3dPackedParamsBase'",
"link": "https://github.com/lutzroeder/netron/issues/546"
},
{
Expand Down Expand Up @@ -5836,6 +5838,7 @@
"target": "test.8bit.pth",
"source": "https://github.com/lutzroeder/netron/files/5238524/test.8bit.pth.zip[test.8bit.pth]",
"format": "TorchScript v1.6",
"assert": "model.graphs[0].nodes[8].type.name == 'conv1d_prepack'",
"link": "https://github.com/lutzroeder/netron/issues/546"
},
{
Expand Down

0 comments on commit d4eaec9

Please sign in to comment.