diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..560f1c4 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,10 @@ +target/ +.git/ +.github/ +.env +**/*.rs.bk +Dockerfile +.dockerignore +.gitignore +README.md +*.log \ No newline at end of file diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml new file mode 100644 index 0000000..808a5b8 --- /dev/null +++ b/.github/workflows/docker-build.yml @@ -0,0 +1,105 @@ +name: Docker Build and Push + +on: + push: + branches: + - 'main' # Only trigger push events on main + tags: [ "v*" ] + pull_request: # Keep PR triggers for all branches + workflow_dispatch: + +env: + REGISTRY: ghcr.io + IMAGE_NAME: ${{ github.repository }} + +jobs: + docker: + # Skip this job if it's a push event on a non-main branch + if: github.event_name != 'push' || github.ref == 'refs/heads/main' + runs-on: ubuntu-latest + permissions: + contents: read + packages: write + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + # Generate Cargo.lock if it doesn't exist + - name: Generate Cargo.lock + run: | + if [ ! -f "Cargo.lock" ]; then + cargo generate-lockfile + fi + + - name: Extract version from Cargo.toml + id: version + run: | + VERSION=$(awk -F '"' '/^version = / {print $2}' Cargo.toml) + echo "CARGO_VERSION=${VERSION}" >> $GITHUB_ENV + echo "Version found: ${VERSION}" + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + # Always build to verify Dockerfile, but push only on main + - name: Build Docker image (non-main branch) + if: github.ref != 'refs/heads/main' + uses: docker/build-push-action@v5 + with: + context: . + push: false + platforms: linux/amd64,linux/arm64 + cache-from: type=gha + cache-to: type=gha,mode=max + + # Main branch handling - build and push + - name: Log in to GitHub Container Registry + if: github.ref == 'refs/heads/main' + uses: docker/login-action@v3 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Extract metadata for GHCR + if: github.ref == 'refs/heads/main' + id: meta-ghcr + uses: docker/metadata-action@v5 + with: + images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} + tags: | + type=raw,value=latest + type=raw,value=${{ env.CARGO_VERSION }} + + - name: Build and push to GitHub Container Registry + if: github.ref == 'refs/heads/main' + uses: docker/build-push-action@v5 + with: + context: . + push: true + tags: ${{ steps.meta-ghcr.outputs.tags }} + labels: ${{ steps.meta-ghcr.outputs.labels }} + cache-from: type=gha + cache-to: type=gha,mode=max + platforms: linux/amd64,linux/arm64 + + - name: Login to Docker Hub + if: github.ref == 'refs/heads/main' + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Build and push to Docker Hub + if: github.ref == 'refs/heads/main' + uses: docker/build-push-action@v5 + with: + context: . + push: true + tags: | + magicapi1/magicapi-ai-gateway:latest + magicapi1/magicapi-ai-gateway:${{ env.CARGO_VERSION }} + cache-from: type=gha + cache-to: type=gha,mode=max + platforms: linux/amd64 \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index a61baea..bd28d42 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.1.7] - 2024-11-13 +### Added +- Managed deployment offering with testing gateway at gateway.magicapi.dev +- Thread-based performance optimizations for improved request handling +- Documentation for testing deployment environment +### Enhanced +- Significant performance improvements in request processing +- Build system optimizations +- CI/CD pipeline improvements +### Fixed +- Git build configuration issues +- Various minor bug fixes + ## [0.1.6] - 2024-11-13 ### Added - Support for Fireworks AI provider @@ -65,7 +78,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Error handling - Basic documentation -[Unreleased]: https://github.com/MagicAPI/ai-gateway/compare/v0.1.6...HEAD +[Unreleased]: https://github.com/MagicAPI/ai-gateway/compare/v0.1.7...HEAD +[0.1.7]: https://github.com/MagicAPI/ai-gateway/compare/v0.1.6...v0.1.7 [0.1.6]: https://github.com/MagicAPI/ai-gateway/compare/v0.1.5...v0.1.6 [0.1.5]: https://github.com/MagicAPI/ai-gateway/compare/v0.1.4...v0.1.5 [0.1.3]: https://github.com/MagicAPI/ai-gateway/compare/v0.1.0...v0.1.3 \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 0a3e8b9..69d680b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "magicapi-ai-gateway" -version = "0.1.6" +version = "0.1.7" edition = "2021" description = "A high-performance AI Gateway proxy for routing requests to various AI providers, offering seamless integration and management of multiple AI services" authors = ["MagicAPI Team "] @@ -12,6 +12,7 @@ readme = "README.md" keywords = ["ai", "gateway", "proxy", "openai", "llm"] categories = ["web-programming", "api-bindings", "asynchronous"] exclude = [ + ".env", ".cursorrules", ".github/**/*", ".cargo_vcs_info.json", @@ -20,22 +21,28 @@ exclude = [ # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[profile.release] +opt-level = 3 +lto = "fat" +codegen-units = 1 +panic = "abort" +strip = true +debug = false + [dependencies] -axum = { version = "0.7", features = ["http2"] } -tokio = { version = "1.0", features = ["full"] } -tower-http = { version = "0.5", features = ["cors"] } -tracing = "0.1" -tracing-subscriber = { version = "0.3", features = ["env-filter"] } -reqwest = { version = "0.11", features = ["stream", "json"] } -futures = "0.3" +axum = { version = "0.7", features = ["http2", "tokio"] } +tokio = { version = "1.0", features = ["full", "parking_lot", "rt-multi-thread"] } +tower-http = { version = "0.5", features = ["cors", "compression-full"] } +tracing = { version = "0.1", features = ["attributes"] } +tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] } +reqwest = { version = "0.12.9", features = ["stream", "json", "rustls-tls", "http2", "gzip", "brotli"], default-features = false } http = "1.0" -tower = "0.4" -bytes = "1.0" +bytes = { version = "1.0", features = ["serde"] } dotenv = "0.15" -futures-util = "0.3" +futures-util = { version = "0.3", features = ["io"] } once_cell = "1.18" -hyper = { version = "1.0", features = ["full"] } async-trait = "0.1" thiserror = "1.0" serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" \ No newline at end of file +serde_json = "1.0" +num_cpus = "1.15" \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 87b23e5..5e5ed60 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # Build stage -FROM rust:1.82-slim-bookworm as builder +FROM --platform=linux/amd64 rust:1.82-slim-bookworm as builder # Install required dependencies RUN apt-get update && apt-get install -y \ @@ -9,22 +9,38 @@ RUN apt-get update && apt-get install -y \ # Create a new empty shell project WORKDIR /usr/src/app -COPY . . -# Build with release profile for maximum performance -RUN cargo build --release +# Copy only necessary files first +COPY Cargo.toml Cargo.lock ./ + +# Create a dummy main.rs to build dependencies +RUN mkdir src && \ + echo "fn main() {}" > src/main.rs && \ + cargo build --release --target x86_64-unknown-linux-gnu && \ + rm -rf src + +# Now copy the real source code +COPY src ./src + +# Build the application +RUN RUSTFLAGS='-C target-feature=+crt-static' cargo build --release --target x86_64-unknown-linux-gnu && \ + strip target/x86_64-unknown-linux-gnu/release/magicapi-ai-gateway # Runtime stage -FROM debian:bookworm-slim +FROM --platform=linux/amd64 debian:bookworm-slim + +# Add LABEL to identify the image +LABEL org.opencontainers.image.source="https://github.com/magicapi/ai-gateway" +LABEL org.opencontainers.image.description="MagicAPI AI Gateway" +LABEL org.opencontainers.image.version="latest" # Install runtime dependencies RUN apt-get update && apt-get install -y \ ca-certificates \ - libssl3 \ && rm -rf /var/lib/apt/lists/* # Copy the binary from builder -COPY --from=builder /usr/src/app/target/release/magicapi-ai-gateway /usr/local/bin/ +COPY --from=builder /usr/src/app/target/x86_64-unknown-linux-gnu/release/magicapi-ai-gateway /usr/local/bin/ # Set the startup command CMD ["magicapi-ai-gateway"] \ No newline at end of file diff --git a/README.md b/README.md index be9d69e..20985e4 100644 --- a/README.md +++ b/README.md @@ -249,10 +249,15 @@ Special thanks to all our contributors and the Rust community for making this pr 1. Build the Docker image: ```bash -docker build -t magicapi1/magicapi-ai-gateway:latest . +docker buildx build --platform linux/amd64 -t magicapi1/magicapi-ai-gateway:latest . --load ``` -2. Run the container: +2. Push the image to Docker Hub: +```bash +docker push magicapi1/magicapi-ai-gateway:latest +``` + +3. Run the container: ```bash docker run -p 3000:3000 \ -e RUST_LOG=info \ @@ -279,6 +284,7 @@ version: '3.8' services: gateway: build: . + platform: linux/amd64 ports: - "3000:3000" environment: @@ -295,6 +301,7 @@ version: '3.8' services: gateway: image: magicapi1/magicapi-ai-gateway:latest + platform: linux/amd64 ports: - "3000:3000" environment: @@ -327,11 +334,11 @@ git add Cargo.toml CHANGELOG.md git commit -m "chore: release v0.1.6" # Create a git tag -git tag -a v0.1.7 -m "Release v0.1.6" +git tag -a v0.1.7 -m "Release v0.1.7" # Push changes and tag -git push origin release/v0.1.6 -git push origin v0.1.6 +git push origin release/v0.1.7 +git push origin v0.1.7 ``` ### 3. Publishing to crates.io @@ -362,7 +369,7 @@ After publishing, verify: - The new version appears on [crates.io](https://crates.io/crates/magicapi-ai-gateway) - Documentation is updated on [docs.rs](https://docs.rs/magicapi-ai-gateway) - The GitHub release is visible (if using GitHub) -``` + This process follows Rust community best practices for releasing crates. Remember to: - Follow semantic versioning (MAJOR.MINOR.PATCH) @@ -370,4 +377,34 @@ This process follows Rust community best practices for releasing crates. Remembe - Document all significant changes - Keep your repository and crates.io package in sync -Would you like me to explain any part of this process in more detail? \ No newline at end of file +Would you like me to explain any part of this process in more detail? + +## Testing Deployment + +MagicAPI provides a testing deployment of the AI Gateway, hosted in our London data centre. This deployment is intended for testing and evaluation purposes only, and should not be used for production workloads. + +### Testing Gateway URL +``` +https://gateway.magicapi.dev +``` + +### Example Request to Testing Gateway +```bash +curl --location 'https://gateway.magicapi.dev/v1/chat/completions' \ + --header 'Authorization: Bearer YOUR_API_KEY' \ + --header 'Content-Type: application/json' \ + --header 'x-provider: groq' \ + --data '{ + "model": "llama-3.1-8b-instant", + "messages": [ + { + "role": "user", + "content": "Write a poem" + } + ], + "stream": true, + "max_tokens": 300 +}' +``` + +> **Note**: This deployment is provided for testing and evaluation purposes only. For production workloads, please deploy your own instance of the gateway or contact us for information about production-ready managed solutions. \ No newline at end of file diff --git a/k8/ai-gateway.yaml b/k8/ai-gateway.yaml new file mode 100644 index 0000000..b44933e --- /dev/null +++ b/k8/ai-gateway.yaml @@ -0,0 +1,116 @@ +--- +apiVersion: networking.k8s.io/v1 +kind: Ingress +metadata: + name: magicapi-gateway-ingress + namespace: magicapi + annotations: + cert-manager.io/issuer: prod-issuer + cert-manager.io/issuer-kind: OriginIssuer + cert-manager.io/issuer-group: cert-manager.k8s.cloudflare.com + external-dns.alpha.kubernetes.io/hostname: gateway.magicapi.dev + external-dns.alpha.kubernetes.io/cloudflare-proxied: 'true' + nginx.ingress.kubernetes.io/use-proxy-protocol: "true" +spec: + ingressClassName: kong + rules: + - host: gateway.magicapi.dev + http: + paths: + - backend: + service: + name: magicapi-gateway-svc + port: + number: 80 + path: / + pathType: Prefix + tls: + - hosts: + - gateway.magicapi.dev + secretName: magicapi-tls-gateway +--- +apiVersion: v1 +kind: Service +metadata: + name: magicapi-gateway-svc + namespace: magicapi + annotations: + external-dns.alpha.kubernetes.io/hostname: gateway.magicapi.dev +spec: + selector: + app: magicapi-gateway + type: ClusterIP + ports: + - port: 80 + name: http + targetPort: 3000 +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: magicapi-gateway + namespace: magicapi + labels: + app: magicapi-gateway +spec: + replicas: 1 + selector: + matchLabels: + app: magicapi-gateway + strategy: + rollingUpdate: + maxSurge: 10% + maxUnavailable: 0 + type: RollingUpdate + template: + metadata: + labels: + app: magicapi-gateway + spec: + nodeSelector: + kubernetes.io/arch: amd64 + imagePullSecrets: + - name: regcred + containers: + - name: magicapi-gateway + image: magicapi1/magicapi-ai-gateway:latest + env: + - name: RUST_LOG + value: "info" + ports: + - containerPort: 3000 + name: http + imagePullPolicy: Always + resources: + requests: + memory: '512Mi' + cpu: '512m' + limits: + memory: '1024Mi' + cpu: '1024m' + livenessProbe: + httpGet: + path: /health + port: 3000 + initialDelaySeconds: 3 + periodSeconds: 30 +--- +apiVersion: autoscaling/v2 +kind: HorizontalPodAutoscaler +metadata: + name: magicapi-gateway + namespace: magicapi +spec: + scaleTargetRef: + apiVersion: apps/v1 + kind: Deployment + name: magicapi-gateway + minReplicas: 1 + maxReplicas: 50 + metrics: + - type: Resource + resource: + name: cpu + target: + type: Utilization + averageUtilization: 60 diff --git a/src/config.rs b/src/config.rs index f83ae75..91e52e4 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,21 +1,66 @@ -use dotenv::dotenv; use std::env; +use num_cpus; +use tracing::info; +use tracing::debug; pub struct AppConfig { pub port: u16, pub host: String, + pub worker_threads: usize, + pub max_connections: usize, + pub tcp_keepalive_interval: u64, + pub tcp_nodelay: bool, + pub buffer_size: usize, } impl AppConfig { pub fn new() -> Self { - dotenv().ok(); + info!("Loading environment configuration"); + dotenv::dotenv().ok(); + + // Optimize thread count based on CPU cores + let cpu_count = num_cpus::get(); + debug!("Detected {} CPU cores", cpu_count); + + let default_workers = if cpu_count <= 4 { + cpu_count * 2 + } else { + cpu_count + 4 + }; + debug!("Calculated default worker threads: {}", default_workers); - Self { + let config = Self { port: env::var("PORT") .unwrap_or_else(|_| "3000".to_string()) .parse() .expect("PORT must be a number"), host: env::var("HOST").unwrap_or_else(|_| "127.0.0.1".to_string()), - } + worker_threads: env::var("WORKER_THREADS") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(default_workers), + max_connections: env::var("MAX_CONNECTIONS") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(10_000), + tcp_keepalive_interval: env::var("TCP_KEEPALIVE_INTERVAL") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(30), + tcp_nodelay: env::var("TCP_NODELAY") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(true), + buffer_size: env::var("BUFFER_SIZE") + .ok() + .and_then(|v| v.parse().ok()) + .unwrap_or(8 * 1024), // 8KB default + }; + + info!("Configuration loaded: port={}, host={}", config.port, config.host); + debug!("Advanced settings: workers={}, max_conn={}, buffer_size={}", + config.worker_threads, config.max_connections, config.buffer_size); + + config } } diff --git a/src/handlers.rs b/src/handlers.rs index 763c16e..157aff1 100644 --- a/src/handlers.rs +++ b/src/handlers.rs @@ -1,25 +1,24 @@ use crate::{config::AppConfig, proxy::proxy_request_to_provider}; use axum::{ body::Body, - extract::State, + extract::{State, ConnectInfo}, http::{HeaderMap, Request}, response::IntoResponse, Json, }; use serde_json::json; -use std::sync::Arc; -use tracing::{error, info}; +use std::{sync::Arc, net::SocketAddr}; +use tracing::{error, Instrument, debug}; pub async fn health_check() -> impl IntoResponse { - Json(json!({ - "status": "healthy", - "version": env!("CARGO_PKG_VERSION") - })) + debug!("Health check endpoint called"); + Json(json!({ "status": "healthy", "version": env!("CARGO_PKG_VERSION") })) } pub async fn proxy_request( State(config): State>, headers: HeaderMap, + ConnectInfo(addr): ConnectInfo, request: Request, ) -> impl IntoResponse { let provider = headers @@ -27,22 +26,30 @@ pub async fn proxy_request( .and_then(|h| h.to_str().ok()) .unwrap_or("openai"); - info!( + debug!( + "Received request for provider: {}, client: {}, path: {}", + provider, + addr, + request.uri().path() + ); + + let span = tracing::info_span!( + "proxy_request", provider = provider, method = %request.method(), path = %request.uri().path(), - "Incoming proxy request" + client = %addr ); - match proxy_request_to_provider(config, provider, request).await { - Ok(response) => response, - Err(e) => { - error!( - error = %e, - provider = provider, - "Proxy request failed" - ); - e.into_response() + async move { + match proxy_request_to_provider(config, provider, request).await { + Ok(response) => response, + Err(e) => { + error!(error = %e, "Proxy request failed"); + e.into_response() + } } } + .instrument(span) + .await } diff --git a/src/main.rs b/src/main.rs index feb6a99..fbf6f42 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,8 +4,9 @@ use axum::{ }; use std::sync::Arc; use tower_http::cors::{Any, CorsLayer}; -use tracing::{error, info}; +use tracing::{error, info, debug}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; +use std::time::Duration; mod config; mod error; @@ -17,65 +18,83 @@ use crate::config::AppConfig; #[tokio::main] async fn main() { - // Initialize tracing with more detailed format + // Initialize tracing + info!("Initializing tracing system"); tracing_subscriber::registry() .with(tracing_subscriber::EnvFilter::new( - std::env::var("RUST_LOG") - .unwrap_or_else(|_| "info,tower_http=debug,axum::rejection=trace".into()), + std::env::var("RUST_LOG").unwrap_or_else(|_| "info".into()), )) - .with( - tracing_subscriber::fmt::layer() - .with_file(true) - .with_line_number(true) - .with_thread_ids(true) - .with_thread_names(true), - ) + .with(tracing_subscriber::fmt::layer().compact()) .init(); // Load configuration + info!("Loading application configuration"); let config = Arc::new(AppConfig::new()); + debug!("Configuration loaded: port={}, host={}", config.port, config.host); - info!( - host = %config.port, - port = %config.port, - "Starting server with configuration" - ); + // Optimize tokio runtime + info!("Configuring tokio runtime with {} worker threads", config.worker_threads); + std::env::set_var("TOKIO_WORKER_THREADS", config.worker_threads.to_string()); + std::env::set_var("TOKIO_THREAD_STACK_SIZE", (2 * 1024 * 1024).to_string()); // Setup CORS + debug!("Setting up CORS layer with 1-hour max age"); let cors = CorsLayer::new() .allow_origin(Any) .allow_methods(Any) - .allow_headers(Any); - - info!("CORS configuration: allowing all origins, methods, and headers"); + .allow_headers(Any) + .max_age(Duration::from_secs(3600)); - // Create router + // Create router with optimized settings let app = Router::new() .route("/health", get(handlers::health_check)) .route("/v1/*path", any(handlers::proxy_request)) .with_state(config.clone()) - .layer(cors); + .layer(cors) + .into_make_service_with_connect_info::(); - info!("Router configured with health check and proxy endpoints"); - - // Start server + // Start server with optimized TCP settings let addr = std::net::SocketAddr::from(([0, 0, 0, 0], config.port)); + info!("Setting up TCP listener with non-blocking mode"); + let tcp_listener = std::net::TcpListener::bind(addr).expect("Failed to bind address"); + tcp_listener.set_nonblocking(true).expect("Failed to set non-blocking"); + + debug!("Converting to tokio TCP listener"); + let listener = tokio::net::TcpListener::from_std(tcp_listener) + .expect("Failed to create Tokio TCP listener"); + info!( - address = %addr, - "Starting server" + "AI Gateway listening on {}:{} with {} worker threads", + config.host, config.port, config.worker_threads ); - match tokio::net::TcpListener::bind(addr).await { - Ok(listener) => { - info!("Server successfully bound to address"); - if let Err(e) = axum::serve(listener, app).await { - error!(error = %e, "Server error occurred"); - std::process::exit(1); - } - } - Err(e) => { - error!(error = %e, "Failed to bind server to address"); + axum::serve(listener, app) + .with_graceful_shutdown(shutdown_signal()) + .await + .unwrap_or_else(|e| { + error!("Server error: {}", e); std::process::exit(1); - } + }); +} + +async fn shutdown_signal() { + info!("Registering shutdown signal handler"); + let ctrl_c = async { + tokio::signal::ctrl_c() + .await + .expect("Failed to install CTRL+C signal handler") + }; + + let terminate = async { + tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) + .expect("Failed to install signal handler") + .recv() + .await; + }; + + tokio::select! { + _ = ctrl_c => {}, + _ = terminate => {}, } + info!("Shutdown signal received, starting graceful shutdown"); } diff --git a/src/proxy/client.rs b/src/proxy/client.rs index d7931a7..6b226c5 100644 --- a/src/proxy/client.rs +++ b/src/proxy/client.rs @@ -1,12 +1,37 @@ use once_cell::sync::Lazy; use std::time::Duration; +use crate::config::AppConfig; +use tracing::info; +use tracing::debug; + +pub fn create_client(config: &AppConfig) -> reqwest::Client { + info!("Creating HTTP client with optimized settings"); + debug!( + "Client config: max_connections={}, keepalive={}s, nodelay={}", + config.max_connections, + config.tcp_keepalive_interval, + config.tcp_nodelay + ); -pub static CLIENT: Lazy = Lazy::new(|| { reqwest::Client::builder() - .pool_idle_timeout(Duration::from_secs(60)) - .pool_max_idle_per_host(32) - .tcp_keepalive(Duration::from_secs(60)) - .timeout(Duration::from_secs(60)) + .pool_max_idle_per_host(config.max_connections) + .pool_idle_timeout(Duration::from_secs(30)) + .http2_prior_knowledge() + .http2_keep_alive_interval(Duration::from_secs(config.tcp_keepalive_interval)) + .http2_keep_alive_timeout(Duration::from_secs(30)) + .http2_adaptive_window(true) + .tcp_keepalive(Duration::from_secs(config.tcp_keepalive_interval)) + .tcp_nodelay(config.tcp_nodelay) + .use_rustls_tls() + .timeout(Duration::from_secs(30)) + .connect_timeout(Duration::from_secs(10)) + .gzip(true) + .brotli(true) .build() .expect("Failed to create HTTP client") +} + +pub static CLIENT: Lazy = Lazy::new(|| { + let config = AppConfig::new(); + create_client(&config) }); diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index 9d523b2..749e7a7 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -1,11 +1,12 @@ use axum::{ - body::{self, Body, Bytes}, + body::{self, Body}, http::{HeaderMap, HeaderValue, Request, Response, StatusCode}, }; use futures_util::StreamExt; use reqwest::Method; use std::sync::Arc; use tracing::{debug, error, info}; +use bytes::BytesMut; use crate::{config::AppConfig, error::AppError, providers::create_provider}; @@ -13,7 +14,7 @@ mod client; pub use client::CLIENT; pub async fn proxy_request_to_provider( - _config: Arc, + config: Arc, provider_name: &str, original_request: Request, ) -> Result, AppError> { @@ -24,10 +25,10 @@ pub async fn proxy_request_to_provider( "Incoming request" ); - // Create provider instance + debug!("Creating provider instance for: {}", provider_name); let provider = create_provider(provider_name)?; - // Call before request hook + debug!("Executing before_request hook"); provider.before_request(&original_request).await?; let path = original_request.uri().path(); @@ -42,20 +43,21 @@ pub async fn proxy_request_to_provider( let url = format!("{}{}{}", provider.base_url(), modified_path, query); debug!( - provider = provider_name, + provider = provider.name(), url = %url, - "Preparing proxy request" + "Preparing proxy request to {} provider", provider.name() ); // Process headers let headers = provider.process_headers(original_request.headers())?; - // Create and send request + // Create and send request with optimized buffer handling let response = send_provider_request( original_request.method().clone(), url, headers, original_request.into_body(), + config.clone(), ) .await?; @@ -68,105 +70,120 @@ pub async fn proxy_request_to_provider( Ok(processed_response) } -// Helper function to send the actual request async fn send_provider_request( method: http::Method, url: String, headers: HeaderMap, body: Body, + config: Arc, ) -> Result, AppError> { + debug!("Preparing to send request: {} {}", method, url); + let body_bytes = body::to_bytes(body, usize::MAX).await?; + debug!("Request body size: {} bytes", body_bytes.len()); let client = &*CLIENT; - let method = - Method::from_bytes(method.as_str().as_bytes()).map_err(|_| AppError::InvalidMethod)?; + let method = Method::from_bytes(method.as_str().as_bytes()) + .map_err(|_| AppError::InvalidMethod)?; - // Convert http::HeaderMap to reqwest::HeaderMap - let mut reqwest_headers = reqwest::header::HeaderMap::new(); + // Pre-allocate headers map with known capacity + let mut reqwest_headers = reqwest::header::HeaderMap::with_capacity(headers.len()); + + // Batch process headers for (name, value) in headers.iter() { - if let Ok(v) = reqwest::header::HeaderValue::from_bytes(value.as_bytes()) { - // Convert the header name to a string first - if let Ok(name_str) = name.as_str().parse::() { - reqwest_headers.insert(name_str, v); - } + if let (Ok(name_str), Ok(v)) = ( + name.as_str().parse::(), + reqwest::header::HeaderValue::from_bytes(value.as_bytes()), + ) { + reqwest_headers.insert(name_str, v); } } let response = client .request(method, url) - .headers(reqwest_headers) // Now using the converted reqwest::HeaderMap + .headers(reqwest_headers) .body(body_bytes.to_vec()) .send() .await?; - process_response(response).await + process_response(response, config).await } -// Add this function after the send_provider_request function - -async fn process_response(response: reqwest::Response) -> Result, AppError> { +async fn process_response( + response: reqwest::Response, + config: Arc, +) -> Result, AppError> { let status = StatusCode::from_u16(response.status().as_u16())?; + debug!("Processing response with status: {}", status); + + // Pre-allocate headers map + let mut response_headers = HeaderMap::with_capacity(response.headers().len()); + + // Batch process headers + for (name, value) in response.headers() { + if let (Ok(header_name), Ok(v)) = ( + http::HeaderName::from_bytes(name.as_ref()), + HeaderValue::from_bytes(value.as_bytes()), + ) { + response_headers.insert(header_name, v); + } + } - // Check if response is a stream - if response - .headers() + // Check for streaming response + if response.headers() .get(reqwest::header::CONTENT_TYPE) .and_then(|v| v.to_str().ok()) .map_or(false, |ct| ct.contains("text/event-stream")) { - debug!("Processing streaming response"); - - // Convert headers - let mut response_headers = HeaderMap::new(); - for (name, value) in response.headers() { - if let Ok(v) = HeaderValue::from_bytes(value.as_bytes()) { - if let Ok(header_name) = http::HeaderName::from_bytes(name.as_ref()) { - response_headers.insert(header_name, v); + info!("Processing streaming response"); + debug!("Setting up stream with buffer size: {}", config.buffer_size); + + let stream = response.bytes_stream().map(move |result| { + match result { + Ok(bytes) => { + let mut buffer = BytesMut::with_capacity(config.buffer_size); + buffer.extend_from_slice(&bytes); + Ok(buffer.freeze()) + } + Err(e) => { + error!("Stream error: {}", e); + Err(std::io::Error::new(std::io::ErrorKind::Other, e)) } - } - } - - // Set up streaming response - let stream = response.bytes_stream().map(|result| match result { - Ok(bytes) => { - debug!("Streaming chunk: {} bytes", bytes.len()); - Ok(bytes) - } - Err(e) => { - error!("Stream error: {}", e); - Err(std::io::Error::new(std::io::ErrorKind::Other, e)) } }); - Ok(Response::builder() + let mut response_builder = Response::builder() .status(status) .header("content-type", "text/event-stream") .header("cache-control", "no-cache") - .header("connection", "keep-alive") - .extension(response_headers) - .body(Body::from_stream(stream)) - .unwrap()) - } else { - debug!("Processing regular response"); + .header("connection", "keep-alive"); - // Convert headers - let mut response_headers = HeaderMap::new(); - for (name, value) in response.headers() { - if let Ok(v) = HeaderValue::from_bytes(value.as_bytes()) { - if let Ok(header_name) = http::HeaderName::from_bytes(name.as_ref()) { - response_headers.insert(header_name, v); - } + // Add all headers from response_headers + for (key, value) in response_headers { + if let Some(key) = key { + response_builder = response_builder.header(key, value); } } - // Process regular response body + Ok(response_builder + .body(Body::from_stream(stream)) + .unwrap()) + } else { + debug!("Processing regular response"); + let body = response.bytes().await?; - - let mut builder = Response::builder().status(status); - for (name, value) in response_headers.iter() { - builder = builder.header(name, value); + + let mut response_builder = Response::builder().status(status); + + // Add all headers from response_headers + for (key, value) in response_headers { + if let Some(key) = key { + response_builder = response_builder.header(key, value); + } } - Ok(builder.body(Body::from(body)).unwrap()) + Ok(response_builder + .body(Body::from(body)) + .unwrap()) } }