Skip to content

Commit

Permalink
Add TorchScript test file (#1061)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Jun 29, 2024
1 parent 214a237 commit 4cc6962
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 8 deletions.
14 changes: 7 additions & 7 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -5088,7 +5088,7 @@ python.Execution = class {
}
const dtype = dtypes.get(obj.dtype.str);
const strides = obj.strides.map((stride) => stride / obj.itemsize);
const storage = execution.invoke('torch.storage._TypedStorage', [obj.size, dtype]);
const storage = execution.invoke('torch.storage.TypedStorage', [obj.size, dtype]);
storage._set_cdata(obj.data);
const tensor = execution.invoke('torch.Tensor', []);
tensor.__setstate__([storage, 0, obj.shape, strides]);
Expand Down Expand Up @@ -5262,7 +5262,7 @@ python.Execution = class {
const shape = size;
dtype = torch._prims_common.dtype_or_default(dtype);
size = shape.reduce((a, b) => a * b, 1);
const storage = execution.invoke('torch.storage._TypedStorage', [size, dtype]);
const storage = execution.invoke('torch.storage.TypedStorage', [size, dtype]);
const tensor = execution.invoke('torch.Tensor', []);
tensor.__setstate__([storage, 0, shape, stride]);
return tensor;
Expand Down Expand Up @@ -6348,13 +6348,13 @@ python.Execution = class {
return storage;
}
});
this.registerType('torch.storage._UntypedStorage', class extends torch.storage._StorageBase {
this.registerType('torch.storage.UntypedStorage', class extends torch.storage._StorageBase {
constructor() {
super();
throw new python.Error('_UntypedStorage not implemented.');
throw new python.Error('UntypedStorage not implemented.');
}
});
this.registerType('torch.storage._TypedStorage', class {
this.registerType('torch.storage.TypedStorage', class {
constructor(...args) {
if (args.length >= 2 && Number.isInteger(args[0]) && args[1] instanceof torch.dtype) {
if (args[3] instanceof torch.device) {
Expand All @@ -6363,7 +6363,7 @@ python.Execution = class {
[this._size, this._dtype] = args;
}
} else {
throw new python.Error(`Unsupported _TypedStorage arguments '${JSON.stringify(args)}'.`);
throw new python.Error(`Unsupported TypedStorage arguments '${JSON.stringify(args)}'.`);
}
}
get device() {
Expand Down Expand Up @@ -6408,7 +6408,7 @@ python.Execution = class {
return storage;
}
});
this.registerType('torch.storage._LegacyStorage', class extends torch.storage._TypedStorage {
this.registerType('torch.storage._LegacyStorage', class extends torch.storage.TypedStorage {
constructor() {
super();
throw new python.Error('_LegacyStorage not implemented.');
Expand Down
1 change: 1 addition & 0 deletions source/pytorch-metadata.json
Original file line number Diff line number Diff line change
Expand Up @@ -8064,6 +8064,7 @@
},
{
"name": "aten::lstm_cell",
"category": "Layer",
"inputs": [
{ "name": "input", "type": "Tensor" },
{ "name": "hx", "type": "Tensor[]" },
Expand Down
2 changes: 1 addition & 1 deletion source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -3216,7 +3216,7 @@ pytorch.jit.FlatBuffersLoader = class {
const data = this._module.storage_data[index].data;
const dtype = this._dtypes.get(metadata.scalar_type);
const size = data.length / dtype.itemsize();
const storage = this._cu.execution.invoke('torch.storage._TypedStorage', [size, dtype]);
const storage = this._cu.execution.invoke('torch.storage.TypedStorage', [size, dtype]);
storage._set_cdata(data);
const tensor = this._cu.execution.invoke('torch.Tensor', []);
const shape = Array.from(metadata.sizes);
Expand Down
8 changes: 8 additions & 0 deletions test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -5978,6 +5978,14 @@
"format": "TorchScript v1.6",
"link": "https://github.com/lutzroeder/netron/issues/637"
},
{
"type": "pytorch",
"target": "v1_lj_8000.jit",
"source": "https://github.com/user-attachments/files/16041474/v1_lj_8000.jit.zip[v1_lj_8000.jit]",
"format": "TorchScript v1.6",
"error": "Cannot read properties of undefined (reading 'position')",
"link": "https://github.com/lutzroeder/netron/issues/1061"
},
{
"type": "pytorch",
"target": "v3_1_ru.pt",
Expand Down

0 comments on commit 4cc6962

Please sign in to comment.