-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
*: replace line by line text loaders by chunk by chunk text loaders. …
…Loaders now yield token sequences of length blockSize
- Loading branch information
Showing
21 changed files
with
547 additions
and
194 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,114 @@ | ||
import * as fs from "node:fs/promises"; | ||
import * as readline from "node:readline/promises"; | ||
import createDebug from "debug"; | ||
import { createReadStream } from 'node:fs'; | ||
import { PreTrainedTokenizer } from '@xenova/transformers'; | ||
import { Dataset, Text, processing } from "@epfml/discojs"; | ||
|
||
import { Dataset, Text } from "@epfml/discojs"; | ||
const debug = createDebug("discojs-node:loaders:text"); | ||
|
||
export function load(path: string): Dataset<Text> { | ||
/** | ||
* Returns a Dataset that streams and tokenizes text to yield tokenized sequences | ||
* one at a time. | ||
* The sequences returned are going to be split into input and label sequences of size `blockSize` | ||
* The label sequences are the input sequences shifted by one token. | ||
* Since the last token of the input sequence needs a label, | ||
* we include one more token (`blockSize` + 1 total) in the sequences returned. | ||
* * Thus, each sequence yielded has size `blockSize` + 1, where the last token | ||
* is included only to be the label of the last input token: | ||
* xs = tokens[0:blockSize] | ||
* ys = tokens[1:blockSize+1] | ||
* | ||
* Because the `blockSize+1`nth token is only used as label and not as input, | ||
* the next sequence will be shifted by `blockSize` (and not `blockSize + 1`) | ||
* In other words, the dataset yields sequences of size `blockSize` + 1 | ||
* with an overlap of 1 token between each sequence. | ||
* | ||
* @param path path to the text file to read | ||
* @param tokenizer the tokenizer to use, should match the model that will be trained | ||
* @param blockSize the context length, the maximum number of tokens of input sequences | ||
* @param batchSize default to 1, the number of input sequences (of `blockSize` tokens) in each batch. | ||
* The batch size is only used to configure the chunk size of the file stream such that each chunk is | ||
* big enough to contain at least one batch. | ||
* @param minChunkSize default to 16KiB, the minimum size of each chunk in bits | ||
* @returns a dataset of tokenized input and label sequences | ||
*/ | ||
export function load(path: string, tokenizer: PreTrainedTokenizer, | ||
blockSize: number, batchSize: number = 1, minChunkSize = 16384): Dataset<Text> { | ||
return new Dataset(async function* () { | ||
const input = (await fs.open(path)).createReadStream({ encoding: "utf8" }); | ||
if (batchSize < 1 || !Number.isInteger(batchSize) || | ||
blockSize < 1 || !Number.isInteger(blockSize) || | ||
minChunkSize < 1 || !Number.isInteger(minChunkSize)) | ||
throw new Error("batchSize, blockSize and minChunkSize must be positive integers"); | ||
const sequenceLength = blockSize + 1 // + 1 for the blockSize'nth token's label | ||
// we want each chunk to be at least bigger than the block size (each chunk corresponds to a block) | ||
// (or event bigger than batch size * block size so that each chunk corresponds to a batch) | ||
const chunkTokenSize = batchSize * (sequenceLength) | ||
// We read 8*8 = 8 bytes per expected token to ensure we have enough tokens | ||
// For reference, the GPT-2 tokenizer encodes 3 to 4 bytes per token on average | ||
const chunkBitSize = Math.max(minChunkSize, chunkTokenSize * 8 * 8); | ||
debug("Setting the chunk size to %o bits", chunkBitSize) | ||
// Create a stream to read the text file chunk by chunk | ||
const stream = createReadStream(path, { | ||
encoding: "utf8", | ||
highWaterMark: chunkBitSize | ||
}); | ||
|
||
// `readline` is a bit overkill but seems standard | ||
// https://nodejs.org/api/readline.html#example-read-file-stream-line-by-line | ||
yield* readline.createInterface({ input, crlfDelay: Infinity }); | ||
// iterate over the chunks | ||
let endOfPreviousChunk = "" | ||
let alreadyAppliedPadding = false | ||
for await (const chunk of stream) { | ||
if (typeof chunk !== 'string') throw new Error('Expected file stream to yield string') | ||
debug("Reading chunk of size %o", chunk.length) | ||
// tokenize the whole chunk at once | ||
// Concatenate with potential leftovers from the previous chunk | ||
let tokens = processing.tokenize(tokenizer, endOfPreviousChunk + chunk) | ||
if (tokens.size < sequenceLength) { | ||
// throw if we need to apply padding more than once | ||
// We can pad if the whole text is smaller than block size or | ||
// if the very last chunk is smaller than block size | ||
if (alreadyAppliedPadding) | ||
throw new Error(`the chunk (${tokens.size} tokens) is too small ` + | ||
`to get a sequence of length blockSize (${sequenceLength} tokens). ` + | ||
`Either the text file or the chunk size (${chunkBitSize} bits) is too small.`); | ||
// if this isn't the first iteration we simply skip | ||
// as we expect the last chunk to be potentially smaller than the block size | ||
debug("chunk smaller than block size, padding to blockSize") | ||
yield processing.tokenize(tokenizer, endOfPreviousChunk + chunk, { | ||
padding: true, max_length: sequenceLength | ||
}) | ||
alreadyAppliedPadding = true | ||
continue | ||
} | ||
debug("batch per chunk: %o", tokens.size / (batchSize * blockSize)) | ||
// yield one block of tokens at a time | ||
while (tokens.size >= sequenceLength) { | ||
yield tokens.take(sequenceLength); | ||
tokens = tokens.slice(blockSize); // only shift by blockSize rather than sequenceLength | ||
} | ||
// keep the last tokens for the next chunk | ||
// if this was the last one the remaining tokens are discarded | ||
if (tokens.size > 0) { | ||
// We actually need to decode the tokens to get the leftover text | ||
// instead of simply keeping the remaining tokens. | ||
// this is because the tokens may be different once prepended to the next chunk | ||
// e.g. if the remaining text is ". A" and the next chunk starts with "nother" | ||
// the tokenization will be different than if we simply concatenate the remaining tokens | ||
endOfPreviousChunk = tokenizer.decode( | ||
tokens.toArray(), | ||
{ skip_special_tokens: true } | ||
) | ||
debug("End of chunk, remaining text: '%s'", endOfPreviousChunk) | ||
} else { | ||
// Note that the difference between tokenizing and then concatenating | ||
// vs concatenating and then tokenizing can happen if their is no | ||
// remaining text. We consider this difference negligible | ||
endOfPreviousChunk = ""; | ||
} | ||
} | ||
if (endOfPreviousChunk.length === 0) return | ||
|
||
// flush the remaining text after the last chunk | ||
yield processing.tokenize(tokenizer, endOfPreviousChunk, { | ||
padding: true, max_length: sequenceLength | ||
}) | ||
}); | ||
} |
Oops, something went wrong.