Skip to content

Commit

Permalink
source-http-ingest: add CORS support
Browse files Browse the repository at this point in the history
Supports CORS preflight requests on an opt-in basis. The goal is to allow the
users to capture data directly from browsers. A list of allowed origins was
added to the endpoint config. If left empty (the default), then cors will be
disabled entirely. A non-empty list will enable cors for the specific origins
listed. As a special case, `*` may be used to permit any origin.
  • Loading branch information
psFried committed Nov 27, 2024
1 parent 3b4bafd commit dfeeb53
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 6 deletions.
6 changes: 5 additions & 1 deletion source-http-ingest/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ utoipa-swagger-ui = { version = "7.1", features = ["axum"] }
time = { version = "0.3", features = ["formatting"] }
uuid = { version = "1.7", features = ["v4"] }
lazy_static = "1.4"
tower-http = { version = "0.5", features = ["decompression-full", "trace"] }
tower-http = { version = "0.5", features = [
"decompression-full",
"trace",
"cors",
] }
tower = "0.4"

[dev-dependencies]
Expand Down
29 changes: 29 additions & 0 deletions source-http-ingest/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,15 @@ pub struct EndpointConfig {
#[serde(default)]
#[schemars(default = "paths_schema_default", schema_with = "paths_schema")]
paths: Vec<String>,

/// List of allowed CORS origins. If empty, then CORS will be disabled. Otherwise, each item
/// in the list will be interpreted as a specific request origin that will be permitted by the
/// `Access-Control-Allow-Origin` header for preflight requests coming from that origin. As a special
/// case, the value `*` is permitted in order to allow all origins. The `*` should be used with extreme
/// caution, however. See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin
#[serde(default)]
#[schemars(default, schema_with = "cors_schema")]
allowed_cors_origins: Vec<String>,
}

/// Sets the default value that's used only in the JSON schema. This is _not_ the default that's used
Expand All @@ -63,6 +72,20 @@ fn paths_schema(_gen: &mut gen::SchemaGenerator) -> schema::Schema {
.unwrap()
}

fn cors_schema(_gen: &mut gen::SchemaGenerator) -> schema::Schema {
// This schema is a little more permissive than would otherwise be ideal.
// We'd like to use something like `oneOf: [{format: hostname}, {const: '*'}]`,
// but the UI does not handle that construct well.
serde_json::from_value(serde_json::json!({
"title": "CORS Allowed Origins",
"type": "array",
"items": {
"type": "string"
}
}))
.unwrap()
}

