Skip to content

Commit

Permalink
Add PyTorch test file (#720)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Jun 19, 2024
1 parent efdbd24 commit a89820d
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 33 deletions.
64 changes: 31 additions & 33 deletions source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -350,32 +350,34 @@ pytorch.Node = class {
const entries = [];
const attributes = new Map();
stack = stack || new Set();
for (const [name, value] of Object.entries(obj)) {
if (name === '__class__') {
continue;
} else if (name === '_parameters' && value instanceof Map) {
for (const [name, parameter] of Array.from(value)) {
parameters.set(name, parameter);
}
} else if (name === '_buffers' && value instanceof Map) {
for (const [name, buffer] of Array.from(value)) {
parameters.set(name, buffer);
if (obj) {
for (const [name, value] of Object.entries(obj)) {
if (name === '__class__') {
continue;
} else if (name === '_parameters' && value instanceof Map) {
for (const [name, parameter] of Array.from(value)) {
parameters.set(name, parameter);
}
} else if (name === '_buffers' && value instanceof Map) {
for (const [name, buffer] of Array.from(value)) {
parameters.set(name, buffer);
}
} else if (Array.isArray(value) && value.every((tensor) => pytorch.Utility.isTensor(tensor))) {
parameters.set(name, value);
} else if (pytorch.Utility.isTensor(value)) {
parameters.set(name, value);
} else if (value && value.__class__ && value.__class__.__module__ === 'collections' && value.__class__.__name__ === 'OrderedDict' &&
value instanceof Map && value.size === 0) {
continue;
} else if (value && value.__class__ && value.__class__.__module__ === 'builtins' && value.__class__.__name__ === 'set' &&
value instanceof Set && value.size === 0) {
continue;
} else if (value && value.__class__ && value.__class__.__module__ === 'builtins' && value.__class__.__name__ === 'list' &&
Array.isArray(value) && value.length === 0) {
continue;
} else {
entries.push([name, value]);
}
} else if (Array.isArray(value) && value.every((tensor) => pytorch.Utility.isTensor(tensor))) {
parameters.set(name, value);
} else if (pytorch.Utility.isTensor(value)) {
parameters.set(name, value);
} else if (value && value.__class__ && value.__class__.__module__ === 'collections' && value.__class__.__name__ === 'OrderedDict' &&
value instanceof Map && value.size === 0) {
continue;
} else if (value && value.__class__ && value.__class__.__module__ === 'builtins' && value.__class__.__name__ === 'set' &&
value instanceof Set && value.size === 0) {
continue;
} else if (value && value.__class__ && value.__class__.__module__ === 'builtins' && value.__class__.__name__ === 'list' &&
Array.isArray(value) && value.length === 0) {
continue;
} else {
entries.push([name, value]);
}
}
for (const [name, value] of entries) {
Expand Down Expand Up @@ -408,15 +410,11 @@ pytorch.Node = class {
} else if (Array.isArray(value) && value.every((value) => typeof value === 'number')) {
return new pytorch.Argument(name, value);
} else if (name === '_modules' && value && value.__class__ && value.__class__.__module__ === 'collections' && value.__class__.__name__ === 'OrderedDict' &&
value instanceof Map && Array.from(value).every(([, value]) => value.__class__)) {
const values = Array.from(value).filter(([, value]) => !stack.has(value)).map(([name, value]) => {
value instanceof Map && Array.from(value).every(([, value]) => value === null || value.__class__)) {
const values = Array.from(value).filter(([, value]) => !stack.has(value)).map(([name, obj]) => {
stack.add(value);
const item = {
name,
type: `${value.__class__.__module__}.${value.__class__.__name__}`,
obj: value
};
const node = new pytorch.Node(metadata, group, item);
const type = obj === null ? 'builtins.NoneType' : `${obj.__class__.__module__}.${obj.__class__.__name__}`;
const node = new pytorch.Node(metadata, group, { name, type, obj });
stack.delete(value);
return node;
});
Expand Down
7 changes: 7 additions & 0 deletions test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -5068,6 +5068,13 @@
"error": "'torch.export' not supported.",
"link": "https://github.com/lutzroeder/netron/issues/1211"
},
{
"type": "pytorch",
"target": "hrnet_posenet_FP32.pth",
"source": "https://github.com/user-attachments/files/15894705/hrnet_posenet_FP32.pth.zip[hrnet_posenet_FP32.pth]",
"format": "PyTorch v1.6",
"link": "https://github.com/lutzroeder/netron/issues/720"
},
{
"type": "pytorch",
"target": "inception_v3_traced.pt",
Expand Down

0 comments on commit a89820d

Please sign in to comment.