From 839fa434af7a0d44d5fd7b3fc49cf179641147c9 Mon Sep 17 00:00:00 2001 From: Shane Osbourne Date: Sun, 8 Sep 2024 15:54:25 +0100 Subject: [PATCH 1/2] initial impl --- crates/bsnext_core/src/common_layers.rs | 63 ------------ crates/bsnext_core/src/handler_stack.rs | 17 ++-- crates/bsnext_core/src/handlers/proxy.rs | 9 +- crates/bsnext_core/src/lib.rs | 2 +- crates/bsnext_core/src/optional_layers.rs | 95 +++++++++++++++++++ ...handler_stack__test__handler_stack_01.snap | 1 + ...handler_stack__test__handler_stack_02.snap | 2 + ...handler_stack__test__handler_stack_03.snap | 1 + crates/bsnext_core/tests/manual_service.rs | 79 +++++++++++++++ .../manual_service__manual_service_impl.snap | 15 +++ crates/bsnext_input/src/input_test/mod.rs | 73 +++++++++++++- .../bsnext_input__input_test__com_yaml.snap | 36 +++++++ ...ut__input_test__deserialize_3_headers.snap | 1 + ...t_test__deserialize_3_headers_control.snap | 1 + ...test__deserialize_compressions_gzip-2.snap | 9 ++ ...t_test__deserialize_compressions_gzip.snap | 9 ++ crates/bsnext_input/src/route.rs | 27 ++++++ .../bsnext_input__route__com_yaml.snap | 36 +++++++ crates/bsnext_resp/src/inject_opts.rs | 12 ++- ..._kind__start_from_paths__test__test-2.snap | 1 + ...rt_kind__start_from_paths__test__test.snap | 1 + 21 files changed, 414 insertions(+), 76 deletions(-) delete mode 100644 crates/bsnext_core/src/common_layers.rs create mode 100644 crates/bsnext_core/src/optional_layers.rs create mode 100644 crates/bsnext_core/tests/manual_service.rs create mode 100644 crates/bsnext_core/tests/snapshots/manual_service__manual_service_impl.snap create mode 100644 crates/bsnext_input/src/input_test/snapshots/bsnext_input__input_test__com_yaml.snap create mode 100644 crates/bsnext_input/src/input_test/snapshots/bsnext_input__input_test__deserialize_compressions_gzip-2.snap create mode 100644 crates/bsnext_input/src/input_test/snapshots/bsnext_input__input_test__deserialize_compressions_gzip.snap create mode 100644 crates/bsnext_input/src/snapshots/bsnext_input__route__com_yaml.snap diff --git a/crates/bsnext_core/src/common_layers.rs b/crates/bsnext_core/src/common_layers.rs deleted file mode 100644 index a163d08..0000000 --- a/crates/bsnext_core/src/common_layers.rs +++ /dev/null @@ -1,63 +0,0 @@ -use axum::extract::Request; -use axum::middleware::Next; -use axum::routing::MethodRouter; -use axum::{middleware, Extension}; -use bsnext_input::route::{CorsOpts, DelayKind, DelayOpts, Opts}; -use bsnext_resp::{response_modifications_layer, InjectHandling}; -use http::{HeaderName, HeaderValue}; -use std::convert::Infallible; -use std::time::Duration; -use tokio::time::sleep; -use tower_http::cors::CorsLayer; -use tower_http::set_header; - -pub fn add_route_layers(app: MethodRouter, opts: &Opts) -> MethodRouter { - let mut app = app; - - if opts - .cors - .as_ref() - .is_some_and(|v| *v == CorsOpts::Cors(true)) - { - app = app.layer(CorsLayer::permissive()); - } - - if let Some(DelayOpts::Delay(DelayKind::Ms(ms))) = opts.delay.as_ref() { - let ms = *ms; - app = app.layer(middleware::from_fn( - move |req: Request, next: Next| async move { - let res = next.run(req).await; - sleep(Duration::from_millis(ms)).await; - Ok::<_, Infallible>(res) - }, - )); - } - - if let Some(headers) = opts.headers.as_ref() { - for (k, v) in headers { - let hn = HeaderName::from_bytes(k.as_bytes()); - let hv = HeaderValue::from_bytes(v.as_bytes()); - match (hn, hv) { - (Ok(n), Ok(v)) => { - app = app.layer(set_header::SetResponseHeaderLayer::overriding(n, v)); - } - (Ok(_), Err(_e)) => { - tracing::error!("invalid header value `{}`", v) - } - (Err(_e), Ok(_)) => { - tracing::error!("invalid header name `{}`", k) - } - (Err(_e), Err(_e2)) => { - tracing::error!("invalid header name AND value `{}:{}`", k, v) - } - } - } - } - - let injections = opts.inject.injections(); - app = app - .layer::<_, Infallible>(middleware::from_fn(response_modifications_layer)) - .layer(Extension(InjectHandling { items: injections })); - - app -} diff --git a/crates/bsnext_core/src/handler_stack.rs b/crates/bsnext_core/src/handler_stack.rs index bf1f9ef..4742d75 100644 --- a/crates/bsnext_core/src/handler_stack.rs +++ b/crates/bsnext_core/src/handler_stack.rs @@ -1,5 +1,5 @@ -use crate::common_layers::add_route_layers; use crate::handlers::proxy::{proxy_handler, ProxyConfig}; +use crate::optional_layers::optional_layers; use crate::raw_loader::serve_raw_one; use crate::serve_dir::try_many_services_dir; use axum::handler::Handler; @@ -8,6 +8,7 @@ use axum::routing::{any, any_service, get_service, MethodRouter}; use axum::{Extension, Router}; use bsnext_input::route::{DirRoute, FallbackRoute, Opts, ProxyRoute, RawRoute, Route, RouteKind}; use std::collections::HashMap; +use tower::Layer; use tower_http::services::{ServeDir, ServeFile}; #[derive(Debug, PartialEq)] @@ -168,7 +169,7 @@ pub fn fallback_to_layered_method_router(route: FallbackRoute) -> MethodRouter { match route.kind { RouteKind::Raw(raw_route) => { let svc = any_service(serve_raw_one.with_state(raw_route)); - add_route_layers(svc, &route.opts) + optional_layers(svc, &route.opts) } RouteKind::Proxy(_new_proxy_route) => { // todo(alpha): make a decision proxy as a fallback @@ -179,7 +180,7 @@ pub fn fallback_to_layered_method_router(route: FallbackRoute) -> MethodRouter { let item = DirRouteOpts::new(dir, route.opts, None); let serve_dir_service = item.as_serve_file(); let service = get_service(serve_dir_service); - let layered = add_route_layers(service, &item.opts); + let layered = optional_layers(service, &item.opts); layered } } @@ -196,7 +197,9 @@ pub fn stack_to_router(path: &str, stack: HandlerStack) -> Router { HandlerStack::None => unreachable!(), HandlerStack::Raw { raw, opts } => { let svc = any_service(serve_raw_one.with_state(raw)); - Router::new().route_service(path, add_route_layers(svc, &opts)) + let out = optional_layers(svc, &opts); + + Router::new().route_service(path, out) } HandlerStack::Dirs(dirs) => { Router::new().nest_service(path, serve_dir_layer(&dirs, Router::new())) @@ -210,7 +213,7 @@ pub fn stack_to_router(path: &str, stack: HandlerStack) -> Router { let proxy_with_decompression = proxy_handler.layer(Extension(proxy_config.clone())); let as_service = any(proxy_with_decompression); - Router::new().nest_service(path, add_route_layers(as_service, &opts)) + Router::new().nest_service(path, optional_layers(as_service, &opts)) } HandlerStack::DirsProxy(dir_list, proxy) => { let r2 = stack_to_router( @@ -233,7 +236,7 @@ fn serve_dir_layer(dir_list_with_opts: &[DirRouteOpts], initial: Router) -> Rout None => { let serve_dir_service = dir_route.as_serve_dir(); let service = get_service(serve_dir_service); - let layered = add_route_layers(service, &dir_route.opts); + let layered = optional_layers(service, &dir_route.opts); layered } Some(fallback) => { @@ -243,7 +246,7 @@ fn serve_dir_layer(dir_list_with_opts: &[DirRouteOpts], initial: Router) -> Rout .fallback(stack) .call_fallback_on_method_not_allowed(true); let service = any_service(serve_dir_service); - let layered = add_route_layers(service, &dir_route.opts); + let layered = optional_layers(service, &dir_route.opts); layered } }) diff --git a/crates/bsnext_core/src/handlers/proxy.rs b/crates/bsnext_core/src/handlers/proxy.rs index 9d6dbca..74418ae 100644 --- a/crates/bsnext_core/src/handlers/proxy.rs +++ b/crates/bsnext_core/src/handlers/proxy.rs @@ -83,7 +83,14 @@ pub async fn proxy_handler( // decompress requests if needed if let Some(h) = req.extensions().get::() { let req_accepted = h.items.iter().any(|item| item.accept_req(&req)); - tracing::trace!(req.accepted = req_accepted); + tracing::trace!( + req.accepted = req_accepted, + req.accept.header = req + .headers() + .get("accept") + .map(|h| h.to_str().unwrap_or("")), + "will accept request + decompress" + ); if req_accepted { let sv2 = any(serve_raw_one.layer(DecompressionLayer::new())); return Ok(sv2.oneshot(req).await.into_response()); diff --git a/crates/bsnext_core/src/lib.rs b/crates/bsnext_core/src/lib.rs index 0aaa480..e842677 100644 --- a/crates/bsnext_core/src/lib.rs +++ b/crates/bsnext_core/src/lib.rs @@ -1,12 +1,12 @@ pub mod server; pub mod servers_supervisor; -pub mod common_layers; pub mod dir_loader; mod handler_stack; pub mod handlers; pub mod meta; pub mod not_found; +pub mod optional_layers; pub mod panic_handler; pub mod proxy_loader; pub mod raw_loader; diff --git a/crates/bsnext_core/src/optional_layers.rs b/crates/bsnext_core/src/optional_layers.rs new file mode 100644 index 0000000..30e58ca --- /dev/null +++ b/crates/bsnext_core/src/optional_layers.rs @@ -0,0 +1,95 @@ +use axum::extract::{Request, State}; +use axum::handler::Handler; +use axum::middleware::{map_response_with_state, Next}; +use axum::response::{IntoResponse, Response}; +use axum::routing::MethodRouter; +use axum::{middleware, Extension}; +use axum_extra::middleware::option_layer; +use bsnext_input::route::{CorsOpts, DelayKind, DelayOpts, Opts}; +use bsnext_resp::{response_modifications_layer, InjectHandling}; +use http::{HeaderName, HeaderValue}; +use std::collections::BTreeMap; +use std::convert::Infallible; +use std::time::Duration; +use tokio::time::sleep; +use tower::{Layer, ServiceBuilder}; +use tower_http::compression::CompressionLayer; +use tower_http::cors::CorsLayer; + +pub fn optional_layers(app: MethodRouter, opts: &Opts) -> MethodRouter { + let app = app; + let cors_enabled_layer = opts + .cors + .as_ref() + .filter(|v| **v == CorsOpts::Cors(true)) + .map(|_| CorsLayer::permissive()); + + let delay_enabled_layer = opts + .delay + .as_ref() + .map(|delay| middleware::from_fn_with_state(delay.clone(), delay_mw)); + + let injections = opts.inject.as_injections(); + let inject_layer = Some(injections.items.len()) + .filter(|inj| *inj > 0) + .map(|_| middleware::from_fn(response_modifications_layer)); + + let set_response_headers_layer = opts + .headers + .as_ref() + .map(|headers| map_response_with_state(headers.clone(), set_resp_headers)); + + let optional_stack = ServiceBuilder::new() + .layer(CompressionLayer::new()) + .layer(option_layer(inject_layer)) + .layer(option_layer(set_response_headers_layer)) + .layer(option_layer(cors_enabled_layer)) + .layer(option_layer(delay_enabled_layer)); + + app.layer::<_, Infallible>(optional_stack) + .layer(Extension(InjectHandling { + items: injections.items, + })) +} + +async fn delay_mw( + State(delay_opts): State, + req: Request, + next: Next, +) -> impl IntoResponse { + match delay_opts { + DelayOpts::Delay(DelayKind::Ms(ms)) => { + let res = next.run(req).await; + sleep(Duration::from_millis(ms)).await; + Ok::<_, Infallible>(res) + } + } +} + +async fn set_resp_headers( + State(header_map): State>, + mut response: Response, +) -> Response { + let headers = response.headers_mut(); + for (k, v) in header_map { + let hn = HeaderName::from_bytes(k.as_bytes()); + let hv = HeaderValue::from_bytes(v.as_bytes()); + match (hn, hv) { + (Ok(k), Ok(v)) => { + tracing::debug!("did insert header `{}`: `{:?}`", k, v); + headers.insert(k, v); + } + (Ok(n), Err(_e)) => { + tracing::error!("invalid header value: `{}` for name: `{}`", v, n) + } + (Err(_e), Ok(v)) => { + tracing::error!("invalid header name `{}`", k) + } + (Err(_e), Err(_e2)) => { + tracing::error!("invalid header name AND value `{}:{}`", k, v) + } + } + } + + response +} diff --git a/crates/bsnext_core/src/snapshots/bsnext_core__handler_stack__test__handler_stack_01.snap b/crates/bsnext_core/src/snapshots/bsnext_core__handler_stack__test__handler_stack_01.snap index 9373fde..40fc287 100644 --- a/crates/bsnext_core/src/snapshots/bsnext_core__handler_stack__test__handler_stack_01.snap +++ b/crates/bsnext_core/src/snapshots/bsnext_core__handler_stack__test__handler_stack_01.snap @@ -16,5 +16,6 @@ Raw { true, ), headers: None, + compression: None, }, } diff --git a/crates/bsnext_core/src/snapshots/bsnext_core__handler_stack__test__handler_stack_02.snap b/crates/bsnext_core/src/snapshots/bsnext_core__handler_stack__test__handler_stack_02.snap index d057f98..02c7077 100644 --- a/crates/bsnext_core/src/snapshots/bsnext_core__handler_stack__test__handler_stack_02.snap +++ b/crates/bsnext_core/src/snapshots/bsnext_core__handler_stack__test__handler_stack_02.snap @@ -19,6 +19,7 @@ DirsProxy( true, ), headers: None, + compression: None, }, fallback_route: None, }, @@ -37,6 +38,7 @@ DirsProxy( true, ), headers: None, + compression: None, }, fallback_route: None, }, diff --git a/crates/bsnext_core/src/snapshots/bsnext_core__handler_stack__test__handler_stack_03.snap b/crates/bsnext_core/src/snapshots/bsnext_core__handler_stack__test__handler_stack_03.snap index 7cab765..32221e9 100644 --- a/crates/bsnext_core/src/snapshots/bsnext_core__handler_stack__test__handler_stack_03.snap +++ b/crates/bsnext_core/src/snapshots/bsnext_core__handler_stack__test__handler_stack_03.snap @@ -20,5 +20,6 @@ Raw { true, ), headers: None, + compression: None, }, } diff --git a/crates/bsnext_core/tests/manual_service.rs b/crates/bsnext_core/tests/manual_service.rs new file mode 100644 index 0000000..b8fa2ea --- /dev/null +++ b/crates/bsnext_core/tests/manual_service.rs @@ -0,0 +1,79 @@ +use axum::body::Body; +use axum::extract::Request; +use axum::response::Response; +use axum::Router; +use futures_util::future::BoxFuture; +use http::{HeaderName, HeaderValue}; +use insta::assert_debug_snapshot; +use std::collections::BTreeMap; +use std::task::{Context, Poll}; +use tower::layer::layer_fn; +use tower::{Layer, Service, ServiceBuilder, ServiceExt}; + +#[tokio::test] +async fn test_manual_service_impl() -> anyhow::Result<()> { + let app = + Router::new().layer( + ServiceBuilder::new().service(layer_fn(|service| MyMiddleware { + inner: service, + headers: [("a".to_string(), "b".to_string())].into(), + })), + ); + let req = Request::get("/").body(Body::empty())?; + let s = app.oneshot(req).await?; + assert_debug_snapshot!(s); + Ok(()) +} + +#[derive(Clone)] +struct MyLayer { + headers: BTreeMap, +} + +impl Layer for MyLayer { + type Service = MyMiddleware; + + fn layer(&self, inner: S) -> Self::Service { + MyMiddleware { + inner, + headers: self.headers.clone(), + } + } +} + +#[derive(Clone)] +struct MyMiddleware { + inner: S, + headers: BTreeMap, +} + +impl Service for MyMiddleware +where + S: Service + Send + 'static, + S::Future: Send + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, request: Request) -> Self::Future { + let future = self.inner.call(request); + let headers = self.headers.clone(); + Box::pin(async move { + let mut response: Response = future.await?; + let header_map = response.headers_mut(); + for (k, v) in headers { + let hn = HeaderName::from_bytes(k.as_bytes()); + let hv = HeaderValue::from_bytes(v.as_bytes()); + if let (Ok(k), Ok(v)) = (hn, hv) { + header_map.insert(k, v); + } + } + Ok(response) + }) + } +} diff --git a/crates/bsnext_core/tests/snapshots/manual_service__manual_service_impl.snap b/crates/bsnext_core/tests/snapshots/manual_service__manual_service_impl.snap new file mode 100644 index 0000000..1c0b6f1 --- /dev/null +++ b/crates/bsnext_core/tests/snapshots/manual_service__manual_service_impl.snap @@ -0,0 +1,15 @@ +--- +source: crates/bsnext_core/tests/manual_service.rs +expression: s +--- +Response { + status: 404, + version: HTTP/1.1, + headers: { + "content-length": "0", + "a": "b", + }, + body: Body( + UnsyncBoxBody, + ), +} diff --git a/crates/bsnext_input/src/input_test/mod.rs b/crates/bsnext_input/src/input_test/mod.rs index ccb07a3..51fce9f 100644 --- a/crates/bsnext_input/src/input_test/mod.rs +++ b/crates/bsnext_input/src/input_test/mod.rs @@ -1,8 +1,10 @@ use crate::route::{ - CorsOpts, DebounceDuration, DelayKind, DelayOpts, FilterKind, Route, Spec, SpecOpts, Watcher, + CompressionOpts, CorsOpts, DebounceDuration, DelayKind, DelayOpts, FilterKind, Route, Spec, + SpecOpts, Watcher, }; use crate::watch_opts::WatchOpts; use crate::Input; +use insta::assert_debug_snapshot; #[test] fn test_deserialize() { @@ -62,6 +64,75 @@ fn test_deserialize_cors_false() { assert_eq!(opts, CorsOpts::Cors(false)); } +#[test] +fn test_deserialize_compressions_absent() { + #[derive(serde::Deserialize, serde::Serialize, Debug)] + struct Config { + pub items: Vec, + } + + let input = r#" + items: + - path: /hello.js + raw: "hello" + "#; + let c: Config = serde_yaml::from_str(input).unwrap(); + let first = c.items.get(0).unwrap(); + assert_eq!(first.opts.compression, None); +} + +#[test] +fn test_deserialize_compressions_true() { + #[derive(serde::Deserialize, serde::Serialize, Debug)] + struct Config { + pub items: Vec, + } + + let input = r#" + items: + - path: /hello.js + raw: "hello" + compression: true + "#; + let c: Config = serde_yaml::from_str(input).unwrap(); + let first = c.items.get(0).unwrap(); + assert_eq!(first.opts.compression, Some(CompressionOpts::Bool(true))); +} +#[test] +fn test_deserialize_compressions_gzip() { + let input = r#" + - path: /hello.js + raw: "hello" + compression: gzip + - path: /hello2.js + raw: "hello" + compression: br + "#; + let c: Vec = serde_yaml::from_str(input).unwrap(); + assert_debug_snapshot!(c.get(0).unwrap().opts.compression); + assert_debug_snapshot!(c.get(1).unwrap().opts.compression); +} + +#[test] +fn test_com_yaml() -> anyhow::Result<()> { + #[derive(Debug, PartialEq, Hash, Clone, serde::Deserialize, serde::Serialize)] + struct V { + compression: CompressionOpts, + } + + let input = r#" + - compression: true + - compression: false + - compression: br + - compression: gzip + - compression: zstd + - compression: deflate + "#; + let v: Vec = serde_yaml::from_str(input)?; + assert_debug_snapshot!(v); + Ok(()) +} + #[test] fn test_deserialize_3_headers_control() { #[derive(serde::Deserialize, serde::Serialize, Debug)] diff --git a/crates/bsnext_input/src/input_test/snapshots/bsnext_input__input_test__com_yaml.snap b/crates/bsnext_input/src/input_test/snapshots/bsnext_input__input_test__com_yaml.snap new file mode 100644 index 0000000..fc4fdf8 --- /dev/null +++ b/crates/bsnext_input/src/input_test/snapshots/bsnext_input__input_test__com_yaml.snap @@ -0,0 +1,36 @@ +--- +source: crates/bsnext_input/src/input_test/mod.rs +expression: v +--- +[ + V { + compression: Bool( + true, + ), + }, + V { + compression: Bool( + false, + ), + }, + V { + compression: CompType( + Br, + ), + }, + V { + compression: CompType( + Gzip, + ), + }, + V { + compression: CompType( + Zstd, + ), + }, + V { + compression: CompType( + Deflate, + ), + }, +] diff --git a/crates/bsnext_input/src/input_test/snapshots/bsnext_input__input_test__deserialize_3_headers.snap b/crates/bsnext_input/src/input_test/snapshots/bsnext_input__input_test__deserialize_3_headers.snap index 5310630..1491ddc 100644 --- a/crates/bsnext_input/src/input_test/snapshots/bsnext_input__input_test__deserialize_3_headers.snap +++ b/crates/bsnext_input/src/input_test/snapshots/bsnext_input__input_test__deserialize_3_headers.snap @@ -30,6 +30,7 @@ Config { "a": "b", }, ), + compression: None, }, fallback: None, }, diff --git a/crates/bsnext_input/src/input_test/snapshots/bsnext_input__input_test__deserialize_3_headers_control.snap b/crates/bsnext_input/src/input_test/snapshots/bsnext_input__input_test__deserialize_3_headers_control.snap index 2206687..745d5f8 100644 --- a/crates/bsnext_input/src/input_test/snapshots/bsnext_input__input_test__deserialize_3_headers_control.snap +++ b/crates/bsnext_input/src/input_test/snapshots/bsnext_input__input_test__deserialize_3_headers_control.snap @@ -26,6 +26,7 @@ Config { true, ), headers: None, + compression: None, }, fallback: None, }, diff --git a/crates/bsnext_input/src/input_test/snapshots/bsnext_input__input_test__deserialize_compressions_gzip-2.snap b/crates/bsnext_input/src/input_test/snapshots/bsnext_input__input_test__deserialize_compressions_gzip-2.snap new file mode 100644 index 0000000..614dd57 --- /dev/null +++ b/crates/bsnext_input/src/input_test/snapshots/bsnext_input__input_test__deserialize_compressions_gzip-2.snap @@ -0,0 +1,9 @@ +--- +source: crates/bsnext_input/src/input_test/mod.rs +expression: c.get(1).unwrap().opts.compression +--- +Some( + CompType( + Br, + ), +) diff --git a/crates/bsnext_input/src/input_test/snapshots/bsnext_input__input_test__deserialize_compressions_gzip.snap b/crates/bsnext_input/src/input_test/snapshots/bsnext_input__input_test__deserialize_compressions_gzip.snap new file mode 100644 index 0000000..e3a228e --- /dev/null +++ b/crates/bsnext_input/src/input_test/snapshots/bsnext_input__input_test__deserialize_compressions_gzip.snap @@ -0,0 +1,9 @@ +--- +source: crates/bsnext_input/src/input_test/mod.rs +expression: c.get(0).unwrap().opts.compression +--- +Some( + CompType( + Gzip, + ), +) diff --git a/crates/bsnext_input/src/route.rs b/crates/bsnext_input/src/route.rs index 2a4a452..5237ae9 100644 --- a/crates/bsnext_input/src/route.rs +++ b/crates/bsnext_input/src/route.rs @@ -35,6 +35,7 @@ pub struct Opts { #[serde(default)] pub inject: InjectOpts, pub headers: Option>, + pub compression: Option, } impl Default for Route { @@ -48,6 +49,7 @@ impl Default for Route { delay: None, watch: Default::default(), inject: Default::default(), + compression: Default::default(), }, fallback: Default::default(), } @@ -143,6 +145,31 @@ pub enum CorsOpts { Cors(bool), } +#[derive(Debug, PartialEq, Hash, Clone, serde::Deserialize, serde::Serialize)] +#[serde(untagged)] +pub enum CompressionOpts { + Bool(bool), + CompType(CompType), +} + +#[derive(Debug, PartialEq, Hash, Clone, serde::Deserialize, serde::Serialize)] +pub enum CompType { + #[serde(rename = "gzip")] + Gzip, + #[serde(rename = "br")] + Br, + #[serde(rename = "deflate")] + Deflate, + #[serde(rename = "zstd")] + Zstd, +} + +impl Default for CompressionOpts { + fn default() -> Self { + Self::Bool(false) + } +} + #[derive(Debug, PartialEq, Hash, Clone, serde::Deserialize, serde::Serialize)] pub enum DelayOpts { #[serde(rename = "delay")] diff --git a/crates/bsnext_input/src/snapshots/bsnext_input__route__com_yaml.snap b/crates/bsnext_input/src/snapshots/bsnext_input__route__com_yaml.snap new file mode 100644 index 0000000..5101f34 --- /dev/null +++ b/crates/bsnext_input/src/snapshots/bsnext_input__route__com_yaml.snap @@ -0,0 +1,36 @@ +--- +source: crates/bsnext_input/src/route.rs +expression: v +--- +[ + V { + compression: Bool( + true, + ), + }, + V { + compression: Bool( + false, + ), + }, + V { + compression: CompType( + Br, + ), + }, + V { + compression: CompType( + Gzip, + ), + }, + V { + compression: CompType( + Zstd, + ), + }, + V { + compression: CompType( + Deflate, + ), + }, +] diff --git a/crates/bsnext_resp/src/inject_opts.rs b/crates/bsnext_resp/src/inject_opts.rs index aa6c1c8..ddb324e 100644 --- a/crates/bsnext_resp/src/inject_opts.rs +++ b/crates/bsnext_resp/src/inject_opts.rs @@ -14,9 +14,14 @@ pub enum InjectOpts { Items(Vec), } +#[derive(Debug, PartialEq)] +pub struct Injections { + pub items: Vec, +} + impl InjectOpts { - pub fn injections(&self) -> Vec { - match self { + pub fn as_injections(&self) -> Injections { + let items = match self { InjectOpts::Bool(true) => { vec![InjectionItem { inner: Injection::BsLive(BuiltinStringDef { @@ -32,7 +37,8 @@ impl InjectOpts { // todo: is this too expensive? InjectOpts::Items(items) => items.to_owned(), InjectOpts::Item(item) => vec![item.to_owned()], - } + }; + Injections { items } } } diff --git a/crates/bsnext_system/src/start_kind/snapshots/bsnext_system__start_kind__start_from_paths__test__test-2.snap b/crates/bsnext_system/src/start_kind/snapshots/bsnext_system__start_kind__start_from_paths__test__test-2.snap index f3f1784..5659050 100644 --- a/crates/bsnext_system/src/start_kind/snapshots/bsnext_system__start_kind__start_from_paths__test__test-2.snap +++ b/crates/bsnext_system/src/start_kind/snapshots/bsnext_system__start_kind__start_from_paths__test__test-2.snap @@ -11,5 +11,6 @@ servers: watch: true inject: true headers: ~ + compression: ~ fallback: ~ watchers: [] diff --git a/crates/bsnext_system/src/start_kind/snapshots/bsnext_system__start_kind__start_from_paths__test__test.snap b/crates/bsnext_system/src/start_kind/snapshots/bsnext_system__start_kind__start_from_paths__test__test.snap index 7e57bd8..f1c3cfa 100644 --- a/crates/bsnext_system/src/start_kind/snapshots/bsnext_system__start_kind__start_from_paths__test__test.snap +++ b/crates/bsnext_system/src/start_kind/snapshots/bsnext_system__start_kind__start_from_paths__test__test.snap @@ -27,6 +27,7 @@ Input { true, ), headers: None, + compression: None, }, fallback: None, }, From cef4281138c3b2f1de93622f303485e29325b87b Mon Sep 17 00:00:00 2001 From: Shane Osbourne Date: Sun, 8 Sep 2024 16:34:51 +0100 Subject: [PATCH 2/2] added playwright tests for compression layer --- crates/bsnext_core/src/optional_layers.rs | 60 ++++++++++++++++++++--- examples/react-router/bslive.yaml | 13 +++++ tests/examples.spec.ts | 20 ++++++++ tests/utils.ts | 14 +++++- 4 files changed, 99 insertions(+), 8 deletions(-) diff --git a/crates/bsnext_core/src/optional_layers.rs b/crates/bsnext_core/src/optional_layers.rs index 30e58ca..9799e34 100644 --- a/crates/bsnext_core/src/optional_layers.rs +++ b/crates/bsnext_core/src/optional_layers.rs @@ -5,7 +5,7 @@ use axum::response::{IntoResponse, Response}; use axum::routing::MethodRouter; use axum::{middleware, Extension}; use axum_extra::middleware::option_layer; -use bsnext_input::route::{CorsOpts, DelayKind, DelayOpts, Opts}; +use bsnext_input::route::{CompType, CompressionOpts, CorsOpts, DelayKind, DelayOpts, Opts}; use bsnext_resp::{response_modifications_layer, InjectHandling}; use http::{HeaderName, HeaderValue}; use std::collections::BTreeMap; @@ -17,13 +17,15 @@ use tower_http::compression::CompressionLayer; use tower_http::cors::CorsLayer; pub fn optional_layers(app: MethodRouter, opts: &Opts) -> MethodRouter { - let app = app; + let mut app = app; let cors_enabled_layer = opts .cors .as_ref() .filter(|v| **v == CorsOpts::Cors(true)) .map(|_| CorsLayer::permissive()); + let compression_layer = opts.compression.as_ref().and_then(comp_opts_to_layer); + let delay_enabled_layer = opts .delay .as_ref() @@ -40,16 +42,23 @@ pub fn optional_layers(app: MethodRouter, opts: &Opts) -> MethodRouter { .map(|headers| map_response_with_state(headers.clone(), set_resp_headers)); let optional_stack = ServiceBuilder::new() - .layer(CompressionLayer::new()) .layer(option_layer(inject_layer)) .layer(option_layer(set_response_headers_layer)) .layer(option_layer(cors_enabled_layer)) .layer(option_layer(delay_enabled_layer)); - app.layer::<_, Infallible>(optional_stack) - .layer(Extension(InjectHandling { - items: injections.items, - })) + app = app.layer(optional_stack); + + // The compression layer has a different type, so needs to apply outside the optional stack + // this essentially wrapping everything. + // I'm sure there's a cleaner way... + if let Some(cl) = compression_layer { + app = app.layer(cl); + } + + app.layer(Extension(InjectHandling { + items: injections.items, + })) } async fn delay_mw( @@ -93,3 +102,40 @@ async fn set_resp_headers( response } + +fn comp_opts_to_layer(comp: &CompressionOpts) -> Option { + match comp { + CompressionOpts::Bool(false) => None, + CompressionOpts::Bool(true) => Some(CompressionLayer::new()), + CompressionOpts::CompType(comp_type) => match comp_type { + CompType::Gzip => Some( + CompressionLayer::new() + .gzip(true) + .no_br() + .no_deflate() + .no_zstd(), + ), + CompType::Br => Some( + CompressionLayer::new() + .br(true) + .no_gzip() + .no_deflate() + .no_zstd(), + ), + CompType::Deflate => Some( + CompressionLayer::new() + .deflate(true) + .no_gzip() + .no_br() + .no_zstd(), + ), + CompType::Zstd => Some( + CompressionLayer::new() + .zstd(true) + .no_gzip() + .no_deflate() + .no_br(), + ), + }, + } +} diff --git a/examples/react-router/bslive.yaml b/examples/react-router/bslive.yaml index 0dca9ab..b8d5204 100644 --- a/examples/react-router/bslive.yaml +++ b/examples/react-router/bslive.yaml @@ -8,6 +8,19 @@ servers: dir: examples/react-router/dist/index.html # and this route just shows an example of a route + delay for testing + - path: /abc + json: [ 1, 2, 3 ] + delay: + ms: 1000 + + ## This server is just like the one above, but it adds `compression: true` + - name: 'react-router-with-compression' + routes: + - path: / + dir: examples/react-router/dist + compression: true + fallback: + dir: examples/react-router/dist/index.html - path: /abc json: [ 1, 2, 3 ] delay: diff --git a/tests/examples.spec.ts b/tests/examples.spec.ts index 8685bb3..29932b8 100644 --- a/tests/examples.spec.ts +++ b/tests/examples.spec.ts @@ -123,4 +123,24 @@ test.describe('examples/react-router/bslive.yaml', { await page.goto(bs.path('/'), {waitUntil: 'networkidle'}) await expect(page.locator('#root')).toContainText('API response from /abc[1,2,3]'); }); + test('supports compressed responses', async ({page, bs}) => { + const load = page.goto(bs.named('react-router-with-compression', '/'), {waitUntil: 'networkidle'}) + const requestPromise = page.waitForResponse((req) => { + const url = new URL(req.url()); + return url.pathname.includes('assets/index') + && url.pathname.endsWith('.js') + }, {timeout: 2000}); + const [_, jsfile] = await Promise.all([load, requestPromise]); + expect(jsfile?.headers()['content-encoding']).toBe('zstd') + }); + test('does not compress by default', async ({page, bs}) => { + const load = page.goto(bs.named('react-router', '/'), {waitUntil: 'networkidle'}) + const requestPromise = page.waitForResponse((req) => { + const url = new URL(req.url()); + return url.pathname.includes('assets/index') + && url.pathname.endsWith('.js') + }, {timeout: 2000}); + const [_, jsfile] = await Promise.all([load, requestPromise]); + expect(jsfile?.headers()['content-encoding']).toBeUndefined() + }); }) diff --git a/tests/utils.ts b/tests/utils.ts index 8aab176..4fde22a 100644 --- a/tests/utils.ts +++ b/tests/utils.ts @@ -44,6 +44,7 @@ export const test = base.extend<{ servers: { url: string }[], child: any; path: (path: string) => string; + named: (name: string, path: string) => string; stdout: string[]; touch: (path: string) => void; // next: (args: NextArgs) => Promise; @@ -117,7 +118,7 @@ export const test = base.extend<{ }) const data = await servers_changed_msg; const servers = data.servers.map(s => { - return {url: 'http://' + s.socket_addr} + return {url: 'http://' + s.socket_addr, identity: s.identity} }); await use({ @@ -130,6 +131,17 @@ export const test = base.extend<{ const url = new URL(path, servers[0].url); return url.toString() }, + named(server_name: string, path: string) { + const server = servers.find(x => { + if (x.identity.kind === "Named") { + return x.identity.payload.name === server_name + } + return false + }); + if (!server) throw new Error('server not found with name: ' + server_name); + const url = new URL(path, server.url); + return url.toString() + }, stdout, touch: (path: string) => { touchFile(join(cwd, path));