Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(chat): add chat support to the client #57

Merged
merged 3 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions examples/chat.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import { Client } from '../src/index.js';

const client = new Client({
apiKey: process.env.GENAI_API_KEY,
});

const model_id = 'google/ul2';

{
// Start a conversation
const {
conversation_id,
result: { generated_text: answer1 },
} = await client.chat({
model_id,
messages: [
{
role: 'system',
content: 'Answer yes or no',
},
{
role: 'user',
content: 'Hello, are you a robot?',
},
],
});
console.log(answer1);

// Continue the conversation
const {
result: { generated_text: answer2 },
} = await client.chat({
conversation_id,
model_id,
messages: [
{
role: 'user',
content: 'Are you sure?',
},
],
});
console.log(answer2);
}

{
// Chat inteface has the same promise, streaming and callback variants as generate interface

// Promise
const data = await client.chat({
model_id,
messages: [{ role: 'user', content: 'How are you?' }],
});
console.log(data.result.generated_text);
// Callback
client.chat(
{ model_id, messages: [{ role: 'user', content: 'How are you?' }] },
(err, data) => {
if (err) console.error(err);
else console.log(data.result.generated_text);
},
);
// Stream
for await (const chunk of client.chat(
{ model_id, messages: [{ role: 'user', content: 'How are you?' }] },
{ stream: true },
)) {
console.log(chunk.result.generated_text);
}
// Streaming callbacks
client.chat(
{
model_id: 'google/ul2',
messages: [{ role: 'user', content: 'How are you?' }],
},
{ stream: true },
(err, data) => {
if (err) console.error(err);
else if (data) console.log(data.result.generated_text);
else console.log('EOS');
},
);
}
3 changes: 2 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@
"example:generate": "yarn run example:run examples/generate.ts",
"example:tune": "yarn run example:run examples/tune.ts",
"example:prompt-template": "yarn run example:run examples/prompt-templates.ts",
"example:file": "yarn run example:run examples/file.ts"
"example:file": "yarn run example:run examples/file.ts",
"example:chat": "yarn run example:run examples/chat.ts"
},
"peerDependencies": {
"langchain": ">=0.0.155"
Expand Down
40 changes: 39 additions & 1 deletion src/api-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,13 @@ export interface UserGenerateDefaultOutput {

// GENERATE

const ParametersSchema = z.record(z.any());

export const GenerateInputSchema = z.object({
model_id: z.string().nullish(),
prompt_id: z.string().nullish(),
inputs: z.array(z.string()),
parameters: z.optional(z.record(z.any())),
parameters: z.optional(ParametersSchema),
use_default: z.optional(z.boolean()),
});
export type GenerateInput = z.infer<typeof GenerateInputSchema>;
Expand Down Expand Up @@ -399,3 +401,39 @@ export const FilesOutputSchema = PaginationOutputSchema.extend({
results: z.array(SingleFileOutputSchema),
});
export type FilesOutput = z.output<typeof FilesOutputSchema>;

// CHAT

export const ChatRoleSchema = z.enum(['user', 'system', 'assistant']);
export type ChatRole = z.infer<typeof ChatRoleSchema>;

export const ChatInputSchema = z.object({
model_id: z.string(),
messages: z.array(
z.object({
role: ChatRoleSchema,
content: z.string(),
}),
),
conversation_id: z.string().nullish(),
parent_id: z.string().nullish(),
prompt_id: z.string().nullish(),
parameters: ParametersSchema.nullish(),
});
export type ChatInput = z.input<typeof ChatInputSchema>;
export const ChatOutputSchema = z.object({
conversation_id: z.string(),
results: z.array(
z
.object({
generated_text: z.string(),
})
.partial(),
),
});
export type ChatOutput = z.output<typeof ChatOutputSchema>;

export const ChatStreamInputSchema = ChatInputSchema;
export type ChatStreamInput = z.input<typeof ChatStreamInputSchema>;
export const ChatStreamOutputSchema = ChatOutputSchema;
export type ChatStreamOutput = z.output<typeof ChatStreamOutputSchema>;
128 changes: 111 additions & 17 deletions src/client/client.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import http, { IncomingMessage } from 'node:http';
import https from 'node:https';
import { Transform, TransformCallback } from 'stream';
import { Transform, TransformCallback } from 'node:stream';

import axios, { AxiosError } from 'axios';
import FormData from 'form-data';
Expand Down Expand Up @@ -36,10 +36,13 @@ import {
handleGenerator,
paginator,
isEmptyObject,
callbackifyStream,
callbackifyPromise,
} from '../helpers/common.js';
import { TypedReadable } from '../utils/stream.js';
import { lookupApiKey, lookupEndpoint } from '../helpers/config.js';
import { RETRY_ATTEMPTS_DEFAULT } from '../constants.js';
import { Callback } from '../helpers/types.js';

import {
GenerateConfigInput,
Expand Down Expand Up @@ -92,6 +95,12 @@ import {
FilesInput,
FileDeleteOutput,
PromptTemplateDeleteOutput,
ChatInput,
ChatOutput,
ChatOptions,
ChatStreamOptions,
ChatStreamInput,
ChatStreamOutput,
} from './types.js';
import { CacheDiscriminator, generateCacheKey } from './cache.js';

Expand All @@ -116,10 +125,6 @@ export interface Configuration {
retries?: HttpHandlerOptions['retries'];
}

type ErrorCallback = (err: unknown) => void;
type DataCallback<T> = (err: unknown, result: T) => void;
export type Callback<T> = ErrorCallback | DataCallback<T>;

export class Client {
readonly #client: AxiosCacheInstance;
readonly #options: Required<Configuration>;
Expand Down Expand Up @@ -484,12 +489,7 @@ export class Client {
return stream;
}

stream.on('data', (data) => callback(null, data));
stream.on('error', (err) => (callback as ErrorCallback)(err));
stream.on('finish', () =>
(callback as DataCallback<GenerateOutput | null>)(null, null),
);

callbackifyStream<GenerateOutput>(stream)(callback);
return;
}

Expand Down Expand Up @@ -549,12 +549,7 @@ export class Client {
});

if (callback) {
promises.forEach((promise) =>
promise.then(
(data) => callback(null as never, data),
(err) => (callback as ErrorCallback)(err),
),
);
promises.forEach((promise) => callbackifyPromise(promise)(callback));
} else {
return Array.isArray(input) ? promises : promises[0];
}
Expand Down Expand Up @@ -1320,4 +1315,103 @@ export class Client {
return transformOutput(result);
});
}