fn require_auth_token_schema(_gen: &mut gen::SchemaGenerator) -> schema::Schema {
serde_json::from_value(serde_json::json!({
"title": "Authentication token",
Expand Down Expand Up @@ -302,6 +325,11 @@ async fn do_validate(
// Check to make sure we can successfully create an openapi spec
server::openapi_spec(&config, &typed_bindings)
.context("cannot create openapi spec from bindings")?;

// Ensure that cors origins are valid
let _ = server::parse_cors_allowed_origins(&config.allowed_cors_origins)
.context("invalid allowedCorsOrigins value")?;

let response = Response {
validated: Some(Validated { bindings: output }),
..Default::default()
Expand Down Expand Up @@ -415,6 +443,7 @@ mod test {
let config = EndpointConfig {
require_auth_token: None,
paths: vec!["/foo".to_string(), "/bar/baz".to_string()],
allowed_cors_origins: Vec::new(),
};
let result = generate_discover_response(config).unwrap();
insta::assert_json_snapshot!(result);
Expand Down
46 changes: 41 additions & 5 deletions source-http-ingest/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ use axum::{
routing, Router,
};
use doc::Annotation;
use http::status::StatusCode;
use http::{header::InvalidHeaderValue, status::StatusCode, HeaderValue};
use json::validator::Validator;
use models::RawValue;
use serde_json::Value;
use tower_http::decompression::RequestDecompressionLayer;
use tower_http::{cors, decompression::RequestDecompressionLayer};
use utoipa::openapi::{self, schema, security, OpenApi, OpenApiBuilder};
use utoipa_swagger_ui::SwaggerUi;

Expand All @@ -22,6 +22,27 @@ use std::{
use tokio::io;
use tokio::sync::Mutex;

pub fn parse_cors_allowed_origins(
cors_allow_origins: &[String],
) -> anyhow::Result<Option<cors::AllowOrigin>> {
if cors_allow_origins.is_empty() {
Ok(None)
} else if cors_allow_origins.iter().any(|o| o.trim() == "*") {
anyhow::ensure!(
cors_allow_origins.len() == 1,
"cannot specify multiple allowed cors origins if using '*' to allow all"
);
Ok(Some(cors::AllowOrigin::any()))
} else {
let list = cors_allow_origins
.iter()
.map(|origin| HeaderValue::from_str(origin.trim()))
.collect::<Result<Vec<HeaderValue>, InvalidHeaderValue>>()
.context("invalid cors allowed origin value")?;
Ok(Some(cors::AllowOrigin::list(list)))
}
}

pub async fn run_server(
endpoint_config: EndpointConfig,
bindings: Vec<Binding>,
Expand All @@ -30,10 +51,12 @@ pub async fn run_server(
) -> anyhow::Result<()> {
let openapi_spec =
openapi_spec(&endpoint_config, &bindings).context("creating openapi spec")?;
let cors_allow_origin = parse_cors_allowed_origins(&endpoint_config.allowed_cors_origins)
.expect("allowedCorsOrigins must be valid");

let handler = Handler::try_new(stdin, stdout, endpoint_config, bindings)?;

let router = Router::new()
let mut router = Router::new()
.merge(SwaggerUi::new("/swagger-ui").url("/api-doc/openapi.json", openapi_spec))
// The root path redirects to the swagger ui, so that a user who clicks a link to just
// the hostname will be redirected to a more useful page.
Expand Down Expand Up @@ -68,8 +91,20 @@ pub async fn run_server(
span.record("handler_time_ms", latency.as_millis());
},
),
)
.with_state(Arc::new(handler));
);

if let Some(allowed_origins) = cors_allow_origin {
let cors = tower_http::cors::CorsLayer::new()
.allow_origin(allowed_origins)
.allow_methods(tower_http::cors::AllowMethods::list([
http::Method::POST,
http::Method::PUT,
]))
.allow_headers(tower_http::cors::AllowHeaders::mirror_request());
router = router.layer(cors);
}

let router = router.with_state(Arc::new(handler));

let address = std::net::SocketAddr::from((std::net::Ipv4Addr::UNSPECIFIED, listen_on_port()));
let listener = tokio::net::TcpListener::bind(address)
Expand Down Expand Up @@ -618,6 +653,7 @@ mod test {
let endpoint_config = EndpointConfig {
require_auth_token: Some("testToken".to_string()),
paths: Vec::new(),
allowed_cors_origins: Vec::new(),
};
let binding0 = Binding {
collection: serde_json::from_value(serde_json::json!({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@ expression: schema
"title": "EndpointConfig",
"type": "object",
"properties": {
"allowedCorsOrigins": {
"title": "CORS Allowed Origins",
"description": "List of allowed CORS origins. If empty, then CORS will be disabled. Otherwise, each item in the list will be interpreted as a specific request origin that will be permitted by the `Access-Control-Allow-Origin` header for preflight requests coming from that origin. As a special case, the value `*` is permitted in order to allow all origins. The `*` should be used with extreme caution, however. See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin",
"default": [],
"type": "array",
"items": {
"type": "string"
}
},
"paths": {
"title": "URL paths",
"description": "List of URL paths to accept requests at.\n\nDiscovery will return a separate collection for each given path. Paths must be provided without any percent encoding, and should not include any query parameters or fragment.",
Expand Down

0 comments on commit dfeeb53

Please sign in to comment.