Skip to content

Commit

Permalink
finalize prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
jessedrelick committed Sep 7, 2024
1 parent 8fb7a8b commit 799306f
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 9 deletions.
7 changes: 3 additions & 4 deletions lib/agens/message.ex
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,9 @@ defmodule Agens.Message do

def send(%__MODULE__{} = message) do
with {:ok, agent_config} <- maybe_get_agent_config(message.agent_name),
{:ok, serving_config} <- get_serving_config(agent_config, message) do
base = build_prompt(agent_config, message, serving_config.prefixes)
prompt = "<s>[INST]#{base}[/INST]"

{:ok, serving_config} <- get_serving_config(agent_config, message),
base <- build_prompt(agent_config, message, serving_config.prefixes),
{:ok, prompt} <- Serving.finalize(serving_config.name, base) do
message =
message
|> Map.put(:prompt, prompt)
Expand Down
32 changes: 30 additions & 2 deletions lib/agens/serving.ex
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ defmodule Agens.Serving do
name: atom(),
serving: Nx.Serving.t() | module(),
args: keyword(),
prefixes: Agens.Prefixes.t() | nil
prefixes: Agens.Prefixes.t() | nil,
finalize: {module(), atom()} | nil
}

@enforce_keys [:name, :serving]
defstruct [:name, :serving, :prefixes, args: []]
defstruct [:name, :serving, :prefixes, :finalize, args: []]
end

defmodule State do
Expand Down Expand Up @@ -97,6 +98,19 @@ defmodule Agens.Serving do
end)
end

@doc false
@spec finalize(atom() | pid(), String.t()) :: {:ok, String.t()} | {:error, :serving_not_found}
def finalize(name, prompt) when is_atom(name) do
name
|> Agens.name_to_pid({:error, :serving_not_found}, fn pid ->
finalize(pid, prompt)
end)
end

def finalize(pid, prompt) when is_pid(pid) do
{:ok, GenServer.call(pid, {:finalize, prompt})}
end

# ===========================================================================
# Setup
# ===========================================================================
Expand Down Expand Up @@ -164,6 +178,20 @@ defmodule Agens.Serving do
{:reply, result, state}
end

@doc false
@impl true
@spec handle_call({:finalize, String.t()}, {pid, term}, State.t()) ::
{:reply, String.t(), State.t()}
def handle_call({:finalize, prompt}, _, %State{config: %Config{finalize: finalize}} = state) do
final =
case finalize do
{module, function} -> apply(module, function, [prompt])
_ -> prompt
end

{:reply, final, state}
end

# ===========================================================================
# Private
# ===========================================================================
Expand Down
13 changes: 10 additions & 3 deletions test/agens/message_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ defmodule Agens.MessageTest do

alias Agens.Message

def wrap_prompt(prompt), do: "<s>[INST]#{prompt}[/INST]"

defp start_agens(_ctx) do
{:ok, _pid} = start_supervised({Agens.Supervisor, name: Agens.Supervisor})
:ok
Expand All @@ -11,7 +13,8 @@ defmodule Agens.MessageTest do
defp start_serving(_ctx) do
%Agens.Serving.Config{
name: :text_generation,
serving: Test.Support.Serving.Stub
serving: Test.Support.Serving.Stub,
finalize: {__MODULE__, :wrap_prompt}
}
|> Agens.Serving.start()

Expand All @@ -30,12 +33,16 @@ defmodule Agens.MessageTest do
test "works with explicit serving" do
serving_name = :text_generation

wrapped =
wrap_prompt(
"## Input\nThe following is the actual input from the user, system or another agent: test\n"
)

assert %Message{
input: "test",
serving_name: serving_name,
result: "sent 'test' to: ",
prompt:
"<s>[INST]## Input\nThe following is the actual input from the user, system or another agent: test\n[/INST]"
prompt: wrapped
} == Message.send(%Message{serving_name: serving_name, input: "test"})
end
end
Expand Down

0 comments on commit 799306f

Please sign in to comment.