Skip to content

Commit

Permalink
Add QNN test file (#1283)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Jun 1, 2024
1 parent 6cc0eec commit edcc750
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 36 deletions.
116 changes: 80 additions & 36 deletions source/qnn.js
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,13 @@ qnn.Model = class {

constructor(metadata, obj, weights) {
this.format = 'QNN';
if (obj.converter_command) {
this.producer = obj.converter_command.split(' ').shift();
}
this.metadata = [];
if (obj.copyright_str) {
this.metadata.push(new qnn.Argument('License', obj.copyright_str));
}
this.graphs = [new qnn.Graph(metadata, obj.graph, weights)];
}
};
Expand All @@ -68,45 +74,32 @@ qnn.Graph = class {
this.outputs = [];
this.nodes = [];
const values = new Map();
values.map = (name, type, tensor) => {
values.map = (name, type, tensor, quantization) => {
type = type || null;
tensor = tensor || null;
if (!values.has(name)) {
const value = new qnn.Value(name, type, tensor);
const value = new qnn.Value(name, type, tensor, quantization);
values.set(name, value);
} else if ((type && !type.equals(values.get(name).type)) || tensor) {
throw new qnn.Error(`Duplicate value '${name}'.`);
}
return values.get(name);
};
const dataType = (value) => {
switch (value) {
case 0x0008: return 'int8';
case 0x0016: return 'int16';
case 0x0032: return 'int32';
case 0x0108: return 'int8';
case 0x0132: return 'int32';
case 0x0216: return 'float16';
case 0x0232: return 'float32';
case 0x0308: return 'qint8';
case 0x0316: return 'qint16';
case 0x0332: return 'qint32';
case 0x0408: return 'uint8';
case 0x0416: return 'uint16';
case 0x0432: return 'uint32';
case 0x0508: return 'boolean';
default: throw new qnn.Error(`Unsupported data type '${JSON.stringify(value)}'.`);
}
};
const tensors = Object.entries(obj.tensors);
for (const [name, obj] of tensors) {
const shape = new qnn.TensorShape(obj.dims);
const type = new qnn.TensorType(dataType(obj.data_type), shape);
const dataType = qnn.Utility.dataType(obj.data_type);
const denotation = obj.axis_format ? obj.axis_format : '';
const type = new qnn.TensorType(dataType, shape, denotation);
switch (obj.type) {
case 0: {
const value = values.map(name, type);
const value = values.map(name, type, null, obj.quant_params);
const argument = new qnn.Argument(name, [value]);
this.inputs.push(argument);
break;
}
case 1: {
const value = values.map(name, type);
const value = values.map(name, type, null, obj.quant_params);
const argument = new qnn.Argument(name, [value]);
this.outputs.push(argument);
break;
Expand All @@ -117,8 +110,8 @@ qnn.Graph = class {
}
case 4: {
const reader = weights.get(`${name}.raw`);
const initializer = new qnn.Tensor(obj, type, reader);
values.map(name, type, initializer);
const tensor = new qnn.Tensor(name, obj, reader);
values.map(name, type, tensor, obj.quant_params);
break;
}
default: {
Expand All @@ -128,7 +121,7 @@ qnn.Graph = class {
}
const nodes = Object.entries(obj.nodes);
for (const [name, obj] of nodes) {
const node = new qnn.Node(metadata, name, obj, values);
const node = new qnn.Node(metadata, name, obj, values, weights);
this.nodes.push(node);
}
}
Expand All @@ -146,13 +139,20 @@ qnn.Argument = class {

qnn.Value = class {

constructor(name, type, initializer) {
constructor(name, type, initializer, quantization) {
if (typeof name !== 'string') {
throw new qnn.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
}
this.name = name;
this.type = type;
this.initializer = initializer;
if (quantization && quantization.definition === 1 && quantization.scale_offset) {
this.quantization = {
type: 'linear',
scale: [quantization.scale_offset.scale],
offset: [quantization.scale_offset.offset]
};
}
}
};

Expand Down Expand Up @@ -185,30 +185,49 @@ qnn.Node = class {
const argument = new qnn.Argument(outputs.length === 1 ? 'output' : 'outputs', outputs);
this.outputs.push(argument);
}
for (const [name, value] of Object.entries(obj.scalar_params)) {
const entries = Object.entries(value);
if (entries.length === 1 && name !== 'packageName') {
const dataType = qnn.Utility.dataType(parseInt(entries[0][0], 10));
const argument = new qnn.Argument(name, entries[0][1], dataType);
this.attributes.push(argument);
}
}
for (const [name, value] of Object.entries(obj.tensor_params)) {
const entries = Object.entries(value);
if (entries.length === 1 && name !== 'packageName') {
const tensor = new qnn.Tensor(name, entries[0][1]);
const argument = new qnn.Argument(name, tensor, 'tensor');
this.attributes.push(argument);
}
}
}
};

qnn.Tensor = class {

constructor(obj, type, reader) {
this.type = type;
this.encoding = '<';
this._reader = reader;
constructor(name, obj, data) {
const shape = new qnn.TensorShape(obj.dims);
const dataType = qnn.Utility.dataType(obj.data_type);
this.type = new qnn.TensorType(dataType, shape);
this.data = obj.data ? obj.data.flat() : data;
this.encoding = Array.isArray(this.data) ? '|' : '<';
}

get values() {
if (this._reader) {
return this._reader.peek();
if (this.data && this.data.peak) {
return this.data.peek();
}
return null;
return this.data;
}
};

qnn.TensorType = class {

constructor(dataType, shape) {
constructor(dataType, shape, denotation) {
this.dataType = dataType;
this.shape = shape;
this.denotation = denotation;
}

toString() {
Expand All @@ -230,6 +249,31 @@ qnn.TensorShape = class {
}
};

qnn.Utility = class {

static dataType(value) {
switch (value) {
case 0x0008: return 'int8';
case 0x0016: return 'int16';
case 0x0032: return 'int32';
case 0x0064: return 'int64';
case 0x0108: return 'int8';
case 0x0132: return 'int32';
case 0x0216: return 'float16';
case 0x0232: return 'float32';
case 0x0308: return 'qint8';
case 0x0316: return 'qint16';
case 0x0332: return 'qint32';
case 0x0408: return 'uint8';
case 0x0416: return 'uint16';
case 0x0432: return 'uint32';
case 0x0508: return 'boolean';
case 0x7fffffff: return 'string';
default: throw new qnn.Error(`Unsupported data type '${JSON.stringify(value)}'.`);
}
}
};

qnn.Error = class extends Error {

constructor(message) {
Expand Down
1 change: 1 addition & 0 deletions test/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -4658,6 +4658,7 @@
"target": "mobilenetv2-12_net.json",
"source": "https://github.com/user-attachments/files/15507409/mobilenetv2-12_net.json",
"format": "QNN",
"tags": "quantization",
"link": "https://github.com/lutzroeder/netron/issues/1283"
},
{
Expand Down

0 comments on commit edcc750

Please sign in to comment.