diff --git a/Cargo.toml b/Cargo.toml index 0a3e8b9..272365b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "magicapi-ai-gateway" -version = "0.1.6" +version = "0.1.7" edition = "2021" description = "A high-performance AI Gateway proxy for routing requests to various AI providers, offering seamless integration and management of multiple AI services" authors = ["MagicAPI Team "] @@ -12,6 +12,7 @@ readme = "README.md" keywords = ["ai", "gateway", "proxy", "openai", "llm"] categories = ["web-programming", "api-bindings", "asynchronous"] exclude = [ + ".env", ".cursorrules", ".github/**/*", ".cargo_vcs_info.json", @@ -20,22 +21,27 @@ exclude = [ # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[profile.release] +opt-level = 3 +lto = "fat" +codegen-units = 1 +panic = "abort" +strip = true + [dependencies] -axum = { version = "0.7", features = ["http2"] } -tokio = { version = "1.0", features = ["full"] } -tower-http = { version = "0.5", features = ["cors"] } -tracing = "0.1" -tracing-subscriber = { version = "0.3", features = ["env-filter"] } -reqwest = { version = "0.11", features = ["stream", "json"] } -futures = "0.3" +axum = { version = "0.7", features = ["http2", "tokio"] } +tokio = { version = "1.0", features = ["full", "parking_lot", "rt-multi-thread"] } +tower-http = { version = "0.5", features = ["cors", "compression-full"] } +tracing = { version = "0.1", features = ["attributes"] } +tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] } +reqwest = { version = "0.12.9", features = ["stream", "json", "rustls-tls", "http2", "gzip", "brotli"], default-features = false } http = "1.0" -tower = "0.4" -bytes = "1.0" +bytes = { version = "1.0", features = ["serde"] } dotenv = "0.15" -futures-util = "0.3" +futures-util = { version = "0.3", features = ["io"] } once_cell = "1.18" -hyper = { version = "1.0", features = ["full"] } async-trait = "0.1" thiserror = "1.0" serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" \ No newline at end of file +serde_json = "1.0" +num_cpus = "1.15" \ No newline at end of file diff --git a/src/config.rs b/src/config.rs index f83ae75..91e52e4 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,21 +1,66 @@ -use dotenv::dotenv; use std::env; +use num_cpus; +use tracing::info; +use tracing::debug; pub struct AppConfig { pub port: u16, pub host: String, + pub worker_threads: usize, + pub max_connections: usize, + pub tcp_keepalive_interval: u64, + pub tcp_nodelay: bool, + pub buffer_size: usize, } impl AppConfig { pub fn new() -> Self { - dotenv().ok(); + info!("Loading environment configuration"); + dotenv::dotenv().ok(); + + // Optimize thread count based on CPU cores + let cpu_count = num_cpus::get(); + debug!("Detected {} CPU cores", cpu_count); + + let default_workers = if cpu_count <= 4 { + cpu_count * 2 + } else { + cpu_count + 4 + }; + debug!("Calculated default worker threads: {}", default_workers); - Self { + let config = Self { port: env::var("PORT") .unwrap_or_else(|_| "3000".to_string()) .parse() .expect("PORT must be a number"), host: env::var("HOST").unwrap_or_else(|_| "127.0.0.1".to_string()), - } + worker_threads: env::var("WORKER_THREADS") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(default_workers), + max_connections: env::var("MAX_CONNECTIONS") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(10_000), + tcp_keepalive_interval: env::var("TCP_KEEPALIVE_INTERVAL") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(30), + tcp_nodelay: env::var("TCP_NODELAY") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(true), + buffer_size: env::var("BUFFER_SIZE") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(8 * 1024), // 8KB default + }; + + info!("Configuration loaded: port={}, host={}", config.port, config.host); + debug!("Advanced settings: workers={}, max_conn={}, buffer_size={}", + config.worker_threads, config.max_connections, config.buffer_size); + + config } } diff --git a/src/handlers.rs b/src/handlers.rs index 763c16e..157aff1 100644 --- a/src/handlers.rs +++ b/src/handlers.rs @@ -1,25 +1,24 @@ use crate::{config::AppConfig, proxy::proxy_request_to_provider}; use axum::{ body::Body, - extract::State, + extract::{State, ConnectInfo}, http::{HeaderMap, Request}, response::IntoResponse, Json, }; use serde_json::json; -use std::sync::Arc; -use tracing::{error, info}; +use std::{sync::Arc, net::SocketAddr}; +use tracing::{error, Instrument, debug}; pub async fn health_check() -> impl IntoResponse { - Json(json!({ - "status": "healthy", - "version": env!("CARGO_PKG_VERSION") - })) + debug!("Health check endpoint called"); + Json(json!({ "status": "healthy", "version": env!("CARGO_PKG_VERSION") })) } pub async fn proxy_request( State(config): State>, headers: HeaderMap, + ConnectInfo(addr): ConnectInfo, request: Request, ) -> impl IntoResponse { let provider = headers @@ -27,22 +26,30 @@ pub async fn proxy_request( .and_then(|h| h.to_str().ok()) .unwrap_or("openai"); - info!( + debug!( + "Received request for provider: {}, client: {}, path: {}", + provider, + addr, + request.uri().path() + ); + + let span = tracing::info_span!( + "proxy_request", provider = provider, method = %request.method(), path = %request.uri().path(), - "Incoming proxy request" + client = %addr ); - match proxy_request_to_provider(config, provider, request).await { - Ok(response) => response, - Err(e) => { - error!( - error = %e, - provider = provider, - "Proxy request failed" - ); - e.into_response() + async move { + match proxy_request_to_provider(config, provider, request).await { + Ok(response) => response, + Err(e) => { + error!(error = %e, "Proxy request failed"); + e.into_response() + } } } + .instrument(span) + .await } diff --git a/src/main.rs b/src/main.rs index feb6a99..423efb0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,8 +4,9 @@ use axum::{ }; use std::sync::Arc; use tower_http::cors::{Any, CorsLayer}; -use tracing::{error, info}; +use tracing::{error, info, debug}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; +use std::time::Duration; mod config; mod error; @@ -17,65 +18,69 @@ use crate::config::AppConfig; #[tokio::main] async fn main() { - // Initialize tracing with more detailed format + // Initialize tracing + info!("Initializing tracing system"); tracing_subscriber::registry() .with(tracing_subscriber::EnvFilter::new( - std::env::var("RUST_LOG") - .unwrap_or_else(|_| "info,tower_http=debug,axum::rejection=trace".into()), + std::env::var("RUST_LOG").unwrap_or_else(|_| "info".into()), )) - .with( - tracing_subscriber::fmt::layer() - .with_file(true) - .with_line_number(true) - .with_thread_ids(true) - .with_thread_names(true), - ) + .with(tracing_subscriber::fmt::layer().compact()) .init(); // Load configuration + info!("Loading application configuration"); let config = Arc::new(AppConfig::new()); + debug!("Configuration loaded: port={}, host={}", config.port, config.host); - info!( - host = %config.port, - port = %config.port, - "Starting server with configuration" - ); + // Optimize tokio runtime + info!("Configuring tokio runtime with {} worker threads", config.worker_threads); + std::env::set_var("TOKIO_WORKER_THREADS", config.worker_threads.to_string()); + std::env::set_var("TOKIO_THREAD_STACK_SIZE", (2 * 1024 * 1024).to_string()); // Setup CORS + debug!("Setting up CORS layer with 1-hour max age"); let cors = CorsLayer::new() .allow_origin(Any) .allow_methods(Any) - .allow_headers(Any); + .allow_headers(Any) + .max_age(Duration::from_secs(3600)); - info!("CORS configuration: allowing all origins, methods, and headers"); - - // Create router + // Create router with optimized settings let app = Router::new() .route("/health", get(handlers::health_check)) .route("/v1/*path", any(handlers::proxy_request)) .with_state(config.clone()) - .layer(cors); - - info!("Router configured with health check and proxy endpoints"); + .layer(cors) + .into_make_service_with_connect_info::(); - // Start server + // Start server with optimized TCP settings let addr = std::net::SocketAddr::from(([0, 0, 0, 0], config.port)); + info!("Setting up TCP listener with non-blocking mode"); + let tcp_listener = std::net::TcpListener::bind(addr).expect("Failed to bind address"); + tcp_listener.set_nonblocking(true).expect("Failed to set non-blocking"); + + debug!("Converting to tokio TCP listener"); + let listener = tokio::net::TcpListener::from_std(tcp_listener) + .expect("Failed to create Tokio TCP listener"); + info!( - address = %addr, - "Starting server" + "AI Gateway listening on {}:{} with {} worker threads", + config.host, config.port, config.worker_threads ); - match tokio::net::TcpListener::bind(addr).await { - Ok(listener) => { - info!("Server successfully bound to address"); - if let Err(e) = axum::serve(listener, app).await { - error!(error = %e, "Server error occurred"); - std::process::exit(1); - } - } - Err(e) => { - error!(error = %e, "Failed to bind server to address"); + axum::serve(listener, app) + .with_graceful_shutdown(shutdown_signal()) + .await + .unwrap_or_else(|e| { + error!("Server error: {}", e); std::process::exit(1); - } - } + }); +} + +async fn shutdown_signal() { + info!("Registering shutdown signal handler"); + tokio::signal::ctrl_c() + .await + .expect("Failed to install CTRL+C signal handler"); + info!("Shutdown signal received, starting graceful shutdown"); } diff --git a/src/proxy/client.rs b/src/proxy/client.rs index d7931a7..6b226c5 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -1,12 +1,37 @@ use once_cell::sync::Lazy; use std::time::Duration; +use crate::config::AppConfig; +use tracing::info; +use tracing::debug; + +pub fn create_client(config: &AppConfig) -> reqwest::Client { + info!("Creating HTTP client with optimized settings"); + debug!( + "Client config: max_connections={}, keepalive={}s, nodelay={}", + config.max_connections, + config.tcp_keepalive_interval, + config.tcp_nodelay + ); -pub static CLIENT: Lazy = Lazy::new(|| { reqwest::Client::builder() - .pool_idle_timeout(Duration::from_secs(60)) - .pool_max_idle_per_host(32) - .tcp_keepalive(Duration::from_secs(60)) - .timeout(Duration::from_secs(60)) + .pool_max_idle_per_host(config.max_connections) + .pool_idle_timeout(Duration::from_secs(30)) + .http2_prior_knowledge() + .http2_keep_alive_interval(Duration::from_secs(config.tcp_keepalive_interval)) + .http2_keep_alive_timeout(Duration::from_secs(30)) + .http2_adaptive_window(true) + .tcp_keepalive(Duration::from_secs(config.tcp_keepalive_interval)) + .tcp_nodelay(config.tcp_nodelay) + .use_rustls_tls() + .timeout(Duration::from_secs(30)) + .connect_timeout(Duration::from_secs(10)) + .gzip(true) + .brotli(true) .build() .expect("Failed to create HTTP client") +} + +pub static CLIENT: Lazy = Lazy::new(|| { + let config = AppConfig::new(); + create_client(&config) }); diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index 9d523b2..ebc46cc 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -1,11 +1,12 @@ use axum::{ - body::{self, Body, Bytes}, + body::{self, Body}, http::{HeaderMap, HeaderValue, Request, Response, StatusCode}, }; use futures_util::StreamExt; use reqwest::Method; use std::sync::Arc; use tracing::{debug, error, info}; +use bytes::BytesMut; use crate::{config::AppConfig, error::AppError, providers::create_provider}; @@ -13,7 +14,7 @@ mod client; pub use client::CLIENT; pub async fn proxy_request_to_provider( - _config: Arc, + config: Arc, provider_name: &str, original_request: Request, ) -> Result, AppError> { @@ -24,10 +25,10 @@ pub async fn proxy_request_to_provider( "Incoming request" ); - // Create provider instance + debug!("Creating provider instance for: {}", provider_name); let provider = create_provider(provider_name)?; - // Call before request hook + debug!("Executing before_request hook"); provider.before_request(&original_request).await?; let path = original_request.uri().path(); @@ -50,12 +51,13 @@ pub async fn proxy_request_to_provider( // Process headers let headers = provider.process_headers(original_request.headers())?; - // Create and send request + // Create and send request with optimized buffer handling let response = send_provider_request( original_request.method().clone(), url, headers, original_request.into_body(), + config.clone(), ) .await?; @@ -68,73 +70,86 @@ pub async fn proxy_request_to_provider( Ok(processed_response) } -// Helper function to send the actual request async fn send_provider_request( method: http::Method, url: String, headers: HeaderMap, body: Body, + config: Arc, ) -> Result, AppError> { + debug!("Preparing to send request: {} {}", method, url); + let body_bytes = body::to_bytes(body, usize::MAX).await?; + debug!("Request body size: {} bytes", body_bytes.len()); let client = &*CLIENT; - let method = - Method::from_bytes(method.as_str().as_bytes()).map_err(|_| AppError::InvalidMethod)?; + let method = Method::from_bytes(method.as_str().as_bytes()) + .map_err(|_| AppError::InvalidMethod)?; - // Convert http::HeaderMap to reqwest::HeaderMap - let mut reqwest_headers = reqwest::header::HeaderMap::new(); + // Pre-allocate headers map with known capacity + let mut reqwest_headers = reqwest::header::HeaderMap::with_capacity(headers.len()); + + // Batch process headers for (name, value) in headers.iter() { - if let Ok(v) = reqwest::header::HeaderValue::from_bytes(value.as_bytes()) { - // Convert the header name to a string first - if let Ok(name_str) = name.as_str().parse::() { - reqwest_headers.insert(name_str, v); - } + if let (Ok(name_str), Ok(v)) = ( + name.as_str().parse::(), + reqwest::header::HeaderValue::from_bytes(value.as_bytes()), + ) { + reqwest_headers.insert(name_str, v); } } let response = client .request(method, url) - .headers(reqwest_headers) // Now using the converted reqwest::HeaderMap + .headers(reqwest_headers) .body(body_bytes.to_vec()) .send() .await?; - process_response(response).await + process_response(response, config).await } -// Add this function after the send_provider_request function - -async fn process_response(response: reqwest::Response) -> Result, AppError> { +async fn process_response( + response: reqwest::Response, + config: Arc, +) -> Result, AppError> { let status = StatusCode::from_u16(response.status().as_u16())?; + debug!("Processing response with status: {}", status); + + // Pre-allocate headers map + let mut response_headers = HeaderMap::with_capacity(response.headers().len()); + + // Batch process headers + for (name, value) in response.headers() { + if let (Ok(header_name), Ok(v)) = ( + http::HeaderName::from_bytes(name.as_ref()), + HeaderValue::from_bytes(value.as_bytes()), + ) { + response_headers.insert(header_name, v); + } + } - // Check if response is a stream - if response - .headers() + // Check for streaming response + if response.headers() .get(reqwest::header::CONTENT_TYPE) .and_then(|v| v.to_str().ok()) .map_or(false, |ct| ct.contains("text/event-stream")) { - debug!("Processing streaming response"); - - // Convert headers - let mut response_headers = HeaderMap::new(); - for (name, value) in response.headers() { - if let Ok(v) = HeaderValue::from_bytes(value.as_bytes()) { - if let Ok(header_name) = http::HeaderName::from_bytes(name.as_ref()) { - response_headers.insert(header_name, v); + info!("Processing streaming response"); + debug!("Setting up stream with buffer size: {}", config.buffer_size); + + // Optimize streaming with larger chunks + let stream = response.bytes_stream().map(move |result| { + match result { + Ok(bytes) => { + let mut buffer = BytesMut::with_capacity(config.buffer_size); + buffer.extend_from_slice(&bytes); + Ok(buffer.freeze()) + } + Err(e) => { + error!("Stream error: {}", e); + Err(std::io::Error::new(std::io::ErrorKind::Other, e)) } - } - } - - // Set up streaming response - let stream = response.bytes_stream().map(|result| match result { - Ok(bytes) => { - debug!("Streaming chunk: {} bytes", bytes.len()); - Ok(bytes) - } - Err(e) => { - error!("Stream error: {}", e); - Err(std::io::Error::new(std::io::ErrorKind::Other, e)) } }); @@ -148,25 +163,14 @@ async fn process_response(response: reqwest::Response) -> Result, .unwrap()) } else { debug!("Processing regular response"); - - // Convert headers - let mut response_headers = HeaderMap::new(); - for (name, value) in response.headers() { - if let Ok(v) = HeaderValue::from_bytes(value.as_bytes()) { - if let Ok(header_name) = http::HeaderName::from_bytes(name.as_ref()) { - response_headers.insert(header_name, v); - } - } - } - - // Process regular response body + + // Use pre-allocated buffer for regular responses let body = response.bytes().await?; - - let mut builder = Response::builder().status(status); - for (name, value) in response_headers.iter() { - builder = builder.header(name, value); - } - - Ok(builder.body(Body::from(body)).unwrap()) + + Ok(Response::builder() + .status(status) + .extension(response_headers) + .body(Body::from(body)) + .unwrap()) } }