Skip to content

Commit

Permalink
feat(lanchain): update lanchain chat model (#82)
Browse files Browse the repository at this point in the history
Signed-off-by: Tomas Pilar <tomas.pilar@ibm.com>
  • Loading branch information
pilartomas authored Mar 6, 2024
1 parent 9d32dc9 commit 238fef9
Show file tree
Hide file tree
Showing 5 changed files with 259 additions and 132 deletions.
15 changes: 3 additions & 12 deletions examples/langchain/llm-chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@ import { HumanMessage } from '@langchain/core/messages';

import { GenAIChatModel } from '../../src/langchain/llm-chat.js';

const makeClient = (stream?: boolean) =>
const makeClient = () =>
new GenAIChatModel({
modelId: 'meta-llama/llama-2-70b-chat',
stream,
model_id: 'meta-llama/llama-2-70b-chat',
configuration: {
endpoint: process.env.ENDPOINT,
apiKey: process.env.API_KEY,
Expand All @@ -16,14 +15,6 @@ const makeClient = (stream?: boolean) =>
max_new_tokens: 25,
repetition_penalty: 1.5,
},
rolesMapping: {
human: {
stopSequence: '<human>:',
},
system: {
stopSequence: '<bot>:',
},
},
});

{
Expand All @@ -41,7 +32,7 @@ const makeClient = (stream?: boolean) =>

{
// Streaming
const chat = makeClient(true);
const chat = makeClient();

await chat.invoke([new HumanMessage('Tell me a joke.')], {
callbacks: [
Expand Down
4 changes: 2 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,12 @@
"example:models": "yarn run example:run examples/models.ts"
},
"peerDependencies": {
"@langchain/core": ">=0.1.11"
"@langchain/core": ">=0.1.0"
},
"devDependencies": {
"@commitlint/cli": "^18.0.0",
"@commitlint/config-conventional": "^18.0.0",
"@langchain/core": "^0.1.11",
"@langchain/core": "^0.1.0",
"@types/lodash": "^4.14.200",
"@types/node": "^20.11.19",
"@typescript-eslint/eslint-plugin": "^6.9.0",
Expand Down
266 changes: 185 additions & 81 deletions src/langchain/llm-chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,117 +3,221 @@ import {
BaseChatModelParams,
} from '@langchain/core/language_models/chat_models';
import {
AIMessage,
AIMessageChunk,
BaseMessage,
MessageType,
SystemMessage,
} from '@langchain/core/messages';
import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager';
import { ChatResult } from '@langchain/core/outputs';
import { ChatGenerationChunk, ChatResult } from '@langchain/core/outputs';
import { BaseLanguageModelCallOptions as BaseChatModelCallOptions } from '@langchain/core/language_models/base';
import merge from 'lodash/merge.js';

import { InvalidInputError } from '../errors.js';
import { concatUnique } from '../helpers/common.js';
import type { RequiredPartial } from '../helpers/types.js';
import { TextGenerationCreateOutput } from '../schema.js';
import { Client, Configuration } from '../client.js';
import { TextChatCreateInput, TextChatCreateStreamInput } from '../schema.js';
import { InternalError, InvalidInputError } from '../errors.js';

import { GenAIModel, GenAIModelOptions } from './llm.js';
type TextChatInput = TextChatCreateInput & TextChatCreateStreamInput;

export type RolesMapping = RequiredPartial<
Record<
MessageType,
{
stopSequence: string;
}
>,
'system'
>;

type Options = BaseChatModelParams &
GenAIModelOptions & {
rolesMapping: RolesMapping;
export type GenAIChatModelParams = BaseChatModelParams &
Omit<TextChatInput, 'messages' | 'prompt_template_id'> & {
model_id: NonNullable<TextChatInput['model_id']>;
configuration?: Configuration;
};
export type GenAIChatModelOptions = BaseChatModelCallOptions &
Partial<Omit<GenAIChatModelParams, 'configuration'>>;

export class GenAIChatModel extends BaseChatModel {
readonly #model: GenAIModel;
readonly #rolesMapping: RolesMapping;
export class GenAIChatModel extends BaseChatModel<GenAIChatModelOptions> {
protected readonly client: Client;

constructor(options: Options) {
public readonly modelId: GenAIChatModelParams['model_id'];
public readonly promptId: GenAIChatModelParams['prompt_id'];
public readonly conversationId: GenAIChatModelParams['conversation_id'];
public readonly parameters: GenAIChatModelParams['parameters'];
public readonly moderations: GenAIChatModelParams['moderations'];
public readonly useConversationParameters: GenAIChatModelParams['use_conversation_parameters'];
public readonly parentId: GenAIChatModelParams['parent_id'];
public readonly trimMethod: GenAIChatModelParams['trim_method'];

constructor({
model_id,
prompt_id,
conversation_id,
parameters,
moderations,
parent_id,
use_conversation_parameters,
trim_method,
configuration,
...options
}: GenAIChatModelParams) {
super(options);

this.#rolesMapping = options.rolesMapping;
this.modelId = model_id;
this.promptId = prompt_id;
this.conversationId = conversation_id;
this.parameters = parameters;
this.moderations = moderations;
this.parentId = parent_id;
this.useConversationParameters = use_conversation_parameters;
this.trimMethod = trim_method;
this.client = new Client(configuration);
}

this.#model = new GenAIModel({
...options,
parameters: {
...options.parameters,
stop_sequences: concatUnique(
options.parameters?.stop_sequences,
Object.values(options.rolesMapping).map((role) => role.stopSequence),
),
async _generate(
messages: BaseMessage[],
options: this['ParsedCallOptions'],
_runManager?: CallbackManagerForLLMRun,
): Promise<ChatResult> {
const output = await this.client.text.chat.create(
{
...(this.conversationId
? { conversation_id: this.conversationId }
: { model_id: this.modelId, prompt_id: this.promptId }),
messages: this._convertMessages(messages),
parameters: merge(this.parameters, options.parameters),
},
configuration: {
...options.configuration,
// retries: options.maxRetries ?? options.configuration?.retries, TODO reintroduce when client has support
{ signal: options.signal },
);
if (output.results.length !== 1) throw new InternalError('Invalid result');
const result = output.results[0];
if (result.input_token_count == null)
throw new InternalError('Missing token count');
return {
generations: [
{
message: new AIMessage({ content: result.generated_text }),
text: result.generated_text,
generationInfo: {
conversationId: output.conversation_id,
inputTokens: result.input_tokens,
generatedTokens: result.generated_tokens,
seed: result.seed,
stopReason: result.stop_reason,
stopSequence: result.stop_sequence,
moderation: result.moderation,
},
},
],
llmOutput: {
tokenUsage: {
completionTokens: result.generated_token_count,
promptTokens: result.input_token_count,
totalTokens: result.generated_token_count + result.input_token_count,
},
},
});
};
}

async _generate(
async *_streamResponseChunks(
messages: BaseMessage[],
options: this['ParsedCallOptions'],
runManager?: CallbackManagerForLLMRun,
): Promise<ChatResult> {
const message = messages
.map((msg) => {
const type = this.#rolesMapping[msg._getType()];
if (!type) {
throw new InvalidInputError(
`Unsupported message type "${msg._getType()}"`,
);
_runManager?: CallbackManagerForLLMRun,
): AsyncGenerator<ChatGenerationChunk> {
const outputStream = await this.client.text.chat.create_stream(
GenAIChatModel._prepareRequest(
merge(
{
conversation_id: this.conversationId,
model_id: this.modelId,
prompt_id: this.promptId,
messages: this._convertMessages(messages),
moderations: this.moderations,
parameters: this.parameters,
use_conversation_parameters: this.useConversationParameters,
parent_id: this.parentId,
trim_method: this.trimMethod,
},
options,
),
),
{ signal: options.signal },
);
for await (const output of outputStream) {
if (output.results) {
for (const result of output.results) {
yield new ChatGenerationChunk({
message: new AIMessageChunk({
content: result.generated_text,
}),
text: result.generated_text,
generationInfo: {
conversationId: output.conversation_id,
inputTokens: result.input_tokens,
generatedTokens: result.generated_tokens,
seed: result.seed,
stopReason: result.stop_reason,
stopSequence: result.stop_sequence,
},
});
await _runManager?.handleText(result.generated_text);
}
return `${type.stopSequence}${msg.content}`;
})
.join('\n')
.concat(this.#rolesMapping.system.stopSequence);

const output = await this.#model._generate([message], options, runManager);
}
if (output.moderation) {
yield new ChatGenerationChunk({
message: new AIMessageChunk({
content: '',
}),
text: '',
generationInfo: {
conversationId: output.conversation_id,
moderation: output.moderation,
},
});
await _runManager?.handleText('');
}
}
}

private static _prepareRequest(
request: TextChatCreateInput & TextChatCreateStreamInput,
) {
const {
conversation_id,
model_id,
prompt_id,
use_conversation_parameters,
parameters,
...rest
} = request;
return {
generations: output.generations.map(([generation]) => ({
message: new SystemMessage(generation.text),
generationInfo: generation.generationInfo,
text: generation.text,
})),
llmOutput: output.llmOutput,
...(conversation_id
? { conversation_id }
: prompt_id
? { prompt_id }
: { model_id }),
...(use_conversation_parameters
? { use_conversation_parameters }
: { parameters }),
...rest,
};
}

_combineLLMOutput(...llmOutputs: TextGenerationCreateOutput[]) {
return llmOutputs
.flatMap((output) => output.results?.at(0) ?? [])
.reduce(
(acc, gen) => {
acc.tokenUsage.completionTokens += gen.generated_token_count || 0;
acc.tokenUsage.promptTokens += gen.input_token_count || 0;
acc.tokenUsage.totalTokens =
acc.tokenUsage.promptTokens + acc.tokenUsage.completionTokens;

return acc;
},
{
tokenUsage: {
completionTokens: 0,
promptTokens: 0,
totalTokens: 0,
},
},
);
private _convertMessages(
messages: BaseMessage[],
): TextChatCreateInput['messages'] & TextChatCreateStreamInput['messages'] {
return messages.map((message) => {
const content = message.content;
if (typeof content !== 'string')
throw new InvalidInputError('Multimodal messages are not supported.');
const type = message._getType();
switch (type) {
case 'system':
return { content, role: 'system' };
case 'human':
return { content, role: 'user' };
case 'ai':
return { content, role: 'assistant' };
default:
throw new InvalidInputError(`Unsupported message type "${type}"`);
}
});
}

_llmType(): string {
return 'GenAIChat';
}

_modelType(): string {
return this.#model._modelType();
return this.modelId;
}
}
Loading

0 comments on commit 238fef9

Please sign in to comment.