Skip to content

Commit

Permalink
Update dlc.js
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Jun 21, 2024
1 parent 3e49fcf commit dc85d98
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 7 deletions.
17 changes: 17 additions & 0 deletions source/dlc-metadata.json
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@
"name": "Pooling:v3",
"category": "Pool"
},
{
"name": "Prelu:v4",
"category": "Activation",
"inputs": [
{ "name": "input" },
{ "name": "weight" }
]
},
{
"name": "SoftMax:v3",
"category": "Activation"
Expand Down Expand Up @@ -113,6 +121,15 @@
"name": "Transpose:v4",
"category": "Transform"
},
{
"name": "TransposeConv2d:v4",
"category": "Layer",
"inputs": [
{ "name": "input" },
{ "name": "weight" },
{ "name": "bias" }
]
},
{
"name": "StridedSlice:v4",
"category": "Tensor"
Expand Down
32 changes: 25 additions & 7 deletions source/dlc.js
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ dlc.Model = class {
this.source = version ? `${source} v${version}` : source;
}
}
const license = target.metadata.get('model-copyright');
if (license && license !== 'N/A') {
this.metadata.push(new dlc.Argument('license', license));
}
}
for (const graph of target.graphs) {
this.graphs = [new dlc.Graph(metadata, target.version, graph)];
Expand Down Expand Up @@ -98,7 +102,7 @@ dlc.Graph = class {
}
for (const [name, tensor] of values) {
const type = tensor.shape ? new dlc.TensorType(tensor.dtype, tensor.shape) : null;
const initializer = tensor.data && tensor.data ? new dlc.Tensor(type, tensor.data) : null;
const initializer = tensor.data && tensor.data ? new dlc.Tensor(tensor.name, type, tensor.data) : null;
const value = new dlc.Value(name, type, initializer);
values.set(name, value);
}
Expand Down Expand Up @@ -179,8 +183,9 @@ dlc.Node = class {
let type = attr.type;
switch (type) {
case 'tensor': {
const type = new dlc.TensorType(attr.data.dtype, attr.data.shape);
value = new dlc.Tensor(type, attr.data.data);
const tensor = attr.data;
const type = new dlc.TensorType(tensor.dtype, tensor.shape);
value = new dlc.Tensor(tensor.name, type, tensor.data);
break;
}
default: {
Expand All @@ -198,7 +203,7 @@ dlc.Node = class {
if (obj.weights) {
for (const tensor of obj.weights) {
const type = new dlc.TensorType(tensor.data.dtype, tensor.shape);
const value = new dlc.Value('', type, new dlc.Tensor(type, tensor.data));
const value = new dlc.Value('', type, new dlc.Tensor(tensor.name, type, tensor.data));
this.inputs.push(new dlc.Argument(tensor.name, [value]));
}
}
Expand Down Expand Up @@ -233,7 +238,8 @@ dlc.TensorShape = class {

dlc.Tensor = class {

constructor(type, data) {
constructor(name, type, data) {
this.name = name;
this.type = type;
if (data instanceof Uint8Array) {
this.encoding = '<';
Expand Down Expand Up @@ -582,7 +588,13 @@ dlc.Container = class {
graph.tensors.sort((a, b) => a.name.localeCompare(b.name));
for (const tensor of graph.tensors) {
if (tensor.location === 4) {
tensor.data = buffers ? buffers[index++].bytes : tensors.get(tensor.name).bytes;
if (buffers && index < buffers.length) {
tensor.data = buffers[index++].bytes;
} else if (tensors.has(tensor.name)) {
tensor.data = tensors.get(tensor.name).bytes;
} else {
throw new dlc.Error(`Unknown tensor `);
}
}
}
for (let i = 0; i < graph.nodes.length; i++) {
Expand All @@ -591,7 +603,13 @@ dlc.Container = class {
for (const attribute of node.attributes) {
const tensor = attribute.tensor;
if (tensor) {
tensor.data = buffers ? buffers[index++].bytes : tensors.get(tensor.name).bytes;
if (buffers && index < buffers.length) {
tensor.data = buffers[index++].bytes;
} else if (tensors.has(tensor.name)) {
tensor.data = tensors.get(tensor.name).bytes;
} else {
throw new dlc.Error(`Unknown tensor `);
}
}
}
}
Expand Down

0 comments on commit dc85d98

Please sign in to comment.