Skip to content

Commit

Permalink
Add NumPy test file (#711)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Jul 2, 2024
1 parent eeae4d4 commit aa3894a
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 2 deletions.
6 changes: 6 additions & 0 deletions source/numpy.js
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,17 @@ numpy.ModelFactory = class {
switch (context.type) {
case 'npy': {
format = 'NumPy Array';
const unresolved = new Set();
const execution = new python.Execution();
execution.on('resolve', (_, name) => unresolved.add(name));
const stream = context.stream;
const buffer = stream.peek();
const bytes = execution.invoke('io.BytesIO', [buffer]);
const array = execution.invoke('numpy.load', [bytes]);
if (unresolved.size > 0) {
const name = unresolved.values().next().value;
throw new numpy.Error(`Unknown type name '${name}'.`);
}
const layer = { type: 'numpy.ndarray', parameters: [{ name: 'value', tensor: { name: '', array } }] };
graphs.push({ layers: [layer] });
break;
Expand Down
23 changes: 21 additions & 2 deletions source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -1914,7 +1914,7 @@ python.Execution = class {
});
this.registerType('numpy.dtype', class {
constructor(obj, align, copy) {
if (typeof obj === 'string' && (obj.startsWith('<') || obj.startsWith('>'))) {
if (typeof obj === 'string' && (obj.startsWith('<') || obj.startsWith('>') || obj.startsWith('|'))) {
this.byteorder = obj.substring(0, 1);
obj = obj.substring(1);
} else {
Expand Down Expand Up @@ -1950,6 +1950,9 @@ python.Execution = class {
} else if (obj.startsWith('U')) { // Unicode string
this.kind = 'U';
this.itemsize = 4 * parseInt(obj.substring(1), 10);
} else if (obj.startsWith('T')) {
this.kind = 'T';
this.itemsize = parseInt(obj.substring(1), 10);
} else {
throw new python.Error(`Unsupported dtype '${obj}'.`);
}
Expand All @@ -1970,6 +1973,7 @@ python.Execution = class {
case 'V': return `void${this.itemsize === 0 ? '' : (this.itemsize * 8)}`;
case 'S': return `bytes${this.itemsize === 0 ? '' : (this.itemsize * 8)}`;
case 'U': return `str${this.itemsize === 0 ? '' : (this.itemsize * 8)}`;
case 'T': return `StringDType${this.itemsize === 0 ? '' : (this.itemsize * 8)}`;
case 'M': return 'datetime64';
case 'b': return 'bool';
default: return this.__name__;
Expand Down Expand Up @@ -2032,6 +2036,8 @@ python.Execution = class {
default: throw new python.Error(`Unsupported complex itemsize '${this.itemsize}'.`);
}
case 'S':
case 'T':
return 'string';
case 'U':
return 'string';
case 'M':
Expand Down Expand Up @@ -2065,6 +2071,11 @@ python.Execution = class {
this.registerType('numpy.uint32', class extends numpy.unsignedinteger {});
this.registerType('numpy.uint64', class extends numpy.unsignedinteger {});
this.registerType('numpy.datetime64', class extends numpy.generic {});
this.registerType('numpy.dtypes.StringDType', class extends numpy.dtype {
constructor() {
super('|T16');
}
});
this.registerType('gensim.models.doc2vec.Doctag', class {});
this.registerType('gensim.models.doc2vec.Doc2Vec', class {});
this.registerType('gensim.models.doc2vec.Doc2VecTrainables', class {});
Expand Down Expand Up @@ -2438,6 +2449,9 @@ python.Execution = class {
}
return list;
}
case 'T': {
return this.data;
}
case 'O': {
return this.data;
}
Expand Down Expand Up @@ -3821,6 +3835,12 @@ python.Execution = class {
this.registerFunction('numpy.core.multiarray._reconstruct', (subtype, shape, dtype) => {
return numpy.ndarray.__new__(subtype, shape, dtype);
});
this.registerFunction('numpy._core.multiarray._reconstruct', (subtype, shape, dtype) => {
return numpy.ndarray.__new__(subtype, shape, dtype);
});
this.registerFunction('numpy._core._internal._convert_to_stringdtype_kwargs', () => {
return new numpy.dtypes.StringDType();
});
numpy.core._multiarray_umath._reconstruct = numpy.core.multiarray._reconstruct;
this.registerFunction('numpy.core.multiarray.scalar', (dtype, rawData) => {
let data = rawData;
Expand Down Expand Up @@ -3911,7 +3931,6 @@ python.Execution = class {
throw new python.Error(`Unsupported scalar type '${dtype.__name__}'.`);
}
});
numpy._core = numpy.core;
this.registerFunction('numpy.load', (file) => {
// https://github.com/numpy/numpy/blob/main/numpy/lib/format.py
const signature = [0x93, 0x4E, 0x55, 0x4D, 0x50, 0x59];
Expand Down
7 changes: 7 additions & 0 deletions test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -3631,6 +3631,13 @@
"format": "NumPy Array",
"link": "https://github.com/lutzroeder/netron/issues/711"
},
{
"type": "numpy",
"target": "StringDType.npy",
"source": "https://github.com/user-attachments/files/16061374/StringDType.npy.zip[StringDType.npy]",
"format": "NumPy Array",
"link": "https://github.com/lutzroeder/netron/issues/711"
},
{
"type": "numpy",
"target": "tensor.npy",
Expand Down

0 comments on commit aa3894a

Please sign in to comment.