diff --git a/source-http-ingest/Cargo.toml b/source-http-ingest/Cargo.toml index 52d7772994..84aabd53eb 100644 --- a/source-http-ingest/Cargo.toml +++ b/source-http-ingest/Cargo.toml @@ -34,7 +34,11 @@ utoipa-swagger-ui = { version = "7.1", features = ["axum"] } time = { version = "0.3", features = ["formatting"] } uuid = { version = "1.7", features = ["v4"] } lazy_static = "1.4" -tower-http = { version = "0.5", features = ["decompression-full", "trace"] } +tower-http = { version = "0.5", features = [ + "decompression-full", + "trace", + "cors", +] } tower = "0.4" [dev-dependencies] diff --git a/source-http-ingest/src/lib.rs b/source-http-ingest/src/lib.rs index 8fa27bfc28..ce5eef2a2a 100644 --- a/source-http-ingest/src/lib.rs +++ b/source-http-ingest/src/lib.rs @@ -42,6 +42,15 @@ pub struct EndpointConfig { #[serde(default)] #[schemars(default = "paths_schema_default", schema_with = "paths_schema")] paths: Vec, + + /// List of allowed CORS origins. If empty, then CORS will be disabled. Otherwise, each item + /// in the list will be interpreted as a specific request origin that will be permitted by the + /// `Access-Control-Allow-Origin` header for preflight requests coming from that origin. As a special + /// case, the value `*` is permitted in order to allow all origins. The `*` should be used with extreme + /// caution, however. See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin + #[serde(default)] + #[schemars(default, schema_with = "cors_schema")] + allowed_cors_origins: Vec, } /// Sets the default value that's used only in the JSON schema. This is _not_ the default that's used @@ -58,7 +67,23 @@ fn paths_schema(_gen: &mut gen::SchemaGenerator) -> schema::Schema { "items": { "type": "string", "pattern": "/.+", - } + }, + "order": 1 + })) + .unwrap() +} + +fn cors_schema(_gen: &mut gen::SchemaGenerator) -> schema::Schema { + // This schema is a little more permissive than would otherwise be ideal. + // We'd like to use something like `oneOf: [{format: hostname}, {const: '*'}]`, + // but the UI does not handle that construct well. + serde_json::from_value(serde_json::json!({ + "title": "CORS Allowed Origins", + "type": "array", + "items": { + "type": "string" + }, + "order": 3 })) .unwrap() } @@ -68,6 +93,7 @@ fn require_auth_token_schema(_gen: &mut gen::SchemaGenerator) -> schema::Schema "title": "Authentication token", "type": ["string", "null"], "secret": true, + "order": 2 })) .unwrap() } @@ -302,6 +328,11 @@ async fn do_validate( // Check to make sure we can successfully create an openapi spec server::openapi_spec(&config, &typed_bindings) .context("cannot create openapi spec from bindings")?; + + // Ensure that cors origins are valid + let _ = server::parse_cors_allowed_origins(&config.allowed_cors_origins) + .context("invalid allowedCorsOrigins value")?; + let response = Response { validated: Some(Validated { bindings: output }), ..Default::default() @@ -415,6 +446,7 @@ mod test { let config = EndpointConfig { require_auth_token: None, paths: vec!["/foo".to_string(), "/bar/baz".to_string()], + allowed_cors_origins: Vec::new(), }; let result = generate_discover_response(config).unwrap(); insta::assert_json_snapshot!(result); diff --git a/source-http-ingest/src/server.rs b/source-http-ingest/src/server.rs index ed61b7233d..db9436d643 100644 --- a/source-http-ingest/src/server.rs +++ b/source-http-ingest/src/server.rs @@ -7,11 +7,11 @@ use axum::{ routing, Router, }; use doc::Annotation; -use http::status::StatusCode; +use http::{header::InvalidHeaderValue, status::StatusCode, HeaderValue}; use json::validator::Validator; use models::RawValue; use serde_json::Value; -use tower_http::decompression::RequestDecompressionLayer; +use tower_http::{cors, decompression::RequestDecompressionLayer}; use utoipa::openapi::{self, schema, security, OpenApi, OpenApiBuilder}; use utoipa_swagger_ui::SwaggerUi; @@ -22,6 +22,27 @@ use std::{ use tokio::io; use tokio::sync::Mutex; +pub fn parse_cors_allowed_origins( + cors_allow_origins: &[String], +) -> anyhow::Result> { + if cors_allow_origins.is_empty() { + Ok(None) + } else if cors_allow_origins.iter().any(|o| o.trim() == "*") { + anyhow::ensure!( + cors_allow_origins.len() == 1, + "cannot specify multiple allowed cors origins if using '*' to allow all" + ); + Ok(Some(cors::AllowOrigin::any())) + } else { + let list = cors_allow_origins + .iter() + .map(|origin| HeaderValue::from_str(origin.trim())) + .collect::, InvalidHeaderValue>>() + .context("invalid cors allowed origin value")?; + Ok(Some(cors::AllowOrigin::list(list))) + } +} + pub async fn run_server( endpoint_config: EndpointConfig, bindings: Vec, @@ -30,10 +51,12 @@ pub async fn run_server( ) -> anyhow::Result<()> { let openapi_spec = openapi_spec(&endpoint_config, &bindings).context("creating openapi spec")?; + let cors_allow_origin = parse_cors_allowed_origins(&endpoint_config.allowed_cors_origins) + .expect("allowedCorsOrigins must be valid"); let handler = Handler::try_new(stdin, stdout, endpoint_config, bindings)?; - let router = Router::new() + let mut router = Router::new() .merge(SwaggerUi::new("/swagger-ui").url("/api-doc/openapi.json", openapi_spec)) // The root path redirects to the swagger ui, so that a user who clicks a link to just // the hostname will be redirected to a more useful page. @@ -68,8 +91,20 @@ pub async fn run_server( span.record("handler_time_ms", latency.as_millis()); }, ), - ) - .with_state(Arc::new(handler)); + ); + + if let Some(allowed_origins) = cors_allow_origin { + let cors = tower_http::cors::CorsLayer::new() + .allow_origin(allowed_origins) + .allow_methods(tower_http::cors::AllowMethods::list([ + http::Method::POST, + http::Method::PUT, + ])) + .allow_headers(tower_http::cors::AllowHeaders::mirror_request()); + router = router.layer(cors); + } + + let router = router.with_state(Arc::new(handler)); let address = std::net::SocketAddr::from((std::net::Ipv4Addr::UNSPECIFIED, listen_on_port())); let listener = tokio::net::TcpListener::bind(address) @@ -618,6 +653,7 @@ mod test { let endpoint_config = EndpointConfig { require_auth_token: Some("testToken".to_string()), paths: Vec::new(), + allowed_cors_origins: Vec::new(), }; let binding0 = Binding { collection: serde_json::from_value(serde_json::json!({ diff --git a/source-http-ingest/src/snapshots/source_http_ingest__test__endpoint_config_schema.snap b/source-http-ingest/src/snapshots/source_http_ingest__test__endpoint_config_schema.snap index 26f1da82ba..2306b1be8f 100644 --- a/source-http-ingest/src/snapshots/source_http_ingest__test__endpoint_config_schema.snap +++ b/source-http-ingest/src/snapshots/source_http_ingest__test__endpoint_config_schema.snap @@ -7,6 +7,16 @@ expression: schema "title": "EndpointConfig", "type": "object", "properties": { + "allowedCorsOrigins": { + "title": "CORS Allowed Origins", + "description": "List of allowed CORS origins. If empty, then CORS will be disabled. Otherwise, each item in the list will be interpreted as a specific request origin that will be permitted by the `Access-Control-Allow-Origin` header for preflight requests coming from that origin. As a special case, the value `*` is permitted in order to allow all origins. The `*` should be used with extreme caution, however. See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin", + "default": [], + "type": "array", + "items": { + "type": "string" + }, + "order": 3 + }, "paths": { "title": "URL paths", "description": "List of URL paths to accept requests at.\n\nDiscovery will return a separate collection for each given path. Paths must be provided without any percent encoding, and should not include any query parameters or fragment.", @@ -17,7 +27,8 @@ expression: schema "items": { "type": "string", "pattern": "/.+" - } + }, + "order": 1 }, "requireAuthToken": { "title": "Authentication token", @@ -27,6 +38,7 @@ expression: schema "string", "null" ], + "order": 2, "secret": true } }