Skip to content

Commit

Permalink
Allow distinct environment variable for OpenAI API key (#12)
Browse files Browse the repository at this point in the history
* Enables the use of a specific environment variable containing an openai api key, instead of just the default `OPENAI_API_KEY`. Project specific env var is `CR_BOT_OPENAI_API_KEY` but if this is not provided then we fall back to the default behaviour.

* CODE REVIEW: add documentation and user feedback around environment variable fallback logic

* CODE REVIEW: update imports
  • Loading branch information
nihilok authored May 20, 2024
1 parent 3a8f834 commit 16a4359
Showing 1 changed file with 34 additions and 10 deletions.
44 changes: 34 additions & 10 deletions src/ai_funcs.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
use async_openai::types::{
ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestSystemMessageArgs,
ChatCompletionRequestUserMessageArgs, ChatCompletionResponseStream,
CreateChatCompletionRequestArgs,
use async_openai::{
config,
types::{
ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestSystemMessageArgs,
ChatCompletionRequestUserMessageArgs, ChatCompletionResponseStream,
CreateChatCompletionRequestArgs,
},
};
use futures::StreamExt;
use std::env;
use std::io::{stdout, Write};

const COMPLETION_TOKENS: u16 = 1024;
Expand All @@ -13,6 +17,23 @@ const PR_SYSTEM_MESSAGE: &'static str = include_str!("pr-system-message.txt");

const MODEL: &'static str = "gpt-4o";

const OPENAI_API_KEY_VAR_NAME: &'static str = "CR_BOT_OPENAI_API_KEY";

/// Helper function to create an OpenAI client using the appropriate API key
fn get_client() -> async_openai::Client<config::OpenAIConfig> {
let token = env::var(OPENAI_API_KEY_VAR_NAME);
match token {
Ok(token) => {
async_openai::Client::with_config(config::OpenAIConfig::new().with_api_key(token))
}
Err(_) => {
println!("No '{}' environment variable supplied; falling back to default 'OPENAI_API_KEY' environment variable.", OPENAI_API_KEY_VAR_NAME);
async_openai::Client::new()
}
}
}

/// Print stream to stdout as it is returned (does not wait for full response before starting printing)
async fn print_stream(
stream: &mut ChatCompletionResponseStream,
) -> Result<(), Box<dyn std::error::Error>> {
Expand All @@ -34,8 +55,10 @@ async fn print_stream(
}
Ok(())
}
pub async fn code_review(output: String) -> Result<(), Box<dyn std::error::Error>> {
let client = async_openai::Client::new();

/// Review PR changes (or local changes on current branch) supplied as `input`
pub async fn code_review(input: String) -> Result<(), Box<dyn std::error::Error>> {
let client = get_client();
let request = CreateChatCompletionRequestArgs::default()
.max_tokens(COMPLETION_TOKENS)
.model(MODEL)
Expand All @@ -53,7 +76,7 @@ pub async fn code_review(output: String) -> Result<(), Box<dyn std::error::Error
.build()?
.into(),
ChatCompletionRequestUserMessageArgs::default()
.content(output.as_str())
.content(input.as_str())
.build()?
.into(),
])
Expand All @@ -64,8 +87,9 @@ pub async fn code_review(output: String) -> Result<(), Box<dyn std::error::Error
print_stream(&mut stream).await
}

pub async fn implementation_details(output: String) -> Result<(), Box<dyn std::error::Error>> {
let client = async_openai::Client::new();
/// Describe PR changes (or local changes on current branch) supplied as `input`
pub async fn implementation_details(input: String) -> Result<(), Box<dyn std::error::Error>> {
let client = get_client();
let request = CreateChatCompletionRequestArgs::default()
.max_tokens(COMPLETION_TOKENS)
.model(MODEL)
Expand All @@ -85,7 +109,7 @@ pub async fn implementation_details(output: String) -> Result<(), Box<dyn std::e
.build()?
.into(),
ChatCompletionRequestUserMessageArgs::default()
.content(output.as_str())
.content(input.as_str())
.build()?
.into(),
])
Expand Down

0 comments on commit 16a4359

Please sign in to comment.