diff --git a/source/coreml.js b/source/coreml.js index 2f0e2ba34f..5a4f3674f0 100644 --- a/source/coreml.js +++ b/source/coreml.js @@ -1,6 +1,5 @@ import * as base from './base.js'; -import * as text from './text.js'; const coreml = {}; @@ -53,8 +52,8 @@ coreml.ModelFactory = class { } if (identifier === 'model.mil') { try { - const reader = text.Reader.open(context.stream, 2048); - const signature = reader.read(); + const reader = context.peek('text'); + const signature = reader.peek('\n', 2048); if (signature !== undefined) { if (signature.trim().startsWith('program')) { context.type = 'coreml.mil'; diff --git a/source/darknet.js b/source/darknet.js index 4684abc1df..84e7227fe2 100644 --- a/source/darknet.js +++ b/source/darknet.js @@ -1,6 +1,4 @@ -import * as text from './text.js'; - const darknet = {}; darknet.ModelFactory = class { @@ -16,19 +14,24 @@ darknet.ModelFactory = class { } return; } - try { - const reader = text.Reader.open(context.stream, 65536); - for (let line = reader.read(); line !== undefined; line = reader.read()) { - const content = line.trim(); - if (content.length > 0 && !content.startsWith('#')) { - if (content.startsWith('[') && content.endsWith(']')) { - context.type = 'darknet.model'; - } - return; + const reader = context.peek('text'); + if (reader) { + try { + for (let line = reader.read('\n', 65536); line !== undefined; line = reader.read('\n', 65536)) { + const content = line.trim(); + if (content.length > 0 && !content.startsWith('#')) { + if (content.startsWith('[') && content.endsWith(']')) { + reader.seek(0); + context.type = 'darknet.model'; + context.target = reader; + } + return; + } } + } catch { + // continue regardless of error } - } catch { - // continue regardless of error + reader.seek(0); } } @@ -43,7 +46,7 @@ darknet.ModelFactory = class { const weights = context.target; const name = `${basename}.cfg`; const content = await context.fetch(name); - const reader = new darknet.Reader(content.stream, content.identifier); + const reader = new darknet.Reader(content.peek('text'), content.identifier); return new darknet.Model(metadata, reader, weights); } case 'darknet.model': { @@ -51,10 +54,10 @@ darknet.ModelFactory = class { const name = `${basename}.weights`; const content = await context.fetch(name); const weights = darknet.Weights.open(content); - const reader = new darknet.Reader(context.stream, context.identifier); + const reader = new darknet.Reader(context.target, context.identifier); return new darknet.Model(metadata, reader, weights); } catch { - const reader = new darknet.Reader(context.stream, context.identifier); + const reader = new darknet.Reader(context.target, context.identifier); return new darknet.Model(metadata, reader, null); } } @@ -874,8 +877,8 @@ darknet.TensorShape = class { darknet.Reader = class { - constructor(stream, identifier) { - this.stream = stream; + constructor(reader, identifier) { + this.reader = reader; this.identifier = identifier; } @@ -883,10 +886,10 @@ darknet.Reader = class { // read_cfg const sections = []; let section = null; - const reader = text.Reader.open(this.stream); + const reader = this.reader; let lineNumber = 0; const setup = /^setup.*\.cfg$/.test(this.identifier); - for (let content = reader.read(); content !== undefined; content = reader.read()) { + for (let content = reader.read('\n'); content !== undefined; content = reader.read('\n')) { lineNumber++; const line = content.replace(/\s/g, ''); if (line.length > 0) { diff --git a/source/dlc.js b/source/dlc.js index 4564aa2b55..cf665eb654 100644 --- a/source/dlc.js +++ b/source/dlc.js @@ -350,7 +350,7 @@ dlc.Container = class { delete this._metadata; const reader = text.Reader.open(stream); for (;;) { - const line = reader.read(); + const line = reader.read('\n'); if (line === undefined) { break; } diff --git a/source/ncnn.js b/source/ncnn.js index abd90aa2df..56d7ffd1eb 100644 --- a/source/ncnn.js +++ b/source/ncnn.js @@ -1,6 +1,5 @@ import * as base from './base.js'; -import * as text from './text.js'; const ncnn = {}; @@ -22,21 +21,28 @@ ncnn.ModelFactory = class { } } } else if (identifier.endsWith('.param') || identifier.endsWith('.cfg.ncnn')) { - try { - const reader = text.Reader.open(context.stream, 2048); - const signature = reader.read(); - if (signature !== undefined) { - if (signature.trim() === '7767517') { - context.type = 'ncnn.model'; - return; - } - const header = signature.trim().split(' '); - if (header.length === 2 && header.every((value) => value >>> 0 === parseFloat(value))) { - context.type = 'ncnn.model'; + const reader = context.peek('text'); + if (reader) { + try { + const signature = reader.read('\n', 2048); + if (signature !== undefined) { + if (signature.trim() === '7767517') { + reader.seek(0); + context.type = 'ncnn.model'; + context.target = reader; + return; + } + const header = signature.trim().split(' '); + if (header.length === 2 && header.every((value) => value >>> 0 === parseFloat(value))) { + reader.seek(0); + context.type = 'ncnn.model'; + context.target = reader; + } } + } catch { + // continue regardless of error } - } catch { - // continue regardless of error + reader.seek(0); } } else if (identifier.endsWith('.bin') || identifier.endsWith('.weights.ncnn')) { const stream = context.stream; @@ -80,8 +86,8 @@ ncnn.ModelFactory = class { const reader = new ncnn.BinaryParamReader(param); return new ncnn.Model(metadata, reader, bin); }; - const openText = (param, bin) => { - const reader = new ncnn.TextParamReader(param); + const openText = (reader, bin) => { + reader = new ncnn.TextParamReader(reader); return new ncnn.Model(metadata, reader, bin); }; const identifier = context.identifier.toLowerCase(); @@ -96,9 +102,9 @@ ncnn.ModelFactory = class { try { const content = await context.fetch(bin); const buffer = content.stream.peek(); - return openText(context.stream.peek(), buffer); + return openText(context.target, buffer); } catch { - return openText(context.stream.peek(), null); + return openText(context.target, null); } } case 'ncnn.model.bin': { @@ -120,8 +126,8 @@ ncnn.ModelFactory = class { } try { const content = await context.fetch(file); - const buffer = content.stream.peek(); - return openText(buffer, context.stream.peek()); + const reader = content.peek('text'); + return openText(reader, context.stream.peek()); } catch { const content = await context.fetch(`${file}.bin`); const buffer = content.stream.peek(); @@ -634,11 +640,10 @@ ncnn.Utility = class { ncnn.TextParamReader = class { - constructor(buffer) { - const reader = text.Reader.open(buffer); + constructor(reader) { const lines = []; for (;;) { - const line = reader.read(); + const line = reader.read('\n'); if (line === undefined) { break; } diff --git a/source/nnabla.js b/source/nnabla.js index 89a3c4a8ae..09fd0b6586 100644 --- a/source/nnabla.js +++ b/source/nnabla.js @@ -1,6 +1,4 @@ -import * as text from './text.js'; - const nnabla = {}; nnabla.ModelFactory = class { @@ -29,9 +27,8 @@ nnabla.ModelFactory = class { let version = ''; if (contexts.has('nnp_version.txt')) { const context = contexts.get('nnp_version.txt'); - const stream = context.stream; - const reader = text.Reader.open(stream); - version = reader.read(); + const reader = context.read('text'); + version = reader.read('\n'); version = version.split('\r').shift(); } if (contexts.has('parameter.protobuf')) { diff --git a/source/nnef.js b/source/nnef.js index da414ea369..6f9ffd71ce 100644 --- a/source/nnef.js +++ b/source/nnef.js @@ -54,7 +54,7 @@ nnef.TextReader = class { static open(stream) { const reader = text.Reader.open(stream); for (let i = 0; i < 32; i++) { - const line = reader.read(); + const line = reader.read('\n'); const match = /version\s*(\d+\.\d+);/.exec(line); if (match) { return new nnef.TextReader(stream, match[1]); diff --git a/source/onnx.js b/source/onnx.js index 6b4d44c557..ed87b0ab2f 100644 --- a/source/onnx.js +++ b/source/onnx.js @@ -2017,7 +2017,7 @@ onnx.TextReader = class { const reader = text.Reader.open(buffer); const lines = []; for (let i = 0; i < 32; i++) { - const line = reader.read(); + const line = reader.read('\n'); if (line === undefined) { break; } diff --git a/source/text.js b/source/text.js index cc192a477c..e0e68ae0a7 100644 --- a/source/text.js +++ b/source/text.js @@ -193,8 +193,8 @@ text.Decoder.Utf16LE = class { return String.fromCharCode(c); } if (c >= 0xD800 && c < 0xDBFF) { - if (this._position + 1 < this._length) { - const c2 = this._buffer[this._position++] | (this._buffer[this._position++] << 8); + if (this.position + 1 < this.length) { + const c2 = this.buffer[this.position++] | (this.buffer[this.position++] << 8); if (c >= 0xDC00 || c < 0xDFFF) { return String.fromCodePoint(0x10000 + ((c & 0x3ff) << 10) + (c2 & 0x3ff)); } @@ -225,8 +225,8 @@ text.Decoder.Utf16BE = class { return String.fromCharCode(c); } if (c >= 0xD800 && c < 0xDBFF) { - if (this._position + 1 < this._length) { - const c2 = (this._buffer[this._position++] << 8) | this._buffer[this._position++]; + if (this.position + 1 < this.length) { + const c2 = (this.buffer[this.position++] << 8) | this.buffer[this.position++]; if (c >= 0xDC00 || c < 0xDFFF) { return String.fromCodePoint(0x10000 + ((c & 0x3ff) << 10) + (c2 & 0x3ff)); } @@ -288,37 +288,56 @@ text.Decoder.Utf32BE = class { text.Reader = class { - constructor(data, length) { - this._decoder = text.Decoder.open(data); - this._position = 0; - this._length = length || Number.MAX_SAFE_INTEGER; + constructor(data + ) { + this.decoder = text.Decoder.open(data); + this.position = 0; } static open(data, length) { return new text.Reader(data, length); } - read() { - if (this._position >= this._length) { + seek(position) { + this.position = position; + this.decoder.position = 0; + if (position > 0) { + this.read(undefined, position); + delete this.length; + } + } + + peek(terminal, length) { + const position = this.position; + const offset = this.decoder.position; + const content = this.read(terminal, length); + this.decoder.position = offset; + this.position = position; + return content; + } + + read(terminal, length) { + length = length || this.length || Number.MAX_SAFE_INTEGER; + if (length && this.position >= length) { return undefined; } let line = ''; let buffer = null; for (;;) { - const c = this._decoder.decode(); + const c = this.decoder.decode(); if (c === undefined) { - this._length = this._position; + this.length = this.position; break; } - this._position++; - if (this._position > this._length) { + this.position++; + if (length && this.position > length) { break; } - if (c === '\n') { + if (c === terminal) { break; } line += c; - if (line.length >= 32) { + if (line.length >= 64) { buffer = buffer || []; buffer.push(line); line = ''; diff --git a/source/tnn.js b/source/tnn.js index 80b7ba9041..453e0aa058 100644 --- a/source/tnn.js +++ b/source/tnn.js @@ -1,6 +1,4 @@ -import * as text from './text.js'; - const tnn = {}; tnn.ModelFactory = class { @@ -10,15 +8,15 @@ tnn.ModelFactory = class { const stream = context.stream; if (stream && identifier.endsWith('.tnnproto')) { try { - const buffer = stream.peek(); - const reader = text.Reader.open(buffer, 2048); - const content = reader.read(); + const reader = context.peek('text'); + const content = reader.peek('\n', 2048); if (content !== undefined) { const line = content.trim(); if (line.startsWith('"') && line.endsWith('"')) { const header = line.replace(/(^")|("$)/g, '').split(',').shift().trim().split(' '); if (header.length === 3 || (header.length >= 4 && (header[3] === '4206624770' || header[3] === '4206624772'))) { context.type = 'tnn.model'; + context.target = reader; return; } } @@ -42,17 +40,19 @@ tnn.ModelFactory = class { switch (context.type) { case 'tnn.model': { const name = `${context.identifier.substring(0, context.identifier.length - 9)}.tnnmodel`; + const tnnproto = context.target; try { const content = await context.fetch(name); - return new tnn.Model(metadata, context, content); + return new tnn.Model(metadata, tnnproto, content); } catch { - return new tnn.Model(metadata, context, null); + return new tnn.Model(metadata, tnnproto, null); } } case 'tnn.params': { const name = `${context.identifier.substring(0, context.identifier.length - 9)}.tnnproto`; const content = await context.fetch(name, null); - return new tnn.Model(metadata, content, context); + const tnnproto = content.peek('text'); + return new tnn.Model(metadata, tnnproto, context); } default: { throw new tnn.Error(`Unsupported TNN format '${context.type}'.`); @@ -81,8 +81,8 @@ tnn.Graph = class { if (tnnmodel) { resources.read(tnnmodel); } - const reader = new tnn.TextProtoReader(tnnproto.stream); - reader.read(); + const reader = new tnn.TextProtoReader(tnnproto); + reader.read('\n'); const values = new Map(); values.map = (name, type, tensor) => { if (name.length === 0) { @@ -381,19 +381,18 @@ tnn.TensorShape = class { tnn.TextProtoReader = class { - constructor(stream) { - this.stream = stream; + constructor(reader) { + this.reader = reader; this.inputs = []; this.outputs = []; this.layers = []; } read() { - if (this.stream) { - const reader = text.Reader.open(this.stream); + if (this.reader) { let lines = []; for (;;) { - const line = reader.read(); + const line = this.reader.read('\n'); if (line === undefined) { break; } @@ -469,7 +468,7 @@ tnn.TextProtoReader = class { this.layers.push(layer); } } - delete this.stream; + delete this.reader; } } }; diff --git a/source/view.js b/source/view.js index b2c00e4462..f7ef8cac7c 100644 --- a/source/view.js +++ b/source/view.js @@ -3,6 +3,7 @@ import * as base from './base.js'; import * as zip from './zip.js'; import * as tar from './tar.js'; import * as json from './json.js'; +import * as text from './text.js'; import * as xml from './xml.js'; import * as protobuf from './protobuf.js'; import * as flatbuffers from './flatbuffers.js'; @@ -5369,6 +5370,11 @@ view.Context = class { } break; } + case 'text': { + const reader = text.Reader.open(stream); + this._content.set(type, reader); + break; + } default: { throw new view.Error(`Unsupported open format type '${type}'.`); } @@ -5599,7 +5605,7 @@ view.ModelFactoryService = class { this.register('./hickle', ['.h5', '.hkl']); this.register('./nnef', ['.nnef', '.dat']); this.register('./onednn', ['.json']); - this.register('./mlir', ['.mlir']); + this.register('./mlir', ['.mlir', '.mlir.txt']); this.register('./sentencepiece', ['.model']); this.register('./hailo', ['.hn', '.har', '.metadata.json']); this.register('./nnc', ['.nnc']);