diff --git a/source/dlc-metadata.json b/source/dlc-metadata.json index e5cc3806c9..f601bb5c1e 100644 --- a/source/dlc-metadata.json +++ b/source/dlc-metadata.json @@ -42,6 +42,14 @@ "name": "Pooling:v3", "category": "Pool" }, + { + "name": "Prelu:v4", + "category": "Activation", + "inputs": [ + { "name": "input" }, + { "name": "weight" } + ] + }, { "name": "SoftMax:v3", "category": "Activation" @@ -113,6 +121,15 @@ "name": "Transpose:v4", "category": "Transform" }, + { + "name": "TransposeConv2d:v4", + "category": "Layer", + "inputs": [ + { "name": "input" }, + { "name": "weight" }, + { "name": "bias" } + ] + }, { "name": "StridedSlice:v4", "category": "Tensor" diff --git a/source/dlc.js b/source/dlc.js index ebb76d2995..a3e25fdbf6 100644 --- a/source/dlc.js +++ b/source/dlc.js @@ -41,6 +41,10 @@ dlc.Model = class { this.source = version ? `${source} v${version}` : source; } } + const license = target.metadata.get('model-copyright'); + if (license && license !== 'N/A') { + this.metadata.push(new dlc.Argument('license', license)); + } } for (const graph of target.graphs) { this.graphs = [new dlc.Graph(metadata, target.version, graph)]; @@ -98,7 +102,7 @@ dlc.Graph = class { } for (const [name, tensor] of values) { const type = tensor.shape ? new dlc.TensorType(tensor.dtype, tensor.shape) : null; - const initializer = tensor.data && tensor.data ? new dlc.Tensor(type, tensor.data) : null; + const initializer = tensor.data && tensor.data ? new dlc.Tensor(tensor.name, type, tensor.data) : null; const value = new dlc.Value(name, type, initializer); values.set(name, value); } @@ -179,8 +183,9 @@ dlc.Node = class { let type = attr.type; switch (type) { case 'tensor': { - const type = new dlc.TensorType(attr.data.dtype, attr.data.shape); - value = new dlc.Tensor(type, attr.data.data); + const tensor = attr.data; + const type = new dlc.TensorType(tensor.dtype, tensor.shape); + value = new dlc.Tensor(tensor.name, type, tensor.data); break; } default: { @@ -198,7 +203,7 @@ dlc.Node = class { if (obj.weights) { for (const tensor of obj.weights) { const type = new dlc.TensorType(tensor.data.dtype, tensor.shape); - const value = new dlc.Value('', type, new dlc.Tensor(type, tensor.data)); + const value = new dlc.Value('', type, new dlc.Tensor(tensor.name, type, tensor.data)); this.inputs.push(new dlc.Argument(tensor.name, [value])); } } @@ -233,7 +238,8 @@ dlc.TensorShape = class { dlc.Tensor = class { - constructor(type, data) { + constructor(name, type, data) { + this.name = name; this.type = type; if (data instanceof Uint8Array) { this.encoding = '<'; @@ -582,7 +588,13 @@ dlc.Container = class { graph.tensors.sort((a, b) => a.name.localeCompare(b.name)); for (const tensor of graph.tensors) { if (tensor.location === 4) { - tensor.data = buffers ? buffers[index++].bytes : tensors.get(tensor.name).bytes; + if (buffers && index < buffers.length) { + tensor.data = buffers[index++].bytes; + } else if (tensors.has(tensor.name)) { + tensor.data = tensors.get(tensor.name).bytes; + } else { + throw new dlc.Error(`Unknown tensor `); + } } } for (let i = 0; i < graph.nodes.length; i++) { @@ -591,7 +603,13 @@ dlc.Container = class { for (const attribute of node.attributes) { const tensor = attribute.tensor; if (tensor) { - tensor.data = buffers ? buffers[index++].bytes : tensors.get(tensor.name).bytes; + if (buffers && index < buffers.length) { + tensor.data = buffers[index++].bytes; + } else if (tensors.has(tensor.name)) { + tensor.data = tensors.get(tensor.name).bytes; + } else { + throw new dlc.Error(`Unknown tensor `); + } } } }