diff --git a/src/langchain/llm.ts b/src/langchain/llm.ts index a52c778..02ff701 100644 --- a/src/langchain/llm.ts +++ b/src/langchain/llm.ts @@ -6,24 +6,29 @@ import type { LLMResult, Generation } from 'langchain/schema'; import type { GenerateOutput } from '../client-types.js'; import { GenerateInput } from '../client-types.js'; -export interface GenAIModelOptions { - modelId?: string; +interface BaseGenAIModelOptions { stream?: boolean; parameters?: Record; timeout?: number; configuration?: Configuration; } +export type GenAIModelOptions = + | (BaseGenAIModelOptions & { modelId?: string; promptId?: never }) + | (BaseGenAIModelOptions & { modelId?: never; promptId: string }); + export class GenAIModel extends BaseLLM { #client: Client; protected modelId?: string; + protected promptId?: string; protected stream: boolean; protected timeout: number | undefined; protected parameters: Record; constructor({ modelId, + promptId, stream = false, parameters, timeout, @@ -33,6 +38,7 @@ export class GenAIModel extends BaseLLM { super(baseParams ?? {}); this.modelId = modelId; + this.promptId = promptId; this.timeout = timeout; this.parameters = parameters || {}; this.stream = Boolean(stream); @@ -46,9 +52,15 @@ export class GenAIModel extends BaseLLM { const stopSequences = concatUnique(this.parameters.stop, options.stop); return prompts.map((input) => ({ - ...(!isNullish(this.modelId) && { - model_id: this.modelId, - }), + ...(!isNullish(this.promptId) + ? { + prompt_id: this.promptId, + } + : !isNullish(this.modelId) + ? { + model_id: this.modelId, + } + : {}), input, parameters: { ...this.parameters,