Skip to content

Commit

Permalink
*: replace line by line text loaders by chunk by chunk text loaders
Browse files Browse the repository at this point in the history
discojs/src/dataset: implement and test repeat and batchWithOverlap
  • Loading branch information
JulienVig committed Nov 14, 2024
1 parent c6beac9 commit c477bb3
Show file tree
Hide file tree
Showing 15 changed files with 313 additions and 172 deletions.
13 changes: 6 additions & 7 deletions cli/src/benchmark_gpt.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import '@tensorflow/tfjs-node';
import { List } from "immutable";
import { parse } from "ts-command-line-args";
import { AutoTokenizer } from "@xenova/transformers";
Expand Down Expand Up @@ -41,7 +42,7 @@ const args = { ...defaultArgs, ...parsedArgs }
* Benchmark results are reported in https://github.com/epfml/disco/pull/659
*/

async function main(args: Required<CLIArguments>): Promise<void> {
async function main(args: Required<CLIArguments>): Promise<void> {
const { inference: benchmarkInference, modelType,
contextLength, batchSize, modelPath } = args

Expand Down Expand Up @@ -77,10 +78,11 @@ async function main(args: Required<CLIArguments>): Promise<void> {
task.trainingInformation.batchSize = batchSize
task.trainingInformation.maxSequenceLength = contextLength
const dataset = loadText('../datasets/wikitext/wiki.train.tokens')
.map(text => processing.tokenize(tokenizer, text))
.unbatch()
.batchWithOverlap(config.blockSize)

const maxLength = task.trainingInformation.maxSequenceLength ?? (tokenizer.model_max_length as number) + 1
const preprocessedDataset = dataset
.map((line) => processing.tokenizeAndLeftPad(line, tokenizer, maxLength))
.map((tokens) => [tokens.pop(), tokens.last()] as [List<number>, number])
.batch(batchSize);

Expand Down Expand Up @@ -111,10 +113,7 @@ async function main(args: Required<CLIArguments>): Promise<void> {
const iterations = 10
console.log("Generating", maxNewTokens, "new tokens")

let tokens = List(
(tokenizer(prompt, { return_tensor: false }) as { input_ids: number[] })
.input_ids,
);
let tokens = processing.tokenize(tokenizer, prompt);

let inferenceTime = 0
for (let i = 0; i < iterations; i++) {
Expand Down
37 changes: 23 additions & 14 deletions cli/src/train_gpt.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import * as tf from "@tensorflow/tfjs-node"
import "@tensorflow/tfjs-node"
import { AutoTokenizer } from "@xenova/transformers";
import { models, processing } from "@epfml/discojs";
import { models, processing, Dataset } from "@epfml/discojs";
import { List } from "immutable";

async function main(): Promise<void> {
const data = "Lorem ipsum dolor sit amet, consectetur adipis"
const datasetSource = new tf.data.FileDataSource(Buffer.from(data))
const textDataset = new tf.data.TextLineDataset(datasetSource)
const seed = 42

const config: models.GPTConfig = {
modelType: 'gpt-nano',
Expand All @@ -14,25 +14,34 @@ async function main(): Promise<void> {
evaluateEvery:50,
maxEvalBatches: 10,
blockSize: 16,
vocabSize: 50257,
debug: false
seed
}

const tokenizer = await AutoTokenizer.from_pretrained('Xenova/gpt2')
const tokenDataset = textDataset.map((text: string) => {
const tokens = processing.tokenizeAndLeftPad(text, tokenizer, config.blockSize + 1)
const ys = tf.oneHot(tokens.slice(1), tokenizer.model.vocab.length)
const xs = tf.tensor(tokens.slice(0, config.blockSize), undefined, 'int32')
return {xs, ys}
}).repeat().batch(16) as tf.data.Dataset<{ xs: tf.Tensor2D, ys: tf.Tensor3D }>

const tokenDataset = new Dataset([data])
.map((text: string) => processing.tokenize(tokenizer, text))
.unbatch()
.batchWithOverlap(config.blockSize)
.map((tokens) => [tokens.pop(), tokens.last()] as [List<number>, number])
.repeat()
.batch(8);

const model = new models.GPT(config)

for await (const logs of model.train(tokenDataset, undefined)) {
console.log(logs)
}

const generation = await model.generate("Lorem", tokenizer, { maxNewTokens: 10, doSample: false, topk: 5, temperature:0.1 })
let tokens = processing.tokenize(tokenizer, "Lorem");

const maxNewTokens = 14
for (let n = 0; n < maxNewTokens; n++) {
const next: number = (await model.predict(
List.of(tokens), { seed })
).first();
tokens = tokens.push(next)
}
const generation = tokenizer.decode(tokens.toArray(), { skip_special_tokens: true })
console.log(generation)
}

Expand Down
11 changes: 6 additions & 5 deletions discojs-node/src/loaders.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,13 @@ describe("image directory parser", () => {

describe("text parser", () => {
it("parses basic file", async () => {
const text = ["a", "b", "c"].join("\n")
await withFile(async ({ path }) => {
await fs.writeFile(path, ["a", "b", "c"].join("\n"));

const parsed = loadText(path);

expect(await parsed.size()).to.equal(3);
await fs.writeFile(path, text);
const sequences = await arrayFromAsync(loadText(path))
expect(sequences.length).to.equal(1);
expect(sequences[0]).to.equal(text);
});
});
});
36 changes: 28 additions & 8 deletions discojs-node/src/loaders/text.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,34 @@
import * as fs from "node:fs/promises";
import * as readline from "node:readline/promises";

import createDebug from "debug";
import { createReadStream } from 'node:fs';
import { Dataset, Text } from "@epfml/discojs";

export function load(path: string): Dataset<Text> {
const debug = createDebug("discojs-node:loaders:text");

/**
* Returns chunks of text. Use `minChunkSize` to ensure that
* each chunk is bigger than the expected sequence length.
*
* @param path path to the text file to read
* @param minChunkSize default to 16KiB, the minimum size of each chunk in bits
* @returns a dataset of tokenized input and label sequences
*/
export function load(path: string, minChunkSize = 16384): Dataset<Text> {
return new Dataset(async function* () {
const input = (await fs.open(path)).createReadStream({ encoding: "utf8" });
if (minChunkSize < 1 || !Number.isInteger(minChunkSize))
throw new Error("minChunkSize must be positive integers");

debug("Setting the chunk size to %o bits", minChunkSize)
// Create a stream to read the text file chunk by chunk
const stream = createReadStream(path, {
encoding: "utf8",
highWaterMark: minChunkSize
});
for await (const chunk of stream) {
if (typeof chunk !== 'string')
throw new Error('Expected file stream to yield string')

// `readline` is a bit overkill but seems standard
// https://nodejs.org/api/readline.html#example-read-file-stream-line-by-line
yield* readline.createInterface({ input, crlfDelay: Infinity });
debug("yield chunk of length: %o", chunk.length);
yield chunk
}
});
}
23 changes: 8 additions & 15 deletions discojs-web/src/loaders.spec.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import { describe, it, expect } from "vitest";

import { loadCSV, loadText } from "./loaders/index.js";

async function arrayFromAsync<T>(iter: AsyncIterable<T>): Promise<T[]> {
Expand All @@ -22,22 +21,16 @@ describe("csv parser", () => {
});

describe("text parser", () => {
it("loads", async () => {
it("loads a simple sequence", async () => {
const text = ["first", "second", "third"].join("\n")

// jsdom doesn't implement .text on File/Blob
// trick from https://github.com/jsdom/jsdom/issues/2555
const text = await (
await fetch(
// data URL content need to be url-encoded
["data:,first", "second", "third"].join("%0A"),
)
const file = await (
await fetch( "data:," + encodeURIComponent(text))
).blob();

const parsed = loadText(text);

expect(await arrayFromAsync(parsed)).to.have.ordered.members([
"first",
"second",
"third",
]);
const parsed = loadText(file)
expect(await parsed.size()).to.equal(1);
expect((await arrayFromAsync(parsed))[0]).to.equal(text);
});
});
25 changes: 0 additions & 25 deletions discojs-web/src/loaders/text.ts
Original file line number Diff line number Diff line change
@@ -1,35 +1,10 @@
import { Dataset, Text } from "@epfml/discojs";

class LineStream extends TransformStream<string, string> {
constructor() {
let current_line = "";

super({
transform: (chunk, controller) => {
const [head, ...lines] = chunk.split(/\r\n|\r|\n/);
const first_line = current_line + head;

if (lines.length === 0) {
current_line = first_line;
return;
}

controller.enqueue(first_line);
for (const line of lines.slice(0, -1)) controller.enqueue(line);

current_line = lines[lines.length - 1];
},
flush: (controller) => controller.enqueue(current_line),
});
}
}

export function load(file: Blob): Dataset<Text> {
return new Dataset(async function* () {
const reader = file
.stream()
.pipeThrough(new TextDecoderStream())
.pipeThrough(new LineStream())
.getReader();

while (true) {
Expand Down
56 changes: 55 additions & 1 deletion discojs/src/dataset/dataset.spec.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { expect } from "chai";
import { Dataset } from "./dataset.js";
import { Range } from "immutable";
import { List, Range } from "immutable";

// Array.fromAsync not yet widely used (2024)
async function arrayFromAsync<T>(iter: AsyncIterable<T>): Promise<T[]> {
Expand Down Expand Up @@ -139,4 +139,58 @@ describe("dataset", () => {
[3, 2],
]);
});

it("batches with overlap", async () => {
const dataset = new Dataset([1, 2, 3]);

const batched = dataset.batchWithOverlap(1);

expect(
(await arrayFromAsync(batched)).map((l) => l.toArray()),
).to.have.deep.ordered.members([[1, 2], [2, 3]]);
});

it("batchWithOverlap yields correct batches", async () => {
const expectedTokens = Range(0, 53).toList()
const blockSize = 4

const parsed = new Dataset([expectedTokens])
.unbatch()
.batchWithOverlap(blockSize)

// -1 because the last sequence is dropped as there is no next token label
const expectedLength = Math.ceil(expectedTokens.size / blockSize) - 1
expect(await parsed.size()).to.equal(expectedLength);

// exclude the last sequence because it has been padded
let sequences = List(await arrayFromAsync(parsed))
// we expect the last sequence to have blockSize + 1 tokens via padding
expect(sequences.last()?.size).to.equal(blockSize + 1)
sequences = sequences.pop()
let i = 0
for await (const tokens of sequences) {
// each sequence has length blockSize + 1 (for the label)
expect(tokens.toArray()).to.deep.equal(
expectedTokens.slice(i, i + blockSize + 1).toArray()
);
// but the window should move by blockSize only
i += blockSize
}
})

it("repeats content infinitely", async () => {
const dataset = new Dataset([0, 1, 2]).repeat();
const iter = dataset[Symbol.asyncIterator]()

for (const i of Range(0, 10)) {
const e = await iter.next()
expect(e.done).to.be.false
expect(e.value).to.equal(i % 3)
}
});

it("repeats content a fixed number of times", async () => {
const dataset = new Dataset([0, 1]).repeat(3);
expect([0,1,0,1,0,1]).to.deep.equal(await arrayFromAsync(dataset))
});
});
57 changes: 57 additions & 0 deletions discojs/src/dataset/dataset.ts
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,46 @@ export class Dataset<T> implements AsyncIterable<T> {
);
}

/**
* Create batches of size `size + 1` which overlap on one element:
* the last element of one batch is the same as the first element of the next
* Notes:
* - The resulting dataset has a batch size `size`+ 1
* - The last batch is dropped as there are no next element to add.
*
* This method is tailored to create text sequences where each token's label is the following token.
* In order to have a label for the last token of the input sequence, we include the first token
* of the next sequence.
*
* @param size batch size excluding the overlapping element, at least 1
* @returns a dataset batch size `size + 1`
*/
batchWithOverlap(size: number): Dataset<Batched<T>> {
if (size <= 0 || !Number.isInteger(size)) throw new Error("invalid size");

return new Dataset(
async function* (this: Dataset<T>) {
const iter = this.batch(size)[Symbol.asyncIterator]();
// get the first batch
const firstRes = await iter.next()
if (firstRes.done) return
let currentBatch = firstRes.value
for (; ;) {
// get the next batch
const res = await iter.next()
if (res.done) break;
const nextBatch = res.value
// get the first element of the next batch
const nextFirstElement = nextBatch.first()
if (nextFirstElement === undefined) break
// yield the current batch with the first element of the next batch
yield currentBatch.concat(nextFirstElement);
currentBatch = nextBatch
}
}.bind(this),
);
}

/** Flatten chunks */
unbatch<U>(this: Dataset<Batched<U>>): Dataset<U> {
return new Dataset(
Expand Down Expand Up @@ -176,6 +216,23 @@ export class Dataset<T> implements AsyncIterable<T> {
);
}

/**
* Repeat the dataset `times` times
* @param times number of times to repeat the dataset, if undefined, the dataset is repeated indefinitely
* @returns a dataset repeated `times` times
*/
repeat(times?: number): Dataset<T> {
return new Dataset(
async function* (this: Dataset<T>) {
let loop = 0;
do {
for await (const e of this) yield e;
loop++
} while (times === undefined || loop < times)
}.bind(this),
);
}

/** Compute size
*
* This is a costly operation as we need to go through the whole Dataset.
Expand Down
1 change: 1 addition & 0 deletions discojs/src/dataset/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ export type Batched<T> = List<T>;
export { Image };
export type Tabular = Partial<Record<string, string>>;
export type Text = string;
export type TokenizedText = List<number>;
Loading

0 comments on commit c477bb3

Please sign in to comment.