Skip to content

Commit

Permalink
Add LLM mock support
Browse files Browse the repository at this point in the history
To support different LLMs, we start by adding a mocked LLM
implementation that can be used for unit testing purposes. This allows
us to test the prompt handling logic without needing to rely on a real
LLM service, which can be slow and expensive during development and
testing.
  • Loading branch information
PatWie committed Aug 9, 2024
1 parent b9f6be6 commit f10a11a
Show file tree
Hide file tree
Showing 8 changed files with 126 additions and 131 deletions.
13 changes: 8 additions & 5 deletions src/code_action_providers/lua_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@ use tower_lsp::lsp_types::{CodeAction, CodeActionKind, TextEdit, WorkspaceEdit};
use crate::code_action_providers::parsed_document::ParsedDocument;
use crate::code_action_providers::traits::ActionContext;
use crate::code_action_providers::traits::ActionProvider;
use crate::prompt_handlers::claude::BedrockConverse;
use crate::prompt_handlers::traits::PromptHandler;
use crate::prompt_handlers::traits::LLM;
use crate::ResolveAction;

use super::lua_jit::LuaInterface;

pub struct LuaProvider {
prompt_handler: Arc<BedrockConverse>,
prompt_handler: Arc<LLM>,
lua_source: String,
id: String,
}
Expand All @@ -33,7 +32,7 @@ pub enum LuaProviderError {
impl LuaProvider {
pub fn try_new(
file_name: &str,
prompt_handler: Arc<BedrockConverse>,
prompt_handler: Arc<LLM>,
) -> anyhow::Result<Self, LuaProviderError> {
Ok(Self {
prompt_handler,
Expand Down Expand Up @@ -73,7 +72,11 @@ impl ActionProvider for LuaProvider {
}
//log::info!("prompt {}", prompt);
//log::info!("range {:?}", range);
let mut new_text: String = self.prompt_handler.answer(&prompt).await.unwrap();
let new_text = self.prompt_handler.answer(&prompt).await;
if new_text.is_err() {
return Err(Error::request_cancelled());
}
let mut new_text = new_text.unwrap();
{
//log::info!("answer {}", new_text);
let lua = self.create_lua_interface(doc);
Expand Down
11 changes: 3 additions & 8 deletions src/code_action_providers/yaml_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ use tower_lsp::lsp_types::{CodeAction, CodeActionKind, TextEdit, WorkspaceEdit};
use crate::code_action_providers::traits::ActionContext;
use crate::code_action_providers::traits::ActionProvider;
use crate::code_action_providers::{helper, parsed_document::ParsedDocument};
use crate::prompt_handlers::claude::BedrockConverse;
use crate::prompt_handlers::traits::PromptHandler;
use crate::prompt_handlers::traits::LLM;
use crate::ResolveAction;

use super::config;
Expand All @@ -25,18 +24,14 @@ fn build_prompt(template: &str, hints: &HashMap<String, String>) -> String {
}

pub struct YamlProvider {
prompt_handler: Arc<BedrockConverse>,
prompt_handler: Arc<LLM>,

config: config::CodeAction,
id: String,
}

impl YamlProvider {
pub fn from_config(
config: config::CodeAction,
id: &str,
prompt_handler: Arc<BedrockConverse>,
) -> Self {
pub fn from_config(config: config::CodeAction, id: &str, prompt_handler: Arc<LLM>) -> Self {
Self {
prompt_handler,
config,
Expand Down
7 changes: 4 additions & 3 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ use code_action_providers::parsed_document::ParsedDocument;
use code_action_providers::traits::ActionProvider;
use code_action_providers::yaml_provider::YamlProvider;
use nonsense::TextAdapter;
use prompt_handlers::claude::BedrockConverse;
use prompt_handlers::bedrock::BedrockConverse;
use prompt_handlers::traits::LLM;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::env;
Expand Down Expand Up @@ -374,11 +375,11 @@ async fn main() {

tracing_subscriber::fmt().init();
//log::info!("Start");
let prompt_handler = Arc::new(
let prompt_handler = Arc::new(LLM::Bedrock(
BedrockConverse::new(&polyglot_config.model.bedrock)
.await
.unwrap(),
);
));
let mut providers: HashMap<String, Vec<Box<dyn ActionProvider>>> = Default::default();

//log::info!("Processing config-dir: {:?}", config_dir);
Expand Down
73 changes: 73 additions & 0 deletions src/prompt_handlers/bedrock.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use anyhow::bail;
use aws_config::{BehaviorVersion, Region};
use aws_sdk_bedrockruntime::{
operation::converse::{ ConverseOutput},
types::{ContentBlock, ConversationRole, Message},
Client,
};

use crate::configuration::BedrockConfig;

use super::traits::PromptHandler;


#[derive(Debug)]
pub struct BedrockConverse {
client: Client,
model_id: String,
}

impl BedrockConverse {
pub async fn new(config: &BedrockConfig) -> anyhow::Result<Self> {
let sdk_config = aws_config::defaults(BehaviorVersion::latest())
.region(Region::new(config.region.clone()))
.profile_name(config.aws_profile.clone())
.load()
.await;

let client = Client::new(&sdk_config);

Ok(BedrockConverse {
client,
model_id: config.model_id.clone(),
})
}
}

impl PromptHandler for BedrockConverse {
async fn answer(&self, prompt: &str) -> anyhow::Result<String> {
let response = self
.client
.converse()
.model_id(&self.model_id)
.messages(
Message::builder()
// .role(ConversationRole::Assistant)
// .content(ContentBlock::Text(prompt.to_string()))
.role(ConversationRole::User)
.content(ContentBlock::Text(prompt.to_string()))
.build()?, // .map_err(|_| "failed to build message")?,
)
.send()
.await;
let e = get_converse_output_text(response?);
match e {
Ok(s) => Ok(s),
Err(_) => bail!("failed to get response"),
}
}
}
fn get_converse_output_text(output: ConverseOutput) -> Result<String, String> {
let text = output
.output()
.ok_or("no output")?
.as_message()
.map_err(|_| "output not a message")?
.content()
.first()
.ok_or("no content in message")?
.as_text()
.map_err(|_| "content is not text")?
.to_string();
Ok(text)
}
111 changes: 0 additions & 111 deletions src/prompt_handlers/claude.rs

This file was deleted.

18 changes: 18 additions & 0 deletions src/prompt_handlers/mock.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
use super::traits::PromptHandler;

#[derive(Debug)]
pub struct MockLLM {
answer: String,
}

impl MockLLM {
pub fn new(answer: String) -> anyhow::Result<Self> {
Ok(MockLLM { answer })
}
}

impl PromptHandler for MockLLM {
async fn answer(&self, _: &str) -> anyhow::Result<String> {
Ok(self.answer.clone())
}
}
3 changes: 2 additions & 1 deletion src/prompt_handlers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pub mod claude;
pub mod bedrock;
pub mod mock;
pub mod traits;
21 changes: 18 additions & 3 deletions src/prompt_handlers/traits.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,23 @@
pub trait PromptHandler {
type Error: std::error::Error;
use super::bedrock::BedrockConverse;
use super::mock::MockLLM;

pub trait PromptHandler {
fn answer(
&self,
prompt: &str,
) -> impl std::future::Future<Output = Result<String, Self::Error>> + Send;
) -> impl std::future::Future<Output = anyhow::Result<String>> + Send;
}

pub enum LLM {
Bedrock(BedrockConverse),
Mock(MockLLM),
}

impl LLM {
pub async fn answer<'a>(&'a self, prompt: &'a str) -> anyhow::Result<String> {
match self {
LLM::Bedrock(b) => b.answer(prompt).await,
LLM::Mock(b) => b.answer(prompt).await,
}
}
}

0 comments on commit f10a11a

Please sign in to comment.