Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

source-http-ingest: add CORS support #2173

Merged
merged 1 commit into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion source-http-ingest/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ utoipa-swagger-ui = { version = "7.1", features = ["axum"] }
time = { version = "0.3", features = ["formatting"] }
uuid = { version = "1.7", features = ["v4"] }
lazy_static = "1.4"
tower-http = { version = "0.5", features = ["decompression-full", "trace"] }
tower-http = { version = "0.5", features = [
"decompression-full",
"trace",
"cors",
] }
tower = "0.4"

[dev-dependencies]
Expand Down
34 changes: 33 additions & 1 deletion source-http-ingest/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ pub struct EndpointConfig {
#[serde(default)]
#[schemars(default = "paths_schema_default", schema_with = "paths_schema")]
paths: Vec<String>,

/// List of allowed CORS origins. If empty, then CORS will be disabled. Otherwise, each item
/// in the list will be interpreted as a specific request origin that will be permitted by the
/// `Access-Control-Allow-Origin` header for preflight requests coming from that origin. As a special
/// case, the value `*` is permitted in order to allow all origins. The `*` should be used with extreme
/// caution, however. See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
#[serde(default)]
#[schemars(default, schema_with = "cors_schema")]
allowed_cors_origins: Vec<String>,
}

/// Sets the default value that's used only in the JSON schema. This is _not_ the default that's used
Expand All @@ -58,7 +67,23 @@ fn paths_schema(_gen: &mut gen::SchemaGenerator) -> schema::Schema {
"items": {
"type": "string",
"pattern": "/.+",
}
},
"order": 1
}))
.unwrap()
}

