Skip to content

Commit

Permalink
Use Thread and various performance improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
imshashank committed Nov 13, 2024
1 parent f54584a commit c0e7aa6
Show file tree
Hide file tree
Showing 6 changed files with 232 additions and 140 deletions.
32 changes: 19 additions & 13 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "magicapi-ai-gateway"
version = "0.1.6"
version = "0.1.7"
edition = "2021"
description = "A high-performance AI Gateway proxy for routing requests to various AI providers, offering seamless integration and management of multiple AI services"
authors = ["MagicAPI Team <team@magicapi.com>"]
Expand All @@ -12,6 +12,7 @@ readme = "README.md"
keywords = ["ai", "gateway", "proxy", "openai", "llm"]
categories = ["web-programming", "api-bindings", "asynchronous"]
exclude = [
".env",
".cursorrules",
".github/**/*",
".cargo_vcs_info.json",
Expand All @@ -20,22 +21,27 @@ exclude = [

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[profile.release]
opt-level = 3
lto = "fat"
codegen-units = 1
panic = "abort"
strip = true

[dependencies]
axum = { version = "0.7", features = ["http2"] }
tokio = { version = "1.0", features = ["full"] }
tower-http = { version = "0.5", features = ["cors"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
reqwest = { version = "0.11", features = ["stream", "json"] }
futures = "0.3"
axum = { version = "0.7", features = ["http2", "tokio"] }
tokio = { version = "1.0", features = ["full", "parking_lot", "rt-multi-thread"] }
tower-http = { version = "0.5", features = ["cors", "compression-full"] }
tracing = { version = "0.1", features = ["attributes"] }
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
reqwest = { version = "0.12.9", features = ["stream", "json", "rustls-tls", "http2", "gzip", "brotli"], default-features = false }
http = "1.0"
tower = "0.4"
bytes = "1.0"
bytes = { version = "1.0", features = ["serde"] }
dotenv = "0.15"
futures-util = "0.3"
futures-util = { version = "0.3", features = ["io"] }
once_cell = "1.18"
hyper = { version = "1.0", features = ["full"] }
async-trait = "0.1"
thiserror = "1.0"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
serde_json = "1.0"
num_cpus = "1.15"
53 changes: 49 additions & 4 deletions src/config.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,66 @@
use dotenv::dotenv;
use std::env;
use num_cpus;
use tracing::info;
use tracing::debug;

pub struct AppConfig {
pub port: u16,
pub host: String,
pub worker_threads: usize,
pub max_connections: usize,
pub tcp_keepalive_interval: u64,
pub tcp_nodelay: bool,
pub buffer_size: usize,
}

impl AppConfig {
pub fn new() -> Self {
dotenv().ok();
info!("Loading environment configuration");
dotenv::dotenv().ok();

// Optimize thread count based on CPU cores
let cpu_count = num_cpus::get();
debug!("Detected {} CPU cores", cpu_count);

let default_workers = if cpu_count <= 4 {
cpu_count * 2
} else {
cpu_count + 4
};
debug!("Calculated default worker threads: {}", default_workers);

Self {
let config = Self {
port: env::var("PORT")
.unwrap_or_else(|_| "3000".to_string())
.parse()
.expect("PORT must be a number"),
host: env::var("HOST").unwrap_or_else(|_| "127.0.0.1".to_string()),
}
worker_threads: env::var("WORKER_THREADS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(default_workers),
max_connections: env::var("MAX_CONNECTIONS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(10_000),
tcp_keepalive_interval: env::var("TCP_KEEPALIVE_INTERVAL")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(30),
tcp_nodelay: env::var("TCP_NODELAY")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(true),
buffer_size: env::var("BUFFER_SIZE")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(8 * 1024), // 8KB default
};

info!("Configuration loaded: port={}, host={}", config.port, config.host);
debug!("Advanced settings: workers={}, max_conn={}, buffer_size={}",
config.worker_threads, config.max_connections, config.buffer_size);

config
}
}
43 changes: 25 additions & 18 deletions src/handlers.rs
Original file line number Diff line number Diff line change
@@ -1,48 +1,55 @@
use crate::{config::AppConfig, proxy::proxy_request_to_provider};
use axum::{
body::Body,
extract::State,
extract::{State, ConnectInfo},
http::{HeaderMap, Request},
response::IntoResponse,
Json,
};
use serde_json::json;
use std::sync::Arc;
use tracing::{error, info};
use std::{sync::Arc, net::SocketAddr};
use tracing::{error, Instrument, debug};

pub async fn health_check() -> impl IntoResponse {
Json(json!({
"status": "healthy",
"version": env!("CARGO_PKG_VERSION")
}))
debug!("Health check endpoint called");
Json(json!({ "status": "healthy", "version": env!("CARGO_PKG_VERSION") }))
}

pub async fn proxy_request(
State(config): State<Arc<AppConfig>>,
headers: HeaderMap,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
request: Request<Body>,
) -> impl IntoResponse {
let provider = headers
.get("x-provider")
.and_then(|h| h.to_str().ok())
.unwrap_or("openai");

info!(
debug!(
"Received request for provider: {}, client: {}, path: {}",
provider,
addr,
request.uri().path()
);

let span = tracing::info_span!(
"proxy_request",
provider = provider,
method = %request.method(),
path = %request.uri().path(),
"Incoming proxy request"
client = %addr
);

match proxy_request_to_provider(config, provider, request).await {
Ok(response) => response,
Err(e) => {
error!(
error = %e,
provider = provider,
"Proxy request failed"
);
e.into_response()
async move {
match proxy_request_to_provider(config, provider, request).await {
Ok(response) => response,
Err(e) => {
error!(error = %e, "Proxy request failed");
e.into_response()
}
}
}
.instrument(span)
.await
}
81 changes: 43 additions & 38 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ use axum::{
};
use std::sync::Arc;
use tower_http::cors::{Any, CorsLayer};
use tracing::{error, info};
use tracing::{error, info, debug};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use std::time::Duration;

mod config;
mod error;
Expand All @@ -17,65 +18,69 @@ use crate::config::AppConfig;

#[tokio::main]
async fn main() {
// Initialize tracing with more detailed format
// Initialize tracing
info!("Initializing tracing system");
tracing_subscriber::registry()
.with(tracing_subscriber::EnvFilter::new(
std::env::var("RUST_LOG")
.unwrap_or_else(|_| "info,tower_http=debug,axum::rejection=trace".into()),
std::env::var("RUST_LOG").unwrap_or_else(|_| "info".into()),
))
.with(
tracing_subscriber::fmt::layer()
.with_file(true)
.with_line_number(true)
.with_thread_ids(true)
.with_thread_names(true),
)
.with(tracing_subscriber::fmt::layer().compact())
.init();

// Load configuration
info!("Loading application configuration");
let config = Arc::new(AppConfig::new());
debug!("Configuration loaded: port={}, host={}", config.port, config.host);

info!(
host = %config.port,
port = %config.port,
"Starting server with configuration"
);
// Optimize tokio runtime
info!("Configuring tokio runtime with {} worker threads", config.worker_threads);
std::env::set_var("TOKIO_WORKER_THREADS", config.worker_threads.to_string());
std::env::set_var("TOKIO_THREAD_STACK_SIZE", (2 * 1024 * 1024).to_string());

// Setup CORS
debug!("Setting up CORS layer with 1-hour max age");
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any);
.allow_headers(Any)
.max_age(Duration::from_secs(3600));

info!("CORS configuration: allowing all origins, methods, and headers");

// Create router
// Create router with optimized settings
let app = Router::new()
.route("/health", get(handlers::health_check))
.route("/v1/*path", any(handlers::proxy_request))
.with_state(config.clone())
.layer(cors);

info!("Router configured with health check and proxy endpoints");
.layer(cors)
.into_make_service_with_connect_info::<std::net::SocketAddr>();

// Start server
// Start server with optimized TCP settings
let addr = std::net::SocketAddr::from(([0, 0, 0, 0], config.port));
info!("Setting up TCP listener with non-blocking mode");
let tcp_listener = std::net::TcpListener::bind(addr).expect("Failed to bind address");
tcp_listener.set_nonblocking(true).expect("Failed to set non-blocking");

debug!("Converting to tokio TCP listener");
let listener = tokio::net::TcpListener::from_std(tcp_listener)
.expect("Failed to create Tokio TCP listener");

info!(
address = %addr,
"Starting server"
"AI Gateway listening on {}:{} with {} worker threads",
config.host, config.port, config.worker_threads
);

match tokio::net::TcpListener::bind(addr).await {
Ok(listener) => {
info!("Server successfully bound to address");
if let Err(e) = axum::serve(listener, app).await {
error!(error = %e, "Server error occurred");
std::process::exit(1);
}
}
Err(e) => {
error!(error = %e, "Failed to bind server to address");
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal())
.await
.unwrap_or_else(|e| {
error!("Server error: {}", e);
std::process::exit(1);
}
}
});
}

async fn shutdown_signal() {
info!("Registering shutdown signal handler");
tokio::signal::ctrl_c()
.await
.expect("Failed to install CTRL+C signal handler");
info!("Shutdown signal received, starting graceful shutdown");
}
35 changes: 30 additions & 5 deletions src/proxy/client.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,37 @@
use once_cell::sync::Lazy;
use std::time::Duration;
use crate::config::AppConfig;
use tracing::info;
use tracing::debug;

pub fn create_client(config: &AppConfig) -> reqwest::Client {
info!("Creating HTTP client with optimized settings");
debug!(
"Client config: max_connections={}, keepalive={}s, nodelay={}",
config.max_connections,
config.tcp_keepalive_interval,
config.tcp_nodelay
);

pub static CLIENT: Lazy<reqwest::Client> = Lazy::new(|| {
reqwest::Client::builder()
.pool_idle_timeout(Duration::from_secs(60))
.pool_max_idle_per_host(32)
.tcp_keepalive(Duration::from_secs(60))
.timeout(Duration::from_secs(60))
.pool_max_idle_per_host(config.max_connections)
.pool_idle_timeout(Duration::from_secs(30))
.http2_prior_knowledge()
.http2_keep_alive_interval(Duration::from_secs(config.tcp_keepalive_interval))
.http2_keep_alive_timeout(Duration::from_secs(30))
.http2_adaptive_window(true)
.tcp_keepalive(Duration::from_secs(config.tcp_keepalive_interval))
.tcp_nodelay(config.tcp_nodelay)
.use_rustls_tls()
.timeout(Duration::from_secs(30))
.connect_timeout(Duration::from_secs(10))
.gzip(true)
.brotli(true)
.build()
.expect("Failed to create HTTP client")
}

pub static CLIENT: Lazy<reqwest::Client> = Lazy::new(|| {
let config = AppConfig::new();
create_client(&config)
});
Loading

0 comments on commit c0e7aa6

Please sign in to comment.