-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
To support different LLMs, we start by adding a mocked LLM implementation that can be used for unit testing purposes. This allows us to test the prompt handling logic without needing to rely on a real LLM service, which can be slow and expensive during development and testing.
- Loading branch information
Showing
8 changed files
with
126 additions
and
131 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
use anyhow::bail; | ||
use aws_config::{BehaviorVersion, Region}; | ||
use aws_sdk_bedrockruntime::{ | ||
operation::converse::{ ConverseOutput}, | ||
types::{ContentBlock, ConversationRole, Message}, | ||
Client, | ||
}; | ||
|
||
use crate::configuration::BedrockConfig; | ||
|
||
use super::traits::PromptHandler; | ||
|
||
|
||
#[derive(Debug)] | ||
pub struct BedrockConverse { | ||
client: Client, | ||
model_id: String, | ||
} | ||
|
||
impl BedrockConverse { | ||
pub async fn new(config: &BedrockConfig) -> anyhow::Result<Self> { | ||
let sdk_config = aws_config::defaults(BehaviorVersion::latest()) | ||
.region(Region::new(config.region.clone())) | ||
.profile_name(config.aws_profile.clone()) | ||
.load() | ||
.await; | ||
|
||
let client = Client::new(&sdk_config); | ||
|
||
Ok(BedrockConverse { | ||
client, | ||
model_id: config.model_id.clone(), | ||
}) | ||
} | ||
} | ||
|
||
impl PromptHandler for BedrockConverse { | ||
async fn answer(&self, prompt: &str) -> anyhow::Result<String> { | ||
let response = self | ||
.client | ||
.converse() | ||
.model_id(&self.model_id) | ||
.messages( | ||
Message::builder() | ||
// .role(ConversationRole::Assistant) | ||
// .content(ContentBlock::Text(prompt.to_string())) | ||
.role(ConversationRole::User) | ||
.content(ContentBlock::Text(prompt.to_string())) | ||
.build()?, // .map_err(|_| "failed to build message")?, | ||
) | ||
.send() | ||
.await; | ||
let e = get_converse_output_text(response?); | ||
match e { | ||
Ok(s) => Ok(s), | ||
Err(_) => bail!("failed to get response"), | ||
} | ||
} | ||
} | ||
fn get_converse_output_text(output: ConverseOutput) -> Result<String, String> { | ||
let text = output | ||
.output() | ||
.ok_or("no output")? | ||
.as_message() | ||
.map_err(|_| "output not a message")? | ||
.content() | ||
.first() | ||
.ok_or("no content in message")? | ||
.as_text() | ||
.map_err(|_| "content is not text")? | ||
.to_string(); | ||
Ok(text) | ||
} |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
use super::traits::PromptHandler; | ||
|
||
#[derive(Debug)] | ||
pub struct MockLLM { | ||
answer: String, | ||
} | ||
|
||
impl MockLLM { | ||
pub fn new(answer: String) -> anyhow::Result<Self> { | ||
Ok(MockLLM { answer }) | ||
} | ||
} | ||
|
||
impl PromptHandler for MockLLM { | ||
async fn answer(&self, _: &str) -> anyhow::Result<String> { | ||
Ok(self.answer.clone()) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
pub mod claude; | ||
pub mod bedrock; | ||
pub mod mock; | ||
pub mod traits; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,23 @@ | ||
pub trait PromptHandler { | ||
type Error: std::error::Error; | ||
use super::bedrock::BedrockConverse; | ||
use super::mock::MockLLM; | ||
|
||
pub trait PromptHandler { | ||
fn answer( | ||
&self, | ||
prompt: &str, | ||
) -> impl std::future::Future<Output = Result<String, Self::Error>> + Send; | ||
) -> impl std::future::Future<Output = anyhow::Result<String>> + Send; | ||
} | ||
|
||
pub enum LLM { | ||
Bedrock(BedrockConverse), | ||
Mock(MockLLM), | ||
} | ||
|
||
impl LLM { | ||
pub async fn answer<'a>(&'a self, prompt: &'a str) -> anyhow::Result<String> { | ||
match self { | ||
LLM::Bedrock(b) => b.answer(prompt).await, | ||
LLM::Mock(b) => b.answer(prompt).await, | ||
} | ||
} | ||
} |