Skip to content

Commit

Permalink
✨ adds prompt to cli
Browse files Browse the repository at this point in the history
  • Loading branch information
chriamue committed Dec 20, 2023
1 parent 1f5f4ef commit 08cc9d9
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 14 deletions.
2 changes: 0 additions & 2 deletions src/llm/text_generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ impl TextGeneration {
trace!("{t}")
}
}
std::io::stdout().flush()?;

let mut generated_tokens = 0usize;
let eos_token = match tokenizer.get_token("</s>") {
Some(token) => token,
Expand Down
49 changes: 37 additions & 12 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use chat_flame_backend::{config::load_config, server::server};
use chat_flame_backend::{
config::{load_config, Config},
server::server,
};
use log::{error, info};
use std::net::SocketAddr;
use structopt::StructOpt;
Expand All @@ -16,6 +19,33 @@ struct Opt {
help = "Specify the path to the configuration file"
)]
config: String,
/// Optional text prompt for immediate text generation. If provided, runs text generation instead of starting the server.
#[structopt(short, long)]
prompt: Option<String>,
}

async fn generate_text(prompt: String, config: Config) {
info!("Generating text for prompt: {}", prompt);
let mut text_generation =
chat_flame_backend::llm::create_text_generation(None, None, 0.0, 0, &config.cache_dir)
.unwrap();
let generated_text = text_generation.run(&prompt, 50).unwrap();
println!("{}", generated_text.unwrap_or_default());
}

async fn start_server(config: Config) {
info!("Starting server");
info!("preload model");
let _ = chat_flame_backend::llm::create_model(&config.cache_dir);

info!("Running on port: {}", config.port);
let addr = SocketAddr::from(([0, 0, 0, 0], config.port));
let app = server(config);

info!("Server running at http://{}", addr);

let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
axum::serve(listener, app).await.unwrap();
}

#[tokio::main]
Expand All @@ -26,17 +56,12 @@ async fn main() {
match load_config(&opt.config) {
Ok(config) => {
info!("Loaded config: {:?}", config);
info!("preload model");
let _ = chat_flame_backend::llm::create_model(&config.cache_dir);

info!("Running on port: {}", config.port);
let addr = SocketAddr::from(([0, 0, 0, 0], config.port));
let app = server(config);

info!("Server running at http://{}", addr);

let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
axum::serve(listener, app).await.unwrap();
if let Some(prompt) = opt.prompt {
generate_text(prompt, config).await;
return;
} else {
start_server(config).await;
}
}
Err(e) => {
error!("Failed to load config: {}", e);
Expand Down

0 comments on commit 08cc9d9

Please sign in to comment.