From 799306f07865a24b5a16f1731adc40b0baab213a Mon Sep 17 00:00:00 2001 From: Jesse Drelick Date: Fri, 6 Sep 2024 22:32:03 -0400 Subject: [PATCH] finalize prompt --- lib/agens/message.ex | 7 +++---- lib/agens/serving.ex | 32 ++++++++++++++++++++++++++++++-- test/agens/message_test.exs | 13 ++++++++++--- 3 files changed, 43 insertions(+), 9 deletions(-) diff --git a/lib/agens/message.ex b/lib/agens/message.ex index 13f1c8b..8059e66 100644 --- a/lib/agens/message.ex +++ b/lib/agens/message.ex @@ -55,10 +55,9 @@ defmodule Agens.Message do def send(%__MODULE__{} = message) do with {:ok, agent_config} <- maybe_get_agent_config(message.agent_name), - {:ok, serving_config} <- get_serving_config(agent_config, message) do - base = build_prompt(agent_config, message, serving_config.prefixes) - prompt = "[INST]#{base}[/INST]" - + {:ok, serving_config} <- get_serving_config(agent_config, message), + base <- build_prompt(agent_config, message, serving_config.prefixes), + {:ok, prompt} <- Serving.finalize(serving_config.name, base) do message = message |> Map.put(:prompt, prompt) diff --git a/lib/agens/serving.ex b/lib/agens/serving.ex index 6ba3142..6d439e5 100644 --- a/lib/agens/serving.ex +++ b/lib/agens/serving.ex @@ -25,11 +25,12 @@ defmodule Agens.Serving do name: atom(), serving: Nx.Serving.t() | module(), args: keyword(), - prefixes: Agens.Prefixes.t() | nil + prefixes: Agens.Prefixes.t() | nil, + finalize: {module(), atom()} | nil } @enforce_keys [:name, :serving] - defstruct [:name, :serving, :prefixes, args: []] + defstruct [:name, :serving, :prefixes, :finalize, args: []] end defmodule State do @@ -97,6 +98,19 @@ defmodule Agens.Serving do end) end + @doc false + @spec finalize(atom() | pid(), String.t()) :: {:ok, String.t()} | {:error, :serving_not_found} + def finalize(name, prompt) when is_atom(name) do + name + |> Agens.name_to_pid({:error, :serving_not_found}, fn pid -> + finalize(pid, prompt) + end) + end + + def finalize(pid, prompt) when is_pid(pid) do + {:ok, GenServer.call(pid, {:finalize, prompt})} + end + # =========================================================================== # Setup # =========================================================================== @@ -164,6 +178,20 @@ defmodule Agens.Serving do {:reply, result, state} end + @doc false + @impl true + @spec handle_call({:finalize, String.t()}, {pid, term}, State.t()) :: + {:reply, String.t(), State.t()} + def handle_call({:finalize, prompt}, _, %State{config: %Config{finalize: finalize}} = state) do + final = + case finalize do + {module, function} -> apply(module, function, [prompt]) + _ -> prompt + end + + {:reply, final, state} + end + # =========================================================================== # Private # =========================================================================== diff --git a/test/agens/message_test.exs b/test/agens/message_test.exs index 08c7f75..7e77f9e 100644 --- a/test/agens/message_test.exs +++ b/test/agens/message_test.exs @@ -3,6 +3,8 @@ defmodule Agens.MessageTest do alias Agens.Message + def wrap_prompt(prompt), do: "[INST]#{prompt}[/INST]" + defp start_agens(_ctx) do {:ok, _pid} = start_supervised({Agens.Supervisor, name: Agens.Supervisor}) :ok @@ -11,7 +13,8 @@ defmodule Agens.MessageTest do defp start_serving(_ctx) do %Agens.Serving.Config{ name: :text_generation, - serving: Test.Support.Serving.Stub + serving: Test.Support.Serving.Stub, + finalize: {__MODULE__, :wrap_prompt} } |> Agens.Serving.start() @@ -30,12 +33,16 @@ defmodule Agens.MessageTest do test "works with explicit serving" do serving_name = :text_generation + wrapped = + wrap_prompt( + "## Input\nThe following is the actual input from the user, system or another agent: test\n" + ) + assert %Message{ input: "test", serving_name: serving_name, result: "sent 'test' to: ", - prompt: - "[INST]## Input\nThe following is the actual input from the user, system or another agent: test\n[/INST]" + prompt: wrapped } == Message.send(%Message{serving_name: serving_name, input: "test"}) end end