Skip to content

Commit

Permalink
feat(memory): add sync mechanism to the TokenMemory
Browse files Browse the repository at this point in the history
  • Loading branch information
Tomas2D committed Sep 6, 2024
1 parent b6e4b02 commit 63975c9
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 17 deletions.
80 changes: 75 additions & 5 deletions src/memory/tokenMemory.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@ import * as R from "remeda";
import { verifyDeserialization } from "@tests/e2e/utils.js";

describe("Token Memory", () => {
const getInstance = () => {
const getInstance = (config: {
llmFactor: number;
localFactor: number;
syncThreshold: number;
maxTokens: number;
}) => {
const llm = new BAMChatLLM({
llm: new BAMLLM({
client: new Client(),
Expand All @@ -35,22 +40,87 @@ describe("Token Memory", () => {
},
});

const estimateLLM = (msg: BaseMessage) => Math.ceil(msg.text.length * config.llmFactor);
const estimateLocal = (msg: BaseMessage) => Math.ceil(msg.text.length * config.localFactor);

vi.spyOn(llm, "tokenize").mockImplementation(async (messages: BaseMessage[]) => ({
tokensCount: R.sum(messages.map((msg) => [msg.role, msg.text].join("").length)),
tokensCount: R.sum(messages.map(estimateLLM)),
}));

return new TokenMemory({
llm,
maxTokens: 1000,
maxTokens: config.maxTokens,
syncThreshold: config.syncThreshold,
handlers: {
estimate: estimateLocal,
},
});
};

it("Auto sync", async () => {
const instance = getInstance({
llmFactor: 2,
localFactor: 1,
maxTokens: 4,
syncThreshold: 0.5,
});
await instance.addMany([
BaseMessage.of({ role: Role.USER, text: "A" }),
BaseMessage.of({ role: Role.USER, text: "B" }),
BaseMessage.of({ role: Role.USER, text: "C" }),
BaseMessage.of({ role: Role.USER, text: "D" }),
]);
expect(instance.stats()).toMatchObject({
isDirty: false,
tokensUsed: 4,
messagesCount: 2,
});
});

it("Synchronizes", async () => {
const instance = getInstance({
llmFactor: 2,
localFactor: 1,
maxTokens: 10,
syncThreshold: 1,
});
expect(instance.stats()).toMatchObject({
isDirty: false,
tokensUsed: 0,
messagesCount: 0,
});
await instance.addMany([
BaseMessage.of({ role: Role.USER, text: "A" }),
BaseMessage.of({ role: Role.USER, text: "B" }),
BaseMessage.of({ role: Role.USER, text: "C" }),
BaseMessage.of({ role: Role.USER, text: "D" }),
BaseMessage.of({ role: Role.USER, text: "E" }),
BaseMessage.of({ role: Role.USER, text: "F" }),
]);
expect(instance.stats()).toMatchObject({
isDirty: true,
tokensUsed: 6,
messagesCount: 6,
});
await instance.sync();
expect(instance.stats()).toMatchObject({
isDirty: false,
tokensUsed: 10,
messagesCount: 5,
});
});

it("Serializes", async () => {
vi.stubEnv("GENAI_API_KEY", "123");
const instance = getInstance();
const instance = getInstance({
llmFactor: 2,
localFactor: 1,
maxTokens: 10,
syncThreshold: 1,
});
await instance.add(
BaseMessage.of({
text: "I am a Batman!",
text: "Hello!",
role: Role.USER,
}),
);
Expand Down
62 changes: 50 additions & 12 deletions src/memory/tokenMemory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,35 +20,46 @@ import { ChatLLM, ChatLLMOutput } from "@/llms/chat.js";
import * as R from "remeda";
import { shallowCopy } from "@/serializer/utils.js";
import { removeFromArray } from "@/internals/helpers/array.js";
import { sum } from "remeda";

export interface Handlers {
estimate: (messages: BaseMessage) => number;
removalSelector: (messages: BaseMessage[]) => BaseMessage;
}

export interface TokenMemoryInput {
llm: ChatLLM<ChatLLMOutput>;
maxTokens?: number;
syncThreshold?: number;
capacityThreshold?: number;
handlers?: Handlers;
handlers?: Partial<Handlers>;
}

interface TokenByMessage {
tokensCount: number;
dirty: boolean;
}

export class TokenMemory extends BaseMemory {
public readonly messages: BaseMessage[] = [];

protected llm: ChatLLM<ChatLLMOutput>;
protected threshold = 1;
protected threshold;
protected syncThreshold;
protected maxTokens: number | null = null;
protected tokensUsed = 0;
protected tokensByMessage = new WeakMap<BaseMessage, number>();
protected tokensByMessage = new WeakMap<BaseMessage, TokenByMessage>();
public readonly handlers: Handlers;

constructor(config: TokenMemoryInput) {
super();
this.llm = config.llm;
this.maxTokens = config.maxTokens ?? null;
this.threshold = config.capacityThreshold ?? 0.75;
this.syncThreshold = config.syncThreshold ?? 0.25;
this.handlers = {
...config?.handlers,
estimate:
config?.handlers?.estimate || ((msg) => Math.ceil((msg.role.length + msg.text.length) / 4)),
removalSelector: config.handlers?.removalSelector || ((messages) => messages[0]),
};
if (!R.clamp({ min: 0, max: 1 })(this.threshold)) {
Expand All @@ -60,13 +71,24 @@ export class TokenMemory extends BaseMemory {
this.register();
}

get tokensUsed(): number {
return sum(this.messages.map((msg) => this.tokensByMessage.get(msg)!.tokensCount!));
}

get isDirty(): boolean {
return this.messages.some((msg) => this.tokensByMessage.get(msg)?.dirty !== false);
}

async add(message: BaseMessage) {
if (this.maxTokens === null) {
const meta = await this.llm.meta();
this.maxTokens = Math.ceil((meta.tokenLimit ?? Infinity) * this.threshold);
}

const meta = await this.llm.tokenize([message]);
const meta = this.tokensByMessage.has(message)
? this.tokensByMessage.get(message)!
: { tokensCount: this.handlers.estimate(message), dirty: true };

if (meta.tokensCount > this.maxTokens) {
throw new MemoryFatalError(
`Retrieved message (${meta.tokensCount} tokens) cannot fit inside current memory (${this.maxTokens} tokens)`,
Expand All @@ -80,38 +102,54 @@ export class TokenMemory extends BaseMemory {
if (!messageToDelete || !exists) {
throw new MemoryFatalError('The "removalSelector" handler must return a valid message!');
}

const tokensCount = this.tokensByMessage.get(messageToDelete) ?? 0;
this.tokensUsed -= tokensCount;
}

this.tokensUsed += meta.tokensCount;
this.tokensByMessage.set(message, meta.tokensCount);
this.tokensByMessage.set(message, meta);
this.messages.push(message);

if (this.isDirty && this.tokensUsed / this.maxTokens >= this.syncThreshold) {
await this.sync();
}
}

async sync() {
const messages = await Promise.all(
this.messages.map(async (msg) => {
const cache = this.tokensByMessage.get(msg);
if (cache?.dirty !== false) {
const result = await this.llm.tokenize([msg]);
this.tokensByMessage.set(msg, { tokensCount: result.tokensCount, dirty: false });
}
return msg;
}),
);

this.messages.length = 0;
await this.addMany(messages);
}

reset() {
for (const msg of this.messages) {
this.tokensByMessage.delete(msg);
}
this.messages.length = 0;
this.tokensUsed = 0;
}

stats() {
return {
tokensUsed: this.tokensUsed,
maxTokens: this.maxTokens,
messagesCount: this.messages.length,
isDirty: this.isDirty,
};
}

createSnapshot() {
return {
tokensUsed: this.tokensUsed,
llm: this.llm,
maxTokens: this.maxTokens,
threshold: this.threshold,
syncThreshold: this.syncThreshold,
messages: shallowCopy(this.messages),
handlers: this.handlers,
tokensByMessage: this.messages
Expand Down

0 comments on commit 63975c9

Please sign in to comment.