fn cors_schema(_gen: &mut gen::SchemaGenerator) -> schema::Schema {
// This schema is a little more permissive than would otherwise be ideal.
// We'd like to use something like `oneOf: [{format: hostname}, {const: '*'}]`,
// but the UI does not handle that construct well.
serde_json::from_value(serde_json::json!({
"title": "CORS Allowed Origins",
"type": "array",
"items": {
"type": "string"
},
"order": 3
}))
.unwrap()
}
Expand All @@ -68,6 +93,7 @@ fn require_auth_token_schema(_gen: &mut gen::SchemaGenerator) -> schema::Schema
"title": "Authentication token",
"type": ["string", "null"],
"secret": true,
"order": 2
}))
.unwrap()
}
Expand Down Expand Up @@ -302,6 +328,11 @@ async fn do_validate(
// Check to make sure we can successfully create an openapi spec
server::openapi_spec(&config, &typed_bindings)
.context("cannot create openapi spec from bindings")?;

// Ensure that cors origins are valid
let _ = server::parse_cors_allowed_origins(&config.allowed_cors_origins)
.context("invalid allowedCorsOrigins value")?;

let response = Response {
validated: Some(Validated { bindings: output }),
..Default::default()
Expand Down Expand Up @@ -415,6 +446,7 @@ mod test {
let config = EndpointConfig {
require_auth_token: None,
paths: vec!["/foo".to_string(), "/bar/baz".to_string()],
allowed_cors_origins: Vec::new(),
};
let result = generate_discover_response(config).unwrap();
insta::assert_json_snapshot!(result);
Expand Down
46 changes: 41 additions & 5 deletions source-http-ingest/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ use axum::{
routing, Router,
};
use doc::Annotation;
use http::status::StatusCode;
use http::{header::InvalidHeaderValue, status::StatusCode, HeaderValue};
use json::validator::Validator;
use models::RawValue;
use serde_json::Value;
use tower_http::decompression::RequestDecompressionLayer;
use tower_http::{cors, decompression::RequestDecompressionLayer};
use utoipa::openapi::{self, schema, security, OpenApi, OpenApiBuilder};
use utoipa_swagger_ui::SwaggerUi;

Expand All @@ -22,6 +22,27 @@ use std::{
use tokio::io;
use tokio::sync::Mutex;

pub fn parse_cors_allowed_origins(
cors_allow_origins: &[String],
) -> anyhow::Result<Option<cors::AllowOrigin>> {
if cors_allow_origins.is_empty() {
Ok(None)
} else if cors_allow_origins.iter().any(|o| o.trim() == "*") {
anyhow::ensure!(
cors_allow_origins.len() == 1,
"cannot specify multiple allowed cors origins if using '*' to allow all"
);
Ok(Some(cors::AllowOrigin::any()))
} else {
let list = cors_allow_origins
.iter()
.map(|origin| HeaderValue::from_str(origin.trim()))
.collect::<Result<Vec<HeaderValue>, InvalidHeaderValue>>()
.context("invalid cors allowed origin value")?;
Ok(Some(cors::AllowOrigin::list(list)))
}
}

pub async fn run_server(
endpoint_config: EndpointConfig,
bindings: Vec<Binding>,
Expand All @@ -30,10 +51,12 @@ pub async fn run_server(
) -> anyhow::Result<()> {
let openapi_spec =
openapi_spec(&endpoint_config, &bindings).context("creating openapi spec")?;
let cors_allow_origin = parse_cors_allowed_origins(&endpoint_config.allowed_cors_origins)
.expect("allowedCorsOrigins must be valid");

let handler = Handler::try_new(stdin, stdout, endpoint_config, bindings)?;

let router = Router::new()
let mut router = Router::new()
.merge(SwaggerUi::new("/swagger-ui").url("/api-doc/openapi.json", openapi_spec))
// The root path redirects to the swagger ui, so that a user who clicks a link to just
// the hostname will be redirected to a more useful page.
Expand Down Expand Up @@ -68,8 +91,20 @@ pub async fn run_server(
span.record("handler_time_ms", latency.as_millis());
},
),
)
.with_state(Arc::new(handler));
);

if let Some(allowed_origins) = cors_allow_origin {
let cors = tower_http::cors::CorsLayer::new()
.allow_origin(allowed_origins)
.allow_methods(tower_http::cors::AllowMethods::list([
http::Method::POST,
http::Method::PUT,
]))
.allow_headers(tower_http::cors::AllowHeaders::mirror_request());
router = router.layer(cors);
}

let router = router.with_state(Arc::new(handler));

let address = std::net::SocketAddr::from((std::net::Ipv4Addr::UNSPECIFIED, listen_on_port()));
let listener = tokio::net::TcpListener::bind(address)
Expand Down Expand Up @@ -618,6 +653,7 @@ mod test {
let endpoint_config = EndpointConfig {
require_auth_token: Some("testToken".to_string()),
paths: Vec::new(),
allowed_cors_origins: Vec::new(),
};
let binding0 = Binding {
collection: serde_json::from_value(serde_json::json!({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@ expression: schema
"title": "EndpointConfig",
"type": "object",
"properties": {
"allowedCorsOrigins": {
"title": "CORS Allowed Origins",
"description": "List of allowed CORS origins. If empty, then CORS will be disabled. Otherwise, each item in the list will be interpreted as a specific request origin that will be permitted by the `Access-Control-Allow-Origin` header for preflight requests coming from that origin. As a special case, the value `*` is permitted in order to allow all origins. The `*` should be used with extreme caution, however. See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin",
"default": [],
"type": "array",
"items": {
"type": "string"
},
"order": 3
},
"paths": {
"title": "URL paths",
"description": "List of URL paths to accept requests at.\n\nDiscovery will return a separate collection for each given path. Paths must be provided without any percent encoding, and should not include any query parameters or fragment.",
Expand All @@ -17,7 +27,8 @@ expression: schema
"items": {
"type": "string",
"pattern": "/.+"
}
},
"order": 1
},
"requireAuthToken": {
"title": "Authentication token",
Expand All @@ -27,6 +38,7 @@ expression: schema
"string",
"null"
],
"order": 2,
"secret": true
}
}
Expand Down
Loading