chat(input: ChatInput, callback: Callback<ChatOutput>): void;
chat(
input: ChatInput,
options: ChatOptions,
callback: Callback<ChatOutput>,
): void;
chat(
input: ChatStreamInput,
options: ChatStreamOptions,
callback: Callback<ChatStreamOutput | null>,
): void;
chat(input: ChatInput, options?: ChatOptions): Promise<ChatOutput>;
chat(
input: ChatStreamInput,
options?: ChatStreamOptions,
): TypedReadable<ChatStreamOutput>;
chat(
input: ChatInput | ChatStreamInput,
optionsOrCallback?:
| ChatOptions
| ChatStreamOptions
| Callback<ChatOutput>
| Callback<ChatStreamOutput>,
callback?: Callback<ChatOutput>,
): TypedReadable<ChatStreamOutput> | Promise<ChatOutput> | void {
const { callback: cb, options } = parseFunctionOverloads(
undefined,
optionsOrCallback,
callback,
);

if (options?.stream) {
const stream = new Transform({
autoDestroy: true,
objectMode: true,
transform(
chunk: ApiTypes.ChatStreamOutput,
encoding: BufferEncoding,
callback: TransformCallback,
) {
const { results, ...rest } = chunk;
callback(null, {
...rest,
result: results[0],
} as ChatStreamOutput);
},
});
this.#fetcher<ApiTypes.ChatStreamOutput, ApiTypes.ChatStreamInput>({
...options,
method: 'POST',
url: '/v0/generate/chat',
data: {
...input,
parameters: {
...input.parameters,
stream: true,
},
},
stream: true,
})
.on('error', (err) => stream.emit('error', errorTransformer(err)))
.pipe(stream);

if (cb) {
callbackifyStream<ChatStreamOutput>(stream)(cb);
return;
} else {
return stream;
}
} else {
const promise = (async () => {
const { results, ...rest } = await this.#fetcher<
ApiTypes.ChatOutput,
ApiTypes.ChatInput
>(
{
...options,
method: 'POST',
url: '/v0/generate/chat',
data: input,
stream: false,
},
ApiTypes.ChatOutputSchema,
);
if (results.length !== 1) {
throw new InternalError('Unexpected number of results');
}
return { ...rest, result: results[0] };
})();

if (cb) {
callbackifyPromise(promise)(cb);
return;
} else {
return promise;
}
}
}
}
21 changes: 21 additions & 0 deletions src/client/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,24 @@ export type FileDeleteOptions = HttpHandlerOptions & FlagOption<'delete', true>;
export const FilesOutputSchema =
ApiTypes.FilesOutputSchema.shape.results.element;
export type FilesOutput = z.output<typeof FilesOutputSchema>;

// CHAT

export const ChatInputSchema = z.union([
ApiTypes.ChatInputSchema,
ApiTypes.ChatStreamInputSchema,
]);
export type ChatInput = z.input<typeof ChatInputSchema>;
export type ChatOptions = HttpHandlerNoStreamOptions;
export const ChatOutputSchema = ApiTypes.ChatOutputSchema.omit({
results: true,
}).extend({ result: ApiTypes.ChatOutputSchema.shape.results.element });
export type ChatOutput = z.output<typeof ChatOutputSchema>;

export const ChatStreamInputSchema = ApiTypes.ChatStreamInputSchema;
export type ChatStreamInput = z.input<typeof ChatStreamInputSchema>;
export type ChatStreamOptions = HttpHandlerStreamOptions;
export const ChatStreamOutputSchema = ApiTypes.ChatStreamOutputSchema.omit({
results: true,
}).extend({ result: ApiTypes.ChatOutputSchema.shape.results.element });
export type ChatStreamOutput = z.output<typeof ChatStreamOutputSchema>;
23 changes: 21 additions & 2 deletions src/helpers/common.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import { callbackify } from 'node:util';
import { URLSearchParams } from 'node:url';
import { Readable } from 'node:stream';

import { z } from 'zod';

export type FalsyValues = false | '' | 0 | null | undefined;
export type Truthy<T> = T extends FalsyValues ? never : T;
import { ErrorCallback, DataCallback, Truthy, Callback } from './types.js';

export function isTruthy<T>(value: T): value is Truthy<T> {
return Boolean(value);
Expand Down Expand Up @@ -151,6 +151,25 @@ export function callbackifyGenerator<T>(generatorFn: () => AsyncGenerator<T>) {
};
}

export function callbackifyStream<T>(stream: Readable) {
return (callbackFn: Callback<T>) => {
stream.on('data', (data) => callbackFn(null, data));
stream.on('error', (err) => (callbackFn as ErrorCallback)(err));
stream.on('finish', () =>
(callbackFn as DataCallback<T | null>)(null, null),
);
};
}

export function callbackifyPromise<T>(promise: Promise<T>) {
return (callbackFn: Callback<T>) => {
promise.then(
(data) => callbackFn(null, data),
(err) => (callbackFn as ErrorCallback)(err),
);
};
}

export async function* paginator<T>(
executor: (searchParams: URLSearchParams) => Promise<{
results: T[];
Expand Down
Loading