From 5a7c46fd3e5d88c9cb54e4323722adda9adac6c8 Mon Sep 17 00:00:00 2001 From: pmendes Date: Wed, 18 Sep 2024 18:22:48 +0100 Subject: [PATCH] Move authentication logic out of the daphne crate --- Cargo.lock | 1 + Cargo.toml | 3 +- Makefile | 7 +- crates/dapf/src/acceptance/mod.rs | 3 +- crates/daphne-server/Cargo.toml | 3 +- .../examples/configuration-helper.toml | 7 +- .../examples/configuration-leader.toml | 8 +- crates/daphne-server/src/lib.rs | 81 +++- crates/daphne-server/src/roles/aggregator.rs | 115 +---- crates/daphne-server/src/roles/helper.rs | 3 +- crates/daphne-server/src/roles/leader.rs | 139 +++--- crates/daphne-server/src/roles/mod.rs | 101 ++++- crates/daphne-server/src/router/aggregator.rs | 18 +- crates/daphne-server/src/router/helper.rs | 3 +- crates/daphne-server/src/router/leader.rs | 15 +- crates/daphne-server/src/router/mod.rs | 406 ++++++++++++++---- .../src/storage_proxy_connection/kv/mod.rs | 52 ++- .../src/storage_proxy_connection/mod.rs | 2 +- crates/daphne-server/tests/e2e/e2e.rs | 29 +- crates/daphne-server/tests/e2e/test_runner.rs | 25 +- crates/daphne-service-utils/src/auth.rs | 124 ------ .../daphne-service-utils/src/bearer_token.rs | 57 +++ crates/daphne-service-utils/src/config.rs | 23 +- crates/daphne-service-utils/src/lib.rs | 3 +- crates/daphne/src/auth.rs | 170 -------- crates/daphne/src/error/aborts.rs | 2 +- crates/daphne/src/lib.rs | 11 +- crates/daphne/src/roles/aggregator.rs | 34 +- crates/daphne/src/roles/helper.rs | 37 +- crates/daphne/src/roles/leader/mod.rs | 53 +-- crates/daphne/src/roles/mod.rs | 303 +++---------- crates/daphne/src/taskprov.rs | 14 +- crates/daphne/src/testing/mod.rs | 106 +---- 33 files changed, 867 insertions(+), 1091 deletions(-) delete mode 100644 crates/daphne-service-utils/src/auth.rs create mode 100644 crates/daphne-service-utils/src/bearer_token.rs delete mode 100644 crates/daphne/src/auth.rs diff --git a/Cargo.lock b/Cargo.lock index 9f61bad11..e5be718b1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -955,6 +955,7 @@ dependencies = [ "daphne", "daphne-service-utils", "dhat", + "either", "futures", "hex", "hpke-rs", diff --git a/Cargo.toml b/Cargo.toml index 0c715bb2e..64b446ebf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,6 +46,7 @@ constcat = "0.5.0" criterion = { version = "0.5.1", features = ["async_tokio"] } deepsize = { version = "0.2.0" } dhat = "0.3.3" +either = "1.13.0" futures = "0.3.30" getrandom = "0.2.15" headers = "0.4" @@ -86,9 +87,9 @@ tracing = "0.1.40" tracing-core = "0.1.32" tracing-subscriber = "0.3.18" url = { version = "2.5.2", features = ["serde"] } +wasm-streams = "0.4" webpki = "0.22.4" worker = { version = "0.3.3", features = ["http"] } -wasm-streams = "0.4" x509-parser = "0.15.1" [workspace.dependencies.sentry] diff --git a/Makefile b/Makefile index 8516dfc2f..8688938d7 100644 --- a/Makefile +++ b/Makefile @@ -18,7 +18,12 @@ s: storage-proxy e2e: /tmp/private-key /tmp/certificate export HPKE_SIGNING_KEY="$$(cat /tmp/private-key)"; \ export E2E_TEST_HPKE_SIGNING_CERTIFICATE="$$(cat /tmp/certificate)"; \ - docker compose -f ./crates/daphne-server/docker-compose-e2e.yaml up --build --abort-on-container-exit --exit-code-from test + docker compose -f ./crates/daphne-server/docker-compose-e2e.yaml up \ + --no-attach leader_storage \ + --no-attach helper_storage \ + --build \ + --abort-on-container-exit \ + --exit-code-from test build_interop: docker build . -f ./interop/Dockerfile.interop_helper --tag daphne-interop diff --git a/crates/dapf/src/acceptance/mod.rs b/crates/dapf/src/acceptance/mod.rs index d2852b1c1..a2ecd0ad9 100644 --- a/crates/dapf/src/acceptance/mod.rs +++ b/crates/dapf/src/acceptance/mod.rs @@ -24,7 +24,6 @@ use crate::{ use anyhow::{anyhow, bail, Context, Result}; use async_trait::async_trait; use daphne::{ - auth::BearerToken, constants::DapMediaType, error::aborts::ProblemDetails, hpke::{HpkeConfig, HpkeKemId, HpkeReceiverConfig}, @@ -40,7 +39,7 @@ use daphne::{ DapQueryConfig, DapTaskConfig, DapTaskParameters, DapVersion, EarlyReportStateConsumed, EarlyReportStateInitialized, ReplayProtection, }; -use daphne_service_utils::http_headers; +use daphne_service_utils::{bearer_token::BearerToken, http_headers}; use futures::{future::OptionFuture, StreamExt, TryStreamExt}; use prio::codec::{Decode, ParameterizedEncode}; use prometheus::{Encoder, HistogramVec, IntCounterVec, IntGaugeVec, TextEncoder}; diff --git a/crates/daphne-server/Cargo.toml b/crates/daphne-server/Cargo.toml index efec36477..5df2e41a0 100644 --- a/crates/daphne-server/Cargo.toml +++ b/crates/daphne-server/Cargo.toml @@ -16,6 +16,7 @@ description = "Workers backend for Daphne" axum = "0.6.0" # held back to use http 0.2 daphne = { path = "../daphne" } daphne-service-utils = { path = "../daphne-service-utils", features = ["durable_requests"] } +either.workspace = true futures.workspace = true hex.workspace = true http = "0.2" # held back to use http 0.2 @@ -43,7 +44,7 @@ assert_matches.workspace = true clap.workspace = true config.workspace = true daphne = { path = "../daphne", features = ["test-utils"] } -daphne-service-utils = { path = "../daphne-service-utils", features = ["prometheus"] } +daphne-service-utils = { path = "../daphne-service-utils", features = ["prometheus", "test-utils"] } dhat.workspace = true hpke-rs.workspace = true paste.workspace = true diff --git a/crates/daphne-server/examples/configuration-helper.toml b/crates/daphne-server/examples/configuration-helper.toml index 56673c4ec..7385737eb 100644 --- a/crates/daphne-server/examples/configuration-helper.toml +++ b/crates/daphne-server/examples/configuration-helper.toml @@ -22,13 +22,8 @@ allow_taskprov = true default_num_agg_span_shards = 4 [service.taskprov] +peer_auth.leader.expected_token = "I-am-the-leader" # SECRET vdaf_verify_key_init = "b029a72fa327931a5cb643dcadcaafa098fcbfac07d990cb9e7c9a8675fafb18" # SECRET -leader_auth = """{ - "bearer_token": "I-am-the-leader" -}""" # SECRET -collector_auth = """{ - "bearer_token": "I-am-the-collector" -}""" # SECRET hpke_collector_config = """{ "id": 23, "kem_id": "p256_hkdf_sha256", diff --git a/crates/daphne-server/examples/configuration-leader.toml b/crates/daphne-server/examples/configuration-leader.toml index b9b525761..469bd3515 100644 --- a/crates/daphne-server/examples/configuration-leader.toml +++ b/crates/daphne-server/examples/configuration-leader.toml @@ -22,13 +22,9 @@ allow_taskprov = true default_num_agg_span_shards = 4 [service.taskprov] +peer_auth.collector.expected_token = "I-am-the-collector" # SECRET +self_bearer_token = "I-am-the-leader" # SECRET vdaf_verify_key_init = "b029a72fa327931a5cb643dcadcaafa098fcbfac07d990cb9e7c9a8675fafb18" # SECRET -leader_auth = """{ - "bearer_token": "I-am-the-leader" -}""" # SECRET -collector_auth = """{ - "bearer_token": "I-am-the-collector" -}""" # SECRET hpke_collector_config = """{ "id": 23, "kem_id": "p256_hkdf_sha256", diff --git a/crates/daphne-server/src/lib.rs b/crates/daphne-server/src/lib.rs index d992c9655..bce49cf95 100644 --- a/crates/daphne-server/src/lib.rs +++ b/crates/daphne-server/src/lib.rs @@ -5,12 +5,19 @@ use std::sync::Arc; use daphne::{ audit_log::{AuditLog, NoopAuditLog}, - auth::BearerToken, - roles::leader::in_memory_leader::InMemoryLeaderState, - DapError, + fatal_error, + messages::{Base64Encode, TaskId}, + roles::{leader::in_memory_leader::InMemoryLeaderState, DapAggregator}, + DapError, DapSender, }; -use daphne_service_utils::{config::DaphneServiceConfig, metrics::DaphneServiceMetrics}; +use daphne_service_utils::{ + bearer_token::BearerToken, + config::{DaphneServiceConfig, PeerBearerToken}, + metrics::DaphneServiceMetrics, +}; +use either::Either::{self, Left, Right}; use futures::lock::Mutex; +use roles::BearerTokens; use serde::{Deserialize, Serialize}; use storage_proxy_connection::{kv, Do, Kv}; use tokio::sync::RwLock; @@ -38,7 +45,11 @@ mod storage_proxy_connection; /// use url::Url; /// use daphne::{DapGlobalConfig, hpke::HpkeKemId, DapVersion}; /// use daphne_server::{App, router, StorageProxyConfig}; -/// use daphne_service_utils::{config::DaphneServiceConfig, DapRole, metrics::DaphnePromServiceMetrics}; +/// use daphne_service_utils::{ +/// config::{DaphneServiceConfig, PeerAuth}, +/// DapRole, +/// metrics::DaphnePromServiceMetrics +/// }; /// /// let storage_proxy_settings = StorageProxyConfig { /// url: Url::parse("http://example.com").unwrap(), @@ -51,7 +62,7 @@ mod storage_proxy_connection; /// min_batch_interval_start: 259_200, /// max_batch_interval_end: 259_200, /// supported_hpke_kems: vec![HpkeKemId::X25519HkdfSha256], -/// allow_taskprov: true, +/// allow_taskprov: false, /// default_num_agg_span_shards: NonZeroUsize::new(2).unwrap(), /// }; /// let service_config = DaphneServiceConfig { @@ -92,6 +103,7 @@ pub struct StorageProxyConfig { pub auth_token: BearerToken, } +#[axum::async_trait] impl router::DaphneService for App { fn server_metrics(&self) -> &dyn DaphneServiceMetrics { &*self.metrics @@ -100,6 +112,59 @@ impl router::DaphneService for App { fn signing_key(&self) -> Option<&p256::ecdsa::SigningKey> { self.service_config.signing_key.as_ref() } + + async fn check_bearer_token( + &self, + presented_token: &BearerToken, + sender: DapSender, + task_id: TaskId, + is_taskprov: bool, + ) -> Result<(), Either> { + let reject = |extra_args| { + Err(Left(format!( + "the indicated bearer token is incorrect for the {sender:?} {extra_args}", + ))) + }; + if let Some(taskprov) = self + .service_config + .taskprov + .as_ref() + // we only use taskprov auth if it's allowed by config and if the request is using taskprov + .filter(|_| self.service_config.global.allow_taskprov && is_taskprov) + { + match (&taskprov.peer_auth, sender) { + (PeerBearerToken::Leader { expected_token }, DapSender::Leader) + | (PeerBearerToken::Collector { expected_token }, DapSender::Collector) + if expected_token == presented_token => + { + Ok(()) + } + _ => reject(format_args!("using taskprov")), + } + } else if self + .bearer_tokens() + .matches(sender, task_id, presented_token) + .await + .map_err(|e| { + Right(fatal_error!( + err = ?e, + "internal error occurred while running authentication" + )) + })? + { + Ok(()) + } else { + reject(format_args!("with task_id {}", task_id.to_base64url())) + } + } + + async fn is_taskprov(&self, req: &daphne::DapRequest) -> Result { + Ok(req.taskprov.is_some() + || self + .get_task_config_for(&req.task_id.unwrap()) + .await? + .is_some_and(|task_config| task_config.method_is_taskprov())) + } } impl App { @@ -137,4 +202,8 @@ impl App { pub(crate) fn kv(&self) -> Kv<'_> { Kv::new(&self.storage_proxy_config, &self.http, &self.cache) } + + pub(crate) fn bearer_tokens(&self) -> BearerTokens<'_> { + BearerTokens::from(Kv::new(&self.storage_proxy_config, &self.http, &self.cache)) + } } diff --git a/crates/daphne-server/src/roles/aggregator.rs b/crates/daphne-server/src/roles/aggregator.rs index f748dee58..daee9dbd1 100644 --- a/crates/daphne-server/src/roles/aggregator.rs +++ b/crates/daphne-server/src/roles/aggregator.rs @@ -1,12 +1,11 @@ // Copyright (c) 2024 Cloudflare, Inc. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause -use std::{borrow::Cow, future::ready, num::NonZeroUsize, ops::Range, time::SystemTime}; +use std::{future::ready, num::NonZeroUsize, ops::Range, time::SystemTime}; use axum::async_trait; use daphne::{ audit_log::AuditLog, - auth::{BearerToken, BearerTokenProvider}, error::DapAbort, fatal_error, hpke::{HpkeConfig, HpkeDecrypter, HpkeProvider}, @@ -16,11 +15,8 @@ use daphne::{ taskprov, DapAggregateShare, DapAggregateSpan, DapAggregationParam, DapError, DapGlobalConfig, DapRequest, DapTaskConfig, DapVersion, EarlyReportStateConsumed, EarlyReportStateInitialized, }; -use daphne_service_utils::{ - auth::DaphneAuth, - durable_requests::bindings::{ - self, AggregateStoreMergeOptions, AggregateStoreMergeReq, AggregateStoreMergeResp, - }, +use daphne_service_utils::durable_requests::bindings::{ + self, AggregateStoreMergeOptions, AggregateStoreMergeReq, AggregateStoreMergeResp, }; use futures::{future::try_join_all, StreamExt, TryStreamExt}; use mappable_rc::Marc; @@ -32,7 +28,7 @@ use crate::{ }; #[async_trait] -impl DapAggregator for crate::App { +impl DapAggregator for crate::App { #[tracing::instrument(skip(self, task_config, agg_share_span))] async fn try_put_agg_share_span( &self, @@ -151,54 +147,6 @@ impl DapAggregator for crate::App { where Self: 'a; - async fn unauthorized_reason( - &self, - task_config: &DapTaskConfig, - req: &DapRequest, - ) -> Result, DapError> { - let mut authorized = false; - - let Some(ref sender_auth) = req.sender_auth else { - return Ok(Some("Missing authorization.".into())); - }; - - // If a bearer token is present, verify that it can be used to authorize the request. - if sender_auth.bearer_token.is_some() { - if let Some(unauthorized_reason) = - self.bearer_token_authorized(task_config, req).await? - { - return Ok(Some(unauthorized_reason)); - } - authorized = true; - } - - // If a TLS client certificate is present verify that it is valid. - if let Some(ref cf_tls_client_auth) = sender_auth.cf_tls_client_auth { - // TODO(cjpatton) Add support for TLS client authentication for non-Taskprov tasks. - let Some(ref _taskprov_config) = self.service_config.taskprov else { - return Ok(Some( - "TLS client authentication is currently only supported with Taskprov.".into(), - )); - }; - - // Check that that the certificate is valid. This is indicated by literal "SUCCESS". - if cf_tls_client_auth.verified != "SUCCESS" { - return Ok(Some(format!( - "Invalid TLS certificate ({}).", - cf_tls_client_auth.verified - ))); - } - - authorized = true; - } - - if authorized { - Ok(None) - } else { - Ok(Some("No suitable authorization method was found.".into())) - } - } - async fn get_global_config(&self) -> Result { let mut global_config = self.service_config.global.clone(); @@ -278,7 +226,7 @@ impl DapAggregator for crate::App { async fn taskprov_put( &self, - req: &DapRequest, + req: &DapRequest, task_config: DapTaskConfig, ) -> Result<(), DapError> { let task_id = req.task_id().map_err(DapError::Abort)?; @@ -524,56 +472,3 @@ impl HpkeDecrypter for crate::App { .ok_or(DapError::Transition(TransitionFailure::HpkeUnknownConfigId))? } } - -#[async_trait] -impl BearerTokenProvider for crate::App { - type WrappedBearerToken<'a> = Cow<'a, BearerToken> - where Self: 'a; - - async fn get_leader_bearer_token_for<'s>( - &'s self, - task_id: &'s TaskId, - task_config: &DapTaskConfig, - ) -> std::result::Result>, DapError> { - if self.service_config.global.allow_taskprov && task_config.method_is_taskprov() { - if let Some(bearer_token) = self - .service_config - .taskprov - .as_ref() - .and_then(|c| c.leader_auth.bearer_token.as_ref()) - { - return Ok(Some(Cow::Borrowed(bearer_token))); - } - } - - self.kv() - .get_cloned::(task_id, &KvGetOptions::default()) - .await - .map_err(|e| fatal_error!(err = ?e, "failed to get the leader bearer token")) - .map(|r| r.map(Cow::Owned)) - } - - async fn get_collector_bearer_token_for<'s>( - &'s self, - task_id: &'s TaskId, - task_config: &DapTaskConfig, - ) -> std::result::Result>, DapError> { - if self.service_config.global.allow_taskprov && task_config.method_is_taskprov() { - if let Some(bearer_token) = self.service_config.taskprov.as_ref().and_then(|c| { - c.collector_auth - .as_ref() - .expect("collector auth method not set") - .bearer_token - .as_ref() - }) { - return Ok(Some(Cow::Borrowed(bearer_token))); - } - } - - self.kv() - .get_cloned::(task_id, &KvGetOptions::default()) - .await - .map_err(|e| fatal_error!(err = ?e, "failed to get the collector bearer token")) - .map(|r| r.map(Cow::Owned)) - } -} diff --git a/crates/daphne-server/src/roles/helper.rs b/crates/daphne-server/src/roles/helper.rs index 5d07a168c..f312c7441 100644 --- a/crates/daphne-server/src/roles/helper.rs +++ b/crates/daphne-server/src/roles/helper.rs @@ -3,7 +3,6 @@ use axum::async_trait; use daphne::roles::DapHelper; -use daphne_service_utils::auth::DaphneAuth; #[async_trait] -impl DapHelper for crate::App {} +impl DapHelper for crate::App {} diff --git a/crates/daphne-server/src/roles/leader.rs b/crates/daphne-server/src/roles/leader.rs index 66b58b983..0f6497d50 100644 --- a/crates/daphne-server/src/roles/leader.rs +++ b/crates/daphne-server/src/roles/leader.rs @@ -3,46 +3,24 @@ #![allow(unused_variables)] -use std::time::Instant; +use std::{borrow::Cow, time::Instant}; use axum::{async_trait, http::Method}; use daphne::{ - auth::BearerTokenProvider, constants::DapMediaType, error::DapAbort, fatal_error, messages::{BatchId, BatchSelector, Collection, CollectionJobId, Report, TaskId}, - roles::{leader::WorkItem, DapAggregator, DapAuthorizedSender, DapLeader}, - DapAggregationParam, DapCollectionJob, DapError, DapRequest, DapResponse, DapTaskConfig, + roles::{leader::WorkItem, DapAggregator, DapLeader}, + DapAggregationParam, DapCollectionJob, DapError, DapRequest, DapResponse, }; -use daphne_service_utils::{auth::DaphneAuth, http_headers}; +use daphne_service_utils::http_headers; +use http::StatusCode; use tracing::{error, info}; use url::Url; #[async_trait] -impl DapAuthorizedSender for crate::App { - async fn authorize( - &self, - task_id: &TaskId, - task_config: &DapTaskConfig, - media_type: &DapMediaType, - _payload: &[u8], - ) -> Result { - Ok(DaphneAuth { - bearer_token: Some( - self.authorize_with_bearer_token(task_id, task_config, media_type) - .await? - .into_owned(), - ), - // TODO Consider adding support for authorizing the request with TLS client - // certificates: https://developers.cloudflare.com/workers/runtime-apis/mtls/ - cf_tls_client_auth: None, - }) - } -} - -#[async_trait] -impl DapLeader for crate::App { +impl DapLeader for crate::App { async fn put_report(&self, report: &Report, task_id: &TaskId) -> Result<(), DapError> { let task_config = self .get_task_config_for(task_id) @@ -121,19 +99,11 @@ impl DapLeader for crate::App { self.test_leader_state.lock().await.enqueue_work(items) } - async fn send_http_post( - &self, - req: DapRequest, - url: Url, - ) -> Result { + async fn send_http_post(&self, req: DapRequest, url: Url) -> Result { self.send_http(req, Method::POST, url).await } - async fn send_http_put( - &self, - req: DapRequest, - url: Url, - ) -> Result { + async fn send_http_put(&self, req: DapRequest, url: Url) -> Result { self.send_http(req, Method::PUT, url).await } } @@ -141,11 +111,16 @@ impl DapLeader for crate::App { impl crate::App { async fn send_http( &self, - req: DapRequest, + req: DapRequest, method: Method, url: Url, ) -> Result { use reqwest::header::{self, HeaderMap, HeaderName, HeaderValue}; + let Some(task_id) = req.task_id else { + return Err(fatal_error!( + err = "Cannot authorize request with missing task ID" + )); + }; let content_type = req .media_type @@ -165,18 +140,48 @@ impl crate::App { .map_err(|e| fatal_error!(err = ?e, "failed to construct content-type header"))?, ); - if let Some(bearer_token) = req.sender_auth.and_then(|auth| auth.bearer_token) { - headers.insert( - HeaderName::from_static(http_headers::DAP_AUTH_TOKEN), - HeaderValue::from_str(bearer_token.as_ref()).map_err(|e| { - fatal_error!( - err = ?e, - "failed to construct {} header", - http_headers::DAP_AUTH_TOKEN - ) - })?, - ); - } + // TODO: it is posssible for a taskprov request to not contain this, for example, when + // sending a collect request to a helper, this parameter is not necessary since the task + // must have already been configured in a previous request. + // + // Therefore a better way to handle taskprov auth here must be designed. + let bearer_token = if req.taskprov.is_some() { + if let Some(bearer_token) = self + .service_config + .taskprov + .as_ref() + .and_then(|t| t.self_bearer_token.as_ref()) + { + Cow::Borrowed(bearer_token) + } else { + return Err(DapError::Abort(DapAbort::UnauthorizedRequest { + detail: format!( + "taskprov authentication not setup for authentication with peer at {url}", + ), + task_id: req.task_id.unwrap(), + })); + } + } else if let Some(bearer_token) = self + .bearer_tokens() + .get(daphne::DapSender::Leader, task_id) + .await + .map_err(|e| fatal_error!(err = ?e, "failed to get leader bearer token"))? + { + Cow::Owned(bearer_token) + } else { + return Err(DapError::Abort(DapAbort::UnauthorizedRequest { + detail: format!( + "no suitable authentication method found for authenticating with peer at {url}", + ), + task_id: req.task_id.unwrap(), + })); + }; + + headers.insert( + HeaderName::from_static(http_headers::DAP_AUTH_TOKEN), + HeaderValue::from_str(bearer_token.as_str()) + .map_err(|e| fatal_error!(err = ?e, "failed to construct authentication header"))?, + ); if let Some(taskprov_advertisement) = req.taskprov.as_deref() { headers.insert( @@ -224,19 +229,29 @@ impl crate::App { }) } else { error!("{}: request failed: {:?}", url, reqwest_resp); - if status == 400 { - if let Some(content_type) = - reqwest_resp.headers().get(reqwest::header::CONTENT_TYPE) - { - if content_type == "application/problem+json" { - error!( - "Problem details: {}", - reqwest_resp.text().await.map_err( - |e| fatal_error!(err = ?e, "failed to read body of helper error response") - )? - ); + match status { + StatusCode::BAD_REQUEST => { + if let Some(content_type) = + reqwest_resp.headers().get(reqwest::header::CONTENT_TYPE) + { + if content_type == "application/problem+json" { + error!( + "Problem details: {}", + reqwest_resp.text().await.map_err( + |e| fatal_error!(err = ?e, "failed to read body of helper error response") + )? + ); + } + } + } + StatusCode::UNAUTHORIZED => { + return Err(DapAbort::UnauthorizedRequest { + detail: format!("helper at {url} didn't authorize our request"), + task_id, } + .into()) } + _ => {} } Err(fatal_error!(err = "request aborted by peer")) } diff --git a/crates/daphne-server/src/roles/mod.rs b/crates/daphne-server/src/roles/mod.rs index 0d460be22..b0566d038 100644 --- a/crates/daphne-server/src/roles/mod.rs +++ b/crates/daphne-server/src/roles/mod.rs @@ -1,9 +1,13 @@ // Copyright (c) 2024 Cloudflare, Inc. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause -use daphne::ReplayProtection; +use daphne::{messages::TaskId, DapSender, ReplayProtection}; +use daphne_service_utils::bearer_token::BearerToken; -use crate::storage_proxy_connection::kv::{self, Kv, KvGetOptions}; +use crate::storage_proxy_connection::{ + self, + kv::{self, Kv, KvGetOptions}, +}; mod aggregator; mod helper; @@ -32,18 +36,85 @@ pub async fn fetch_replay_protection_override(kv: Kv<'_>) -> ReplayProtection { } } +/// Bearer token for for tasks configured manually or via the [ppm-dap-interop-test][interop] draft. +/// +/// [interop]: https://divergentdave.github.io/draft-dcook-ppm-dap-interop-test-design/draft-dcook-ppm-dap-interop-test-design.html +pub(crate) struct BearerTokens<'s> { + kv: kv::Kv<'s>, +} + +impl<'s> From> for BearerTokens<'s> { + fn from(kv: kv::Kv<'s>) -> Self { + Self { kv } + } +} + +impl BearerTokens<'_> { + #[cfg(feature = "test-utils")] + pub async fn put_if_not_exists( + &self, + role: DapSender, + task_id: TaskId, + token: BearerToken, + ) -> Result, storage_proxy_connection::Error> { + self.kv + .put_if_not_exists::(&(role, task_id).into(), token) + .await + } + + /// Checks if a presented token matches the expected token of a task. + /// + /// # Returns + /// + /// - `Ok(true)` if the task exists and the token matches + /// - `Ok(false)` if the task doesn't exist or the token doesn't match + /// - `Err(error)` if any io error occurs while fetching + pub async fn matches( + &self, + role: DapSender, + task_id: TaskId, + token: &BearerToken, + ) -> Result { + self.kv + .peek::( + &(role, task_id).into(), + &kv::KvGetOptions { + cache_not_found: false, + }, + |stored_token| stored_token == token, + ) + .await + .map(|s| s.is_some_and(|matches| matches)) + } + + pub async fn get( + &self, + role: DapSender, + task_id: TaskId, + ) -> Result, storage_proxy_connection::Error> { + self.kv + .get_cloned::( + &(role, task_id).into(), + &kv::KvGetOptions { + cache_not_found: false, + }, + ) + .await + } +} + #[cfg(feature = "test-utils")] mod test_utils { use daphne::{ - auth::BearerToken, fatal_error, hpke::{HpkeConfig, HpkeReceiverConfig}, messages::decode_base64url_vec, roles::DapAggregator, vdaf::{Prio3Config, VdafConfig}, - DapError, DapQueryConfig, DapTaskConfig, DapVersion, + DapError, DapQueryConfig, DapSender, DapTaskConfig, DapVersion, }; use daphne_service_utils::{ + bearer_token::BearerToken, test_route_types::{InternalTestAddTask, InternalTestEndpointForTask}, DapRole, }; @@ -61,7 +132,7 @@ mod test_utils { self.http .delete(self.storage_proxy_config.url.join(PURGE_STORAGE).unwrap()) - .bearer_auth(&self.storage_proxy_config.auth_token) + .bearer_auth(self.storage_proxy_config.auth_token.as_str()) .send() .await .map_err( @@ -77,7 +148,7 @@ mod test_utils { use daphne_service_utils::durable_requests::STORAGE_READY; self.http .get(self.storage_proxy_config.url.join(STORAGE_READY).unwrap()) - .bearer_auth(&self.storage_proxy_config.auth_token) + .bearer_auth(self.storage_proxy_config.auth_token.as_str()) .send() .await .map_err(|e| fatal_error!(err = ?e, "failed to send ready check request to storage proxy"))? @@ -155,30 +226,26 @@ mod test_utils { // Leader authentication token. let token = BearerToken::from(cmd.leader_authentication_token); if self - .kv() - .put_if_not_exists::(&cmd.task_id, token) + .bearer_tokens() + .put_if_not_exists(DapSender::Leader, cmd.task_id, token) .await - .map_err(|e| fatal_error!(err = ?e, "failed to fetch leader bearer token"))? - .is_some() + .is_err() { return Err(fatal_error!( err = "command failed: token already exists for the given task and bearer role (leader)", task_id = %cmd.task_id, )); - } + }; // Collector authentication token. match (cmd.role, cmd.collector_authentication_token) { (DapRole::Leader, Some(token_string)) => { let token = BearerToken::from(token_string); if self - .kv() - .put_if_not_exists::(&cmd.task_id, token) + .bearer_tokens() + .put_if_not_exists(DapSender::Collector, cmd.task_id, token) .await - .map_err( - |e| fatal_error!(err = ?e, "failed to put collector bearer token"), - )? - .is_some() + .is_err() { return Err(fatal_error!(err = format!( "command failed: token already exists for the given task ({}) and bearer role (collector)", diff --git a/crates/daphne-server/src/router/aggregator.rs b/crates/daphne-server/src/router/aggregator.rs index 9b66d15d3..4cd33bf26 100644 --- a/crates/daphne-server/src/router/aggregator.rs +++ b/crates/daphne-server/src/router/aggregator.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use axum::{ body::HttpBody, - extract::{Query, State}, + extract::{Path, Query, State}, response::{AppendHeaders, IntoResponse}, routing::get, }; @@ -14,17 +14,17 @@ use daphne::{ fatal_error, messages::{encode_base64url, TaskId}, roles::{aggregator, DapAggregator}, - DapError, DapResponse, + DapError, DapResponse, DapVersion, }; -use daphne_service_utils::{auth::DaphneAuth, http_headers}; +use daphne_service_utils::http_headers; use p256::ecdsa::{signature::Signer, Signature, SigningKey}; use serde::Deserialize; -use super::{AxumDapResponse, DapRequestExtractor, DaphneService}; +use super::{AxumDapResponse, DaphneService}; pub fn add_aggregator_routes(router: super::Router) -> super::Router where - A: DapAggregator + DaphneService + Send + Sync + 'static, + A: DapAggregator + DaphneService + Send + Sync + 'static, B: Send + HttpBody + 'static, B::Data: Send, B::Error: Send + Sync, @@ -38,16 +38,16 @@ struct QueryTaskId { task_id: Option, } -#[tracing::instrument(skip(app, req), fields(version = ?req.version))] +#[tracing::instrument(skip(app), fields(version, task_id))] async fn hpke_config( State(app): State>, Query(QueryTaskId { task_id }): Query, - DapRequestExtractor(req): DapRequestExtractor, + Path(version): Path, ) -> impl IntoResponse where - A: DapAggregator + DaphneService, + A: DapAggregator + DaphneService, { - match aggregator::handle_hpke_config_req(&*app, &req, task_id).await { + match aggregator::handle_hpke_config_req(&*app, version, task_id).await { Ok(resp) => match app.signing_key().map(|k| sign_dap_response(k, &resp)) { None => AxumDapResponse::new_success(resp, app.server_metrics()).into_response(), Some(Ok(signed)) => ( diff --git a/crates/daphne-server/src/router/helper.rs b/crates/daphne-server/src/router/helper.rs index 087c5370a..b672ed754 100644 --- a/crates/daphne-server/src/router/helper.rs +++ b/crates/daphne-server/src/router/helper.rs @@ -13,7 +13,6 @@ use daphne::{ error::DapAbort, roles::{helper, DapHelper}, }; -use daphne_service_utils::auth::DaphneAuth; use http::StatusCode; use crate::{roles::fetch_replay_protection_override, App}; @@ -80,7 +79,7 @@ async fn agg_share( DapRequestExtractor(req): DapRequestExtractor, ) -> AxumDapResponse where - A: DapHelper + DaphneService + Send + Sync, + A: DapHelper + DaphneService + Send + Sync, { AxumDapResponse::from_result( helper::handle_agg_share_req(&*app, &req).await, diff --git a/crates/daphne-server/src/router/leader.rs b/crates/daphne-server/src/router/leader.rs index 95208ef42..05e101017 100644 --- a/crates/daphne-server/src/router/leader.rs +++ b/crates/daphne-server/src/router/leader.rs @@ -16,14 +16,15 @@ use daphne::{ roles::leader::{self, DapLeader}, DapError, DapVersion, }; -use daphne_service_utils::auth::DaphneAuth; use prio::codec::ParameterizedEncode; -use super::{AxumDapResponse, DapRequestExtractor, DaphneService}; +use super::{ + AxumDapResponse, DapRequestExtractor, DaphneService, UnauthenticatedDapRequestExtractor, +}; pub(super) fn add_leader_routes(router: super::Router) -> super::Router where - A: DapLeader + DaphneService + Send + Sync + 'static, + A: DapLeader + DaphneService + Send + Sync + 'static, B: Send + HttpBody + 'static, B::Data: Send, B::Error: Send + Sync, @@ -49,10 +50,10 @@ where )] async fn upload( State(app): State>, - DapRequestExtractor(req): DapRequestExtractor, + UnauthenticatedDapRequestExtractor(req): UnauthenticatedDapRequestExtractor, ) -> Response where - A: DapLeader + DaphneService + Send + Sync, + A: DapLeader + DaphneService + Send + Sync, { match leader::handle_upload_req(&*app, &req).await { Ok(()) => StatusCode::OK.into_response(), @@ -72,7 +73,7 @@ async fn get_collect_uri( DapRequestExtractor(req): DapRequestExtractor, ) -> Response where - A: DapLeader + DaphneService + Send + Sync, + A: DapLeader + DaphneService + Send + Sync, { match (leader::handle_coll_job_req(&*app, &req).await, req.version) { (Ok(collect_uri), DapVersion::Draft09 | DapVersion::Latest) => { @@ -94,7 +95,7 @@ async fn collect( DapRequestExtractor(req): DapRequestExtractor, ) -> Response where - A: DapLeader + DaphneService + Send + Sync, + A: DapLeader + DaphneService + Send + Sync, { let task_id = match req.task_id() { Ok(id) => id, diff --git a/crates/daphne-server/src/router/mod.rs b/crates/daphne-server/src/router/mod.rs index bcf616c2c..d5acd9b0f 100644 --- a/crates/daphne-server/src/router/mod.rs +++ b/crates/daphne-server/src/router/mod.rs @@ -19,20 +19,20 @@ use axum::{ Json, }; use daphne::{ - auth::BearerToken, constants::DapMediaType, error::DapAbort, fatal_error, messages::{AggregationJobId, CollectionJobId, TaskId}, - DapError, DapRequest, DapResource, DapResponse, DapVersion, + DapError, DapRequest, DapResource, DapResponse, DapSender, DapVersion, }; use daphne_service_utils::{ - auth::{DaphneAuth, TlsClientAuth}, + bearer_token::BearerToken, http_headers, metrics::{self, DaphneServiceMetrics}, DapRole, }; -use http::Request; +use either::Either; +use http::{HeaderMap, Request}; use serde::Deserialize; use crate::App; @@ -40,6 +40,7 @@ use crate::App; type Router = axum::Router, B>; /// Capabilities necessary when running a native daphne service. +#[async_trait] pub trait DaphneService { /// The service metrics fn server_metrics(&self) -> &dyn DaphneServiceMetrics; @@ -47,11 +48,29 @@ pub trait DaphneService { fn signing_key(&self) -> Option<&p256::ecdsa::SigningKey> { None } + + /// Checks if a bearer token is accepted. + /// + /// # Errors + /// + /// Returns an either: + /// - left: error message with the reason why the token wasn't accepted. + /// - right: an internal error that made checking the token impossible. + async fn check_bearer_token( + &self, + presented_token: &BearerToken, + sender: DapSender, + task_id: TaskId, + is_taskprov: bool, + ) -> Result<(), Either>; + + async fn is_taskprov(&self, req: &DapRequest) -> Result; } +#[async_trait] impl DaphneService for Arc where - S: DaphneService, + S: DaphneService + Send + Sync, { fn server_metrics(&self) -> &dyn DaphneServiceMetrics { S::server_metrics(&**self) @@ -60,6 +79,20 @@ where fn signing_key(&self) -> Option<&p256::ecdsa::SigningKey> { S::signing_key(&**self) } + + async fn check_bearer_token( + &self, + presented_token: &BearerToken, + sender: DapSender, + task_id: TaskId, + is_taskprov: bool, + ) -> Result<(), Either> { + S::check_bearer_token(&**self, presented_token, sender, task_id, is_taskprov).await + } + + async fn is_taskprov(&self, req: &DapRequest) -> Result { + S::is_taskprov(&**self, req).await + } } pub fn new(role: DapRole, aggregator: App) -> axum::Router<(), B> @@ -211,10 +244,10 @@ impl IntoResponse for AxumDapResponse { /// An axum extractor capable of parsing a [`DapRequest`]. #[derive(Debug)] -struct DapRequestExtractor(pub DapRequest); +struct UnauthenticatedDapRequestExtractor(pub DapRequest); #[async_trait] -impl FromRequest for DapRequestExtractor +impl FromRequest for UnauthenticatedDapRequestExtractor where S: DaphneService + Send + Sync, B: HttpBody + Send + 'static, @@ -245,34 +278,6 @@ where .await .map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?; - let extract_header_as_string = |header: &'static str| -> Option { - parts - .headers - .get(header)? - .to_str() - .ok() - .map(ToString::to_string) - }; - - let sender_auth = DaphneAuth { - bearer_token: extract_header_as_string(http_headers::DAP_AUTH_TOKEN) - .map(BearerToken::from), - - cf_tls_client_auth: extract_header_as_string("X-Client-Cert-Verified") - .map(|verified| TlsClientAuth { verified }), - }; - - if sender_auth.bearer_token.is_some() { - state - .server_metrics() - .auth_method_inc(metrics::AuthMethod::BearerToken); - } - if sender_auth.cf_tls_client_auth.is_some() { - state - .server_metrics() - .auth_method_inc(metrics::AuthMethod::TlsClientAuth); - } - let media_type = if let Some(content_type) = parts.headers.get(CONTENT_TYPE) { let content_type = content_type.to_str().map_err(|_| { let msg = "header value contains non ascii or invisible characters".into(); @@ -285,7 +290,7 @@ where None }; - let taskprov = extract_header_as_string(http_headers::DAP_TASKPROV); + let taskprov = extract_header_as_string(&parts.headers, http_headers::DAP_TASKPROV); // TODO(mendess): this is very eager, we could redesign DapResponse later to allow for // streaming of data. @@ -324,21 +329,110 @@ where (task_id, resource) }; - Ok(DapRequestExtractor(DapRequest { + let request = DapRequest { version, task_id, resource, payload: payload.to_vec(), media_type, - sender_auth: Some(sender_auth), taskprov, - })) + }; + + Ok(UnauthenticatedDapRequestExtractor(request)) + } +} + +/// An axum extractor capable of parsing a [`DapRequest`]. +/// +/// This extractor asserts that the request is authenticated. +#[derive(Debug)] +struct DapRequestExtractor(pub DapRequest); + +#[async_trait] +impl FromRequest for DapRequestExtractor +where + S: DaphneService + Send + Sync, + B: HttpBody + Send + 'static, + ::Data: Send, +{ + type Rejection = (StatusCode, String); + + async fn from_request(req: Request, state: &S) -> Result { + let bearer_token = extract_header_as_string(req.headers(), http_headers::DAP_AUTH_TOKEN) + .map(BearerToken::from); + let cf_tls_client_auth = extract_header_as_string(req.headers(), "X-Client-Cert-Verified"); + + let request = UnauthenticatedDapRequestExtractor::from_request(req, state) + .await? + .0; + + let is_taskprov = dbg!(state + .is_taskprov(&request) + .await + .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "internal error".into()))?); + + let bearer_authed = if let Some((token, sender)) = + bearer_token.zip(request.media_type.map(|m| m.sender())) + { + state + .server_metrics() + .auth_method_inc(metrics::AuthMethod::BearerToken); + state + .check_bearer_token( + &token, + sender, + request.task_id.unwrap(), // task ids are mandatory + is_taskprov, + ) + .await + .map_err(|reason| { + reason.either( + |reason| (StatusCode::UNAUTHORIZED, reason), + |_| (StatusCode::INTERNAL_SERVER_ERROR, "internal error".into()), + ) + })?; + true + } else { + false + }; + let mtls_authed = if let Some(verification_result) = cf_tls_client_auth { + state + .server_metrics() + .auth_method_inc(metrics::AuthMethod::TlsClientAuth); + // we always check if mtls succedded even if ... + if verification_result != "SUCCESS" { + return Err(( + StatusCode::UNAUTHORIZED, + format!("Invalid TLS certificate ({verification_result})"), + )); + } + // ... we only allow mtls auth for taskprov tasks + is_taskprov + } else { + false + }; + + if bearer_authed || mtls_authed { + Ok(Self(request)) + } else { + Err(( + StatusCode::UNAUTHORIZED, + "No suitable authorization method was found".into(), + )) + } } } +fn extract_header_as_string(headers: &HeaderMap, header: &'static str) -> Option { + headers.get(header)?.to_str().ok().map(ToString::to_string) +} + #[cfg(test)] mod test { - use std::sync::{Arc, OnceLock}; + use std::{ + sync::{Arc, OnceLock}, + time::Duration, + }; use axum::{ body::{Body, HttpBody}, @@ -349,34 +443,46 @@ mod test { Router, }; use daphne::{ - async_test_version, async_test_versions, + async_test_versions, messages::{AggregationJobId, Base64Encode, TaskId}, - DapRequest, DapResource, DapVersion, + DapError, DapRequest, DapResource, DapSender, DapVersion, + }; + use daphne_service_utils::{ + bearer_token::BearerToken, http_headers, metrics::DaphnePromServiceMetrics, }; - use daphne_service_utils::{auth::DaphneAuth, metrics::DaphnePromServiceMetrics}; - use futures::future::BoxFuture; + use either::Either::{self, Left}; + use futures::{future::BoxFuture, FutureExt}; use rand::{thread_rng, Rng}; - use tokio::sync::mpsc::{self, Sender}; + use tokio::{ + sync::mpsc::{self, Sender}, + time::timeout, + }; use tower::ServiceExt; - use super::DapRequestExtractor; + use crate::router::DapRequestExtractor; + + use super::UnauthenticatedDapRequestExtractor; - /// Return a function that will parse a request using the [`DapRequestExtractor`] and return - /// the parsed request. + const BEARER_TOKEN: &str = "test-token"; + + /// Return a function that will parse a request using the [`DapRequestExtractor`] or + /// [`UnauthenticatedDapRequestExtractor`] and return the parsed request. /// /// The possible request URIs that are supported by this parser are: - /// - `/:version/parse-version` - /// - `/:version/:task_id/parse-task-id` - /// - `/:version/:agg_job_id/parse-agg-job-id` - /// - `/:version/:collect_job_id/parse-collect-job-id` - fn test_router() -> impl FnOnce(Request) -> BoxFuture<'static, DapRequest> + /// - `/:version/:task_id/auth` uses the [`DapRequestExtractor`] + /// - `/:version/:task_id/parse-mandatory-fields` uses the [`UnauthenticatedDapRequestExtractor`] + /// - `/:version/:agg_job_id/parse-agg-job-id` uses the [`UnauthenticatedDapRequestExtractor`] + /// - `/:version/:collect_job_id/parse-collect-job-id` uses the [`UnauthenticatedDapRequestExtractor`] + fn test_router( + ) -> impl FnOnce(Request) -> BoxFuture<'static, Result> where B: Send + Sync + 'static + HttpBody, B::Data: Send, B::Error: Send + Sync + std::error::Error, { - type Channel = Sender>; + type Channel = Sender; + #[axum::async_trait] impl super::DaphneService for Channel { fn server_metrics(&self) -> &dyn daphne_service_utils::metrics::DaphneServiceMetrics { // These tests don't care about metrics so we just store a static instance here so I @@ -390,30 +496,63 @@ mod test { fn signing_key(&self) -> Option<&p256::ecdsa::SigningKey> { None } - } - async fn handler( - State(ch): State>, - DapRequestExtractor(req): DapRequestExtractor, - ) -> impl IntoResponse { - ch.send(req).await.unwrap(); + async fn check_bearer_token( + &self, + token: &BearerToken, + _sender: DapSender, + _task_id: TaskId, + _is_taskprov: bool, + ) -> Result<(), Either> { + (token.as_str() == BEARER_TOKEN) + .then_some(()) + .ok_or_else(|| Left("invalid token".into())) + } + + async fn is_taskprov(&self, req: &DapRequest) -> Result { + Ok(req.taskprov.is_some()) + } } + // setup a channel to "smuggle" the parsed request out of a handler let (tx, mut rx) = mpsc::channel(1); + // create a router that takes the send end of the channel as state let router = Router::new() - .route("/:version/parse-version", get(handler)) - .route("/:version/:task_id/parse-task-id", get(handler)) - .route("/:version/:agg_job_id/parse-agg-job-id", get(handler)) + .route("/:version/:task_id/auth", get(auth_handler)) + .route("/:version/:task_id/parse-mandatory-fields", get(handler)) .route( - "/:version/:collect_job_id/parse-collect-job-id", + "/:version/:task_id/:agg_job_id/parse-agg-job-id", + get(handler), + ) + .route( + "/:version/:task_id/:collect_job_id/parse-collect-job-id", get(handler), ) .with_state(Arc::new(tx)); + // unauthenticated handler that simply sends the received request through the channel + async fn handler( + State(ch): State>, + UnauthenticatedDapRequestExtractor(req): UnauthenticatedDapRequestExtractor, + ) -> impl IntoResponse { + ch.send(req).await.unwrap(); + } + + // unauthenticated handler that simply sends the received request through the channel + async fn auth_handler( + State(ch): State>, + DapRequestExtractor(req): DapRequestExtractor, + ) -> impl IntoResponse { + ch.send(req).await.unwrap(); + } + move |req| { Box::pin(async move { - let resp = match router.oneshot(req).await { + let resp = match timeout(Duration::from_secs(1), router.oneshot(req)) + .await + .unwrap() + { Ok(resp) => resp, Err(i) => match i {}, }; @@ -428,71 +567,162 @@ mod test { ) ) } - code => assert_eq!(code, StatusCode::OK), + // get the request sent through the channel in the handler + StatusCode::OK => Ok(rx.recv().now_or_never().unwrap().unwrap()), + code => Err(code), } - - rx.recv().await.unwrap() }) } } - async fn parse_version(version: DapVersion) { + fn mk_task_id() -> TaskId { + TaskId(thread_rng().gen()) + } + + async fn parse_mandatory_fields(version: DapVersion) { let test = test_router(); + let task_id = mk_task_id(); let req = test( Request::builder() - .uri(format!("/{version}/parse-version")) + .uri(format!( + "/{version}/{}/parse-mandatory-fields", + task_id.to_base64url() + )) + .header(CONTENT_TYPE, "application/dap-aggregation-job-init-req") + .header(http_headers::DAP_AUTH_TOKEN, BEARER_TOKEN) .body(Body::empty()) .unwrap(), ) - .await; + .await + .unwrap(); assert_eq!(req.version, version); + assert_eq!(req.task_id.unwrap(), task_id); } - async_test_versions! { parse_version } + async_test_versions! { parse_mandatory_fields } - async fn parse_task_id(version: DapVersion) { + async fn parse_agg_job_id(version: DapVersion) { let test = test_router(); - let task_id = TaskId(thread_rng().gen()); + let task_id = mk_task_id(); + let agg_job_id = AggregationJobId(thread_rng().gen()); let req = test( Request::builder() .uri(format!( - "/{version}/{}/parse-task-id", - task_id.to_base64url() + "/{version}/{}/{}/parse-agg-job-id", + task_id.to_base64url(), + agg_job_id.to_base64url(), )) + .header(CONTENT_TYPE, "application/dap-aggregation-job-init-req") + .header(http_headers::DAP_AUTH_TOKEN, BEARER_TOKEN) .body(Body::empty()) .unwrap(), ) - .await; + .await + .unwrap(); - assert_eq!(req.task_id, Some(task_id)); + assert_eq!(req.resource, DapResource::AggregationJob(agg_job_id)); + assert_eq!(req.task_id.unwrap(), task_id); } - async_test_versions! { parse_task_id } + async_test_versions! { parse_agg_job_id } - async fn parse_agg_job_id(version: DapVersion) { + async fn incorrect_bearer_tokens_are_rejected(version: DapVersion) { let test = test_router(); - let agg_job_id = AggregationJobId(thread_rng().gen()); + let status_code = test( + Request::builder() + .uri(format!("/{version}/{}/auth", mk_task_id().to_base64url())) + .header(CONTENT_TYPE, "application/dap-aggregation-job-init-req") + .header(http_headers::DAP_AUTH_TOKEN, "something incorrect") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap_err(); + + assert_eq!(status_code, StatusCode::UNAUTHORIZED); + } + + async_test_versions! { incorrect_bearer_tokens_are_rejected } + + async fn missing_auth_is_rejected(version: DapVersion) { + let test = test_router(); + + let status_code = test( + Request::builder() + .uri(format!("/{version}/{}/auth", mk_task_id().to_base64url())) + .header(CONTENT_TYPE, "application/dap-aggregation-job-init-req") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap_err(); + + assert_eq!(status_code, StatusCode::UNAUTHORIZED); + } + + async_test_versions! { missing_auth_is_rejected } + + async fn mtls_auth_is_enough(version: DapVersion) { + let test = test_router(); let req = test( Request::builder() - .uri(format!( - "/{version}/{}/parse-agg-job-id", - agg_job_id.to_base64url() - )) + .uri(format!("/{version}/{}/auth", mk_task_id().to_base64url())) .header(CONTENT_TYPE, "application/dap-aggregation-job-init-req") + .header("X-Client-Cert-Verified", "SUCCESS") + .header(http_headers::DAP_TASKPROV, "some-taskprov-string") .body(Body::empty()) .unwrap(), ) .await; - assert_eq!(req.resource, DapResource::AggregationJob(agg_job_id)); + req.unwrap(); + } + + async_test_versions! { mtls_auth_is_enough } + + async fn incorrect_bearer_tokens_are_rejected_even_with_mtls_auth(version: DapVersion) { + let test = test_router(); + + let code = test( + Request::builder() + .uri(format!("/{version}/{}/auth", mk_task_id().to_base64url())) + .header(CONTENT_TYPE, "application/dap-aggregation-job-init-req") + .header(http_headers::DAP_AUTH_TOKEN, "something incorrect") + .header("X-Client-Cert-Verified", "SUCCESS") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap_err(); + + assert_eq!(code, StatusCode::UNAUTHORIZED); + } + + async_test_versions! { incorrect_bearer_tokens_are_rejected_even_with_mtls_auth } + + async fn invalid_mtls_auth_is_rejected_despite_correct_bearer_token(version: DapVersion) { + let test = test_router(); + + let code = test( + Request::builder() + .uri(format!("/{version}/{}/auth", mk_task_id().to_base64url())) + .header(CONTENT_TYPE, "application/dap-aggregation-job-init-req") + .header(http_headers::DAP_AUTH_TOKEN, BEARER_TOKEN) + .header("X-Client-Cert-Verified", "FAILED") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap_err(); + + assert_eq!(code, StatusCode::UNAUTHORIZED); } - async_test_version! { parse_agg_job_id, Draft09 } - async_test_version! { parse_agg_job_id, Latest } + async_test_versions! { invalid_mtls_auth_is_rejected_despite_correct_bearer_token } } diff --git a/crates/daphne-server/src/storage_proxy_connection/kv/mod.rs b/crates/daphne-server/src/storage_proxy_connection/kv/mod.rs index 387d8e643..1cbd349d3 100644 --- a/crates/daphne-server/src/storage_proxy_connection/kv/mod.rs +++ b/crates/daphne-server/src/storage_proxy_connection/kv/mod.rs @@ -33,10 +33,16 @@ pub trait KvPrefix { } pub mod prefix { - use std::{fmt::Display, marker::PhantomData}; - - use daphne::{auth::BearerToken, messages::TaskId, taskprov, DapTaskConfig, DapVersion}; - use daphne_service_utils::config::HpkeRecieverConfigList; + use std::{ + fmt::{self, Display}, + marker::PhantomData, + }; + + use daphne::{ + messages::{Base64Encode, TaskId}, + taskprov, DapSender, DapTaskConfig, DapVersion, + }; + use daphne_service_utils::{bearer_token::BearerToken, config::HpkeRecieverConfigList}; use serde::{de::DeserializeOwned, Serialize}; use super::KvPrefix; @@ -98,20 +104,32 @@ pub mod prefix { type Value = HpkeRecieverConfigList; } - pub struct LeaderBearerToken(); - impl KvPrefix for LeaderBearerToken { - const PREFIX: &'static str = "bearer_token/leader/task"; + pub struct KvBearerToken(); + impl KvPrefix for KvBearerToken { + const PREFIX: &'static str = "bearer_token"; - type Key = TaskId; + type Key = KvBearerTokenKey; type Value = BearerToken; } - pub struct CollectorBearerToken(); - impl KvPrefix for CollectorBearerToken { - const PREFIX: &'static str = "bearer_token/collector/task"; - - type Key = TaskId; - type Value = BearerToken; + #[derive(Debug)] + pub struct KvBearerTokenKey(DapSender, TaskId); + impl From<(DapSender, TaskId)> for KvBearerTokenKey { + fn from((s, t): (DapSender, TaskId)) -> Self { + Self(s, t) + } + } + impl fmt::Display for KvBearerTokenKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let Self(sender, task_id) = self; + let task_id = task_id.to_base64url(); + match sender { + DapSender::Client => write!(f, "client/task/{task_id}"), + DapSender::Collector => write!(f, "collector/task/{task_id}"), + DapSender::Helper => write!(f, "helper/task/{task_id}"), + DapSender::Leader => write!(f, "leader/task/{task_id}"), + } + } } } @@ -233,7 +251,7 @@ impl<'h> Kv<'h> { let resp = self .http .get(self.config.url.join(&key).unwrap()) - .bearer_auth(&self.config.auth_token) + .bearer_auth(self.config.auth_token.as_str()) .send() .await?; if resp.status() == StatusCode::NOT_FOUND { @@ -275,7 +293,7 @@ impl<'h> Kv<'h> { let mut request = self .http .post(self.config.url.join(&key).unwrap()) - .bearer_auth(&self.config.auth_token) + .bearer_auth(self.config.auth_token.as_str()) .body(serde_json::to_vec(&value).unwrap()); if let Some(expiration) = expiration { @@ -336,7 +354,7 @@ impl<'h> Kv<'h> { let mut request = self .http .put(self.config.url.join(&key).unwrap()) - .bearer_auth(&self.config.auth_token) + .bearer_auth(self.config.auth_token.as_str()) .body(serde_json::to_vec(&value).unwrap()); if let Some(expiration) = expiration { diff --git a/crates/daphne-server/src/storage_proxy_connection/mod.rs b/crates/daphne-server/src/storage_proxy_connection/mod.rs index 8798345ca..19ed049ee 100644 --- a/crates/daphne-server/src/storage_proxy_connection/mod.rs +++ b/crates/daphne-server/src/storage_proxy_connection/mod.rs @@ -79,7 +79,7 @@ impl<'d, B: DurableMethod + Debug, P: AsRef<[u8]>> RequestBuilder<'d, B, P> { .http .post(url) .body(self.request.into_bytes()) - .bearer_auth(&self.durable.config.auth_token) + .bearer_auth(self.durable.config.auth_token.as_str()) .send() .await?; diff --git a/crates/daphne-server/tests/e2e/e2e.rs b/crates/daphne-server/tests/e2e/e2e.rs index 2e340c93c..6198cc10c 100644 --- a/crates/daphne-server/tests/e2e/e2e.rs +++ b/crates/daphne-server/tests/e2e/e2e.rs @@ -1087,6 +1087,7 @@ async fn fixed_size() { } async fn leader_collect_taskprov_ok(version: DapVersion) { + const DAP_TASKPROV_COLLECTOR_TOKEN: &str = "I-am-the-collector"; let t = TestRunner::default_with_version(version).await; let batch_interval = t.batch_interval(); @@ -1149,7 +1150,7 @@ async fn leader_collect_taskprov_ok(version: DapVersion) { let collect_uri = t .leader_post_collect_using_token( client, - "I-am-the-collector", // DAP_TASKPROV_COLLECTOR_AUTH + DAP_TASKPROV_COLLECTOR_TOKEN, Some(&taskprov_advertisement), Some(&task_id), collect_req.get_encoded_with_param(&t.version).unwrap(), @@ -1159,8 +1160,20 @@ async fn leader_collect_taskprov_ok(version: DapVersion) { println!("collect_uri: {collect_uri}"); // Poll the collect URI before the CollectResp is ready. - let resp = t.poll_collection_url(client, &collect_uri).await.unwrap(); - assert_eq!(resp.status(), 202, "response: {resp:?}"); + let resp = t + .poll_collection_url_using_token(client, &collect_uri, DAP_TASKPROV_COLLECTOR_TOKEN) + .await + .unwrap(); + #[allow(clippy::format_in_format_args)] + { + assert_eq!( + resp.status(), + 202, + "response: {} {}", + format!("{resp:?}"), + resp.text().await.unwrap() + ); + } // The reports are aggregated in the background. let agg_telem = t.internal_process(client).await.unwrap(); @@ -1178,7 +1191,10 @@ async fn leader_collect_taskprov_ok(version: DapVersion) { ); // Poll the collect URI. - let resp = t.poll_collection_url(client, &collect_uri).await.unwrap(); + let resp = t + .poll_collection_url_using_token(client, &collect_uri, DAP_TASKPROV_COLLECTOR_TOKEN) + .await + .unwrap(); assert_eq!(resp.status(), 200); let collection = @@ -1203,7 +1219,10 @@ async fn leader_collect_taskprov_ok(version: DapVersion) { // Poll the collect URI once more. Expect the response to be the same as the first, per HTTP // GET semantics. - let resp = t.poll_collection_url(client, &collect_uri).await.unwrap(); + let resp = t + .poll_collection_url_using_token(client, &collect_uri, DAP_TASKPROV_COLLECTOR_TOKEN) + .await + .unwrap(); assert_eq!(resp.status(), 200); assert_eq!( resp.bytes().await.unwrap(), diff --git a/crates/daphne-server/tests/e2e/test_runner.rs b/crates/daphne-server/tests/e2e/test_runner.rs index cb6c4767b..a1e6ef4d5 100644 --- a/crates/daphne-server/tests/e2e/test_runner.rs +++ b/crates/daphne-server/tests/e2e/test_runner.rs @@ -72,6 +72,7 @@ impl TestRunner { } async fn with(version: DapVersion, query_config: &DapQueryConfig) -> Self { + println!("\n############ starting test prep ############"); let mut rng = thread_rng(); let now = SystemTime::now() .duration_since(SystemTime::UNIX_EPOCH) @@ -485,6 +486,10 @@ impl TestRunner { .context("no string for version")? .parse()?, ); + headers.insert( + reqwest::header::HeaderName::from_static(http_headers::DAP_AUTH_TOKEN), + reqwest::header::HeaderValue::from_str(&self.collector_bearer_token)?, + ); if let Some(taskprov_advertisement) = taskprov { headers.insert( reqwest::header::HeaderName::from_static(http_headers::DAP_TASKPROV), @@ -643,8 +648,8 @@ impl TestRunner { anyhow::ensure!( resp.status() == 200, "unexpected response status. Expected {} got {}: Body is {:?}", - resp.status(), reqwest::StatusCode::OK, + resp.status(), resp.text().await?, ); Ok(resp.json().await?) @@ -671,7 +676,9 @@ impl TestRunner { .context("request failed")?; anyhow::ensure!( resp.status() == 200, - "request to {url} failed: response: {resp:?}" + "request to {url} failed: response: {} {}", + format!("{resp:?}"), // text() moves so we have to format here + resp.text().await.unwrap_or("".into()), ); let t = resp.text().await.context("failed to extract text")?; // This is needed so we can have tests that call this expecting nothing and have it work @@ -756,6 +763,16 @@ impl TestRunner { &self, client: &reqwest::Client, url: &Url, + ) -> anyhow::Result { + self.poll_collection_url_using_token(client, url, &self.collector_bearer_token) + .await + } + + pub async fn poll_collection_url_using_token( + &self, + client: &reqwest::Client, + url: &Url, + token: &str, ) -> anyhow::Result { let builder = client.post(url.as_str()); let mut headers = reqwest::header::HeaderMap::new(); @@ -767,6 +784,10 @@ impl TestRunner { .context("no string for version")?, )?, ); + headers.insert( + reqwest::header::HeaderName::from_static(http_headers::DAP_AUTH_TOKEN), + reqwest::header::HeaderValue::from_str(token)?, + ); Ok(builder.headers(headers).send().await?) } } diff --git a/crates/daphne-service-utils/src/auth.rs b/crates/daphne-service-utils/src/auth.rs deleted file mode 100644 index 4a7898a1c..000000000 --- a/crates/daphne-service-utils/src/auth.rs +++ /dev/null @@ -1,124 +0,0 @@ -// Copyright (c) 2023 Cloudflare, Inc. All rights reserved. -// SPDX-License-Identifier: BSD-3-Clause - -//! Authorization methods for Daphne-Worker. - -use std::fmt::Debug; - -use daphne::auth::BearerToken; -use serde::{Deserialize, Serialize}; - -#[derive(PartialEq, Eq)] -pub struct TlsClientAuth { - pub verified: String, -} - -/// HTTP client authorization for Daphne-Worker. -/// -/// Multiple authorization methods can be configured. The sender may present multiple authorization -/// methods; the request is authorized if validation of all presented methods succeed. If an -/// authorization method is presented, but the server is not configured to validate it, then -/// validation of that method will fail. -// -// TODO(cjpatton) Add an authorization method for Cloudflare Access -// (https://www.cloudflare.com/products/zero-trust/access/). This allows us to delegate access -// control to that service; Daphne-Worker would just need to verify that Access granted access. -#[derive(PartialEq)] -pub struct DaphneAuth { - /// Bearer token, expected to appear in the - /// [`DAP_AUTH_TOKEN`](crate::http_headers::DAP_AUTH_TOKEN) header. - pub bearer_token: Option, - - /// TLS client authentication. The client uses a certificate when establishing the TLS - /// connection. This authorization method is Cloudflare-specific: Verifying the certificate - /// itself is handled by the process that invoked this Worker. The customer zone is also - /// expected to be configured to require mutual TLS for the route on which this Worker is - /// listening. - /// - /// When this authorization method is used, we verify that a certificate was presented and was - /// successfully verified by the TLS server. - /// - /// # Caveats - /// - /// * For now, only the Helper supports TLS client auth; the Leader still expects a bearer - /// token to be configured for the task. - /// - /// * For now, TLS client auth is only enabled if the taskprov extension is configured. - /// Enabling this feature for other tasks will require a bit plumbing. - pub cf_tls_client_auth: Option, -} - -// Custom debug implementation to avoid exposing sensitive information. -impl Debug for DaphneAuth { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - // pattern match on self to get a compiler error if the struct changes - let Self { - bearer_token, - cf_tls_client_auth, - } = self; - - fn opt_to_str(o: &Option) -> &dyn Debug { - if o.is_some() { - &"is-present" - } else { - &"is-missing" - } - } - - f.debug_struct("DaphneAuth") - .field("bearer_token", opt_to_str(bearer_token)) - .field("cf_tls_client_auth", opt_to_str(cf_tls_client_auth)) - .finish() - } -} - -// TODO(mendess): remove this implementation. Implementations of AsRef should never panic -impl AsRef for DaphneAuth { - fn as_ref(&self) -> &BearerToken { - if let Some(ref bearer_token) = self.bearer_token { - bearer_token - } else { - // We would only try this method if we previously resolved to use a bearer token for - // authorization. - unreachable!("no bearer token provided by sender") - } - } -} - -#[derive(Debug, Deserialize, Serialize, Clone)] -pub struct DaphneWorkerAuthMethod { - /// Expected bearer token. - #[serde(default, skip_serializing_if = "Option::is_none")] - pub bearer_token: Option, -} - -// TODO(mendess): remove this implementation. Implementations of AsRef should never panic -impl AsRef for DaphneWorkerAuthMethod { - fn as_ref(&self) -> &BearerToken { - if let Some(ref bearer_token) = self.bearer_token { - bearer_token - } else { - // We would only try this method if we previously resolved to use a bearer token for - // authorization. - unreachable!("no bearer token provided by sender") - } - } -} - -#[cfg(test)] -mod test { - use super::{BearerToken, DaphneWorkerAuthMethod}; - - #[test] - fn daphne_worker_auth_method_json_serialization() { - let daphne_worker_auth_method: DaphneWorkerAuthMethod = - serde_json::from_str(r#"{"bearer_token":"the bearer token"}"#).unwrap(); - assert_eq!( - daphne_worker_auth_method.bearer_token, - Some(BearerToken::from("the bearer token".to_string())) - ); - - let daphne_worker_auth_method: DaphneWorkerAuthMethod = serde_json::from_str("{}").unwrap(); - assert!(daphne_worker_auth_method.bearer_token.is_none()); - } -} diff --git a/crates/daphne-service-utils/src/bearer_token.rs b/crates/daphne-service-utils/src/bearer_token.rs new file mode 100644 index 000000000..d0c8d332f --- /dev/null +++ b/crates/daphne-service-utils/src/bearer_token.rs @@ -0,0 +1,57 @@ +// Copyright (c) 2022 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +//! DAP request authorization. + +use core::fmt; + +use daphne::messages::constant_time_eq; +use serde::{Deserialize, Serialize}; + +/// A bearer token used for authorizing DAP requests. +#[derive(Clone, Deserialize, Serialize, Eq)] +#[serde(transparent)] +pub struct BearerToken { + raw: String, +} + +impl BearerToken { + pub fn as_str(&self) -> &str { + self.raw.as_str() + } +} + +impl fmt::Debug for BearerToken { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + #[cfg(feature = "test-utils")] + { + write!(f, "BearerToken({})", self.raw) + } + #[cfg(not(feature = "test-utils"))] + write!(f, "BearerToken(REDACTED)") + } +} + +impl AsRef for BearerToken { + fn as_ref(&self) -> &str { + self.as_str() + } +} + +impl PartialEq for BearerToken { + fn eq(&self, other: &Self) -> bool { + constant_time_eq(self.raw.as_bytes(), other.raw.as_bytes()) + } +} + +impl From for BearerToken { + fn from(raw: String) -> Self { + Self { raw } + } +} + +impl From<&str> for BearerToken { + fn from(raw: &str) -> Self { + Self::from(raw.to_string()) + } +} diff --git a/crates/daphne-service-utils/src/config.rs b/crates/daphne-service-utils/src/config.rs index 4b729c357..17e69ee09 100644 --- a/crates/daphne-service-utils/src/config.rs +++ b/crates/daphne-service-utils/src/config.rs @@ -9,7 +9,7 @@ use p256::ecdsa::SigningKey; use serde::{Deserialize, Serialize}; use url::Url; -use crate::{auth::DaphneWorkerAuthMethod, DapRole}; +use crate::{bearer_token::BearerToken, DapRole}; /// draft-wang-ppm-dap-taskprov: Long-lived parameters for the taskprov extension. #[derive(Serialize, Deserialize, Debug, Clone)] @@ -22,13 +22,22 @@ pub struct TaskprovConfig { #[serde(with = "hex")] pub vdaf_verify_key_init: [u8; 32], - /// Leader, Helper: Method for authorizing Leader requests. - #[serde(with = "from_raw_string")] - pub leader_auth: DaphneWorkerAuthMethod, + /// Peer's bearer token. + pub peer_auth: PeerBearerToken, - /// Leader: Method for authorizing Collector requests. - #[serde(default, with = "from_raw_string")] - pub collector_auth: Option, + /// Bearer token used when trying to communicate with an aggregator using taskprov. + #[serde(default)] + pub self_bearer_token: Option, +} + +/// Peer authentication tokens for incomming requests. Different roles have different peers. +/// - Helpers have a Leader peer. +/// - Leaders have a Collector peer. +#[derive(Serialize, Deserialize, Debug, Clone)] +#[serde(rename_all = "lowercase")] +pub enum PeerBearerToken { + Leader { expected_token: BearerToken }, + Collector { expected_token: BearerToken }, } pub type HpkeRecieverConfigList = Vec; diff --git a/crates/daphne-service-utils/src/lib.rs b/crates/daphne-service-utils/src/lib.rs index e311bd846..240df21ad 100644 --- a/crates/daphne-service-utils/src/lib.rs +++ b/crates/daphne-service-utils/src/lib.rs @@ -5,12 +5,13 @@ use std::str::FromStr; use serde::{Deserialize, Serialize}; -pub mod auth; +pub mod bearer_token; pub mod config; #[cfg(feature = "durable_requests")] pub mod durable_requests; pub mod http_headers; pub mod metrics; +#[cfg(feature = "test-utils")] pub mod test_route_types; // the generated code expects this module to be defined at the root of the library. diff --git a/crates/daphne/src/auth.rs b/crates/daphne/src/auth.rs deleted file mode 100644 index 46caca4db..000000000 --- a/crates/daphne/src/auth.rs +++ /dev/null @@ -1,170 +0,0 @@ -// Copyright (c) 2022 Cloudflare, Inc. All rights reserved. -// SPDX-License-Identifier: BSD-3-Clause - -//! DAP request authorization. - -use std::fmt::Display; - -use crate::{ - constants::DapMediaType, - fatal_error, - messages::{constant_time_eq, TaskId}, - DapError, DapRequest, DapSender, DapTaskConfig, -}; -use async_trait::async_trait; -use serde::{Deserialize, Serialize}; - -/// A bearer token used for authorizing DAP requests. -#[derive(Clone, Debug, Deserialize, Serialize)] -#[cfg_attr(any(test, feature = "test-utils"), derive(deepsize::DeepSizeOf))] -#[serde(transparent)] -pub struct BearerToken { - raw: String, -} - -impl BearerToken { - pub fn as_str(&self) -> &str { - self.raw.as_str() - } -} - -impl AsRef for BearerToken { - fn as_ref(&self) -> &str { - self.as_str() - } -} - -impl PartialEq for BearerToken { - fn eq(&self, other: &Self) -> bool { - constant_time_eq(self.raw.as_bytes(), other.raw.as_bytes()) - } -} - -impl From for BearerToken { - fn from(raw: String) -> Self { - Self { raw } - } -} - -impl From<&str> for BearerToken { - fn from(raw: &str) -> Self { - Self::from(raw.to_string()) - } -} - -impl AsRef for BearerToken { - fn as_ref(&self) -> &Self { - self - } -} - -impl Display for BearerToken { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(self.as_str()) - } -} - -/// A source of bearer tokens used for authorizing DAP requests. -#[async_trait] -pub trait BearerTokenProvider { - /// A reference to a bearer token owned by the provider. - type WrappedBearerToken<'a>: AsRef + Send - where - Self: 'a; - - /// Fetch the Leader's bearer token for the given task, if the task is recognized. - async fn get_leader_bearer_token_for<'s>( - &'s self, - task_id: &'s TaskId, - task_config: &DapTaskConfig, - ) -> Result>, DapError>; - - /// Fetch the Collector's bearer token for the given task, if the task is recognized. - async fn get_collector_bearer_token_for<'s>( - &'s self, - task_id: &'s TaskId, - task_config: &DapTaskConfig, - ) -> Result>, DapError>; - - /// Return a bearer token that can be used to authorize a request with the given task ID and - /// media type. - async fn authorize_with_bearer_token<'s>( - &'s self, - task_id: &'s TaskId, - task_config: &DapTaskConfig, - media_type: &DapMediaType, - ) -> Result, DapError> { - if matches!(media_type.sender(), DapSender::Leader) { - let token = self - .get_leader_bearer_token_for(task_id, task_config) - .await? - .ok_or_else(|| { - fatal_error!(err = "attempted to authorize request with unknown task ID",) - })?; - return Ok(token); - } - - Err(fatal_error!( - err = "attempted to authorize request of type", - ?media_type, - )) - } - - /// Check that the bearer token carried by a request can be used to authorize that request. - /// - /// Return `None` if the request is authorized. Otherwise return `Some(reason)`, where `reason` - /// is the reason for the failure. - async fn bearer_token_authorized + Send + Sync>( - &self, - task_config: &DapTaskConfig, - req: &DapRequest, - ) -> Result, DapError> { - if req.task_id.is_none() { - // Can't authorize request with missing task ID. - return Ok(Some( - "Cannot authorize request with missing task ID.".into(), - )); - } - let task_id = req.task_id.as_ref().unwrap(); - - // TODO spec: Decide whether to check that the bearer token has the right format, say, - // following RFC 6750, Section 2.1. Note that we would also need to replace `From - // for BearerToken` with `TryFrom` so that a `DapError` can be returned if the - // token is not formatted properly. - if matches!(req.sender(), Some(DapSender::Leader)) { - if let Some(ref got) = req.sender_auth { - if let Some(expected) = self - .get_leader_bearer_token_for(task_id, task_config) - .await? - { - return Ok(if got.as_ref() == expected.as_ref() { - None - } else { - Some("The indicated bearer token is incorrect for the Leader.".into()) - }); - } - } - } - - if matches!(req.sender(), Some(DapSender::Collector)) { - if let Some(ref got) = req.sender_auth { - if let Some(expected) = self - .get_collector_bearer_token_for(task_id, task_config) - .await? - { - return Ok(if got.as_ref() == expected.as_ref() { - None - } else { - Some("The indicated bearer token is incorrect for the Collector.".into()) - }); - } - } - } - - // Deny request with unhandled or unknown media type. - Ok(Some(format!( - "Cannot resolve sender due to unexpected media type ({:?}).", - req.media_type - ))) - } -} diff --git a/crates/daphne/src/error/aborts.rs b/crates/daphne/src/error/aborts.rs index 94391a284..f653fbd18 100644 --- a/crates/daphne/src/error/aborts.rs +++ b/crates/daphne/src/error/aborts.rs @@ -181,7 +181,7 @@ impl DapAbort { } /// Abort due to unexpected value for HTTP content-type header. - pub fn content_type(req: &DapRequest, expected: DapMediaType) -> Self { + pub fn content_type(req: &DapRequest, expected: DapMediaType) -> Self { let want_content_type = expected.as_str_for_version(req.version).unwrap_or_else(|| { unreachable!("unexpected content-type for DAP version {:?}", req.version) }); diff --git a/crates/daphne/src/lib.rs b/crates/daphne/src/lib.rs index 06e20b5f7..35481621a 100644 --- a/crates/daphne/src/lib.rs +++ b/crates/daphne/src/lib.rs @@ -41,7 +41,6 @@ //! > requests to a collect job URI whose results have been removed. pub mod audit_log; -pub mod auth; pub mod constants; pub mod error; pub mod hpke; @@ -1074,7 +1073,7 @@ pub enum DapResource { /// DAP request. #[derive(Debug)] -pub struct DapRequest { +pub struct DapRequest { /// Protocol version indicated by the request. pub version: DapVersion, @@ -1091,15 +1090,12 @@ pub struct DapRequest { /// Request payload. pub payload: Vec, - /// Sender authorization, e.g., a bearer token. - pub sender_auth: Option, - /// taskprov: The task advertisement, sent in the `dap-taskprov` header. pub taskprov: Option, } #[cfg(test)] -impl Default for DapRequest { +impl Default for DapRequest { fn default() -> Self { Self { version: DapVersion::Draft09, @@ -1107,13 +1103,12 @@ impl Default for DapRequest { task_id: Default::default(), resource: Default::default(), payload: Default::default(), - sender_auth: Default::default(), taskprov: Default::default(), } } } -impl DapRequest { +impl DapRequest { /// Return the task ID, handling a missing ID as a user error. pub fn task_id(&self) -> Result<&TaskId, DapAbort> { if let Some(ref id) = self.task_id { diff --git a/crates/daphne/src/roles/aggregator.rs b/crates/daphne/src/roles/aggregator.rs index edc4d3978..08fd82910 100644 --- a/crates/daphne/src/roles/aggregator.rs +++ b/crates/daphne/src/roles/aggregator.rs @@ -15,7 +15,7 @@ use crate::{ metrics::{DaphneMetrics, DaphneRequestType}, protocol::aggregator::{EarlyReportStateConsumed, EarlyReportStateInitialized}, taskprov, DapAggregateShare, DapAggregateSpan, DapAggregationParam, DapError, DapGlobalConfig, - DapRequest, DapResponse, DapTaskConfig, + DapRequest, DapResponse, DapTaskConfig, DapVersion, }; /// Report initializer. Used by a DAP Aggregator [`DapAggregator`] when initializing an aggregation @@ -45,23 +45,12 @@ pub enum MergeAggShareError { /// DAP Aggregator functionality. #[async_trait] -pub trait DapAggregator: HpkeProvider + DapReportInitializer + Sized { +pub trait DapAggregator: HpkeProvider + DapReportInitializer + Sized { /// A refernce to a task configuration stored by the Aggregator. type WrappedDapTaskConfig<'a>: AsRef + Send where Self: 'a; - /// Decide whether the given DAP request is authorized. - /// - /// If the return value is `None`, then the request is authorized. If the return value is - /// `Some(reason)`, then the request is denied and `reason` conveys details about how the - /// decision was reached. - async fn unauthorized_reason( - &self, - task_config: &DapTaskConfig, - req: &DapRequest, - ) -> Result, DapError>; - /// Look up the DAP global configuration. async fn get_global_config(&self) -> Result; @@ -95,7 +84,7 @@ pub trait DapAggregator: HpkeProvider + DapReportInitializer + Sized { /// nothing. async fn taskprov_put( &self, - req: &DapRequest, + req: &DapRequest, task_config: DapTaskConfig, ) -> Result<(), DapError>; @@ -164,19 +153,18 @@ pub trait DapAggregator: HpkeProvider + DapReportInitializer + Sized { } /// Handle request for the Aggregator's HPKE configuration. -pub async fn handle_hpke_config_req( +pub async fn handle_hpke_config_req( aggregator: &A, - req: &DapRequest, + version: DapVersion, task_id: Option, ) -> Result where - S: Sync, - A: DapAggregator, + A: DapAggregator, { let metrics = aggregator.metrics(); let hpke_config = aggregator - .get_hpke_config_for(req.version, task_id.as_ref()) + .get_hpke_config_for(version, task_id.as_ref()) .await?; if let Some(task_id) = task_id { @@ -186,10 +174,8 @@ where .ok_or(DapAbort::UnrecognizedTask { task_id })?; // Check whether the DAP version in the request matches the task config. - if task_config.as_ref().version != req.version { - return Err( - DapAbort::version_mismatch(req.version, task_config.as_ref().version).into(), - ); + if task_config.as_ref().version != version { + return Err(DapAbort::version_mismatch(version, task_config.as_ref().version).into()); } } @@ -202,7 +188,7 @@ where metrics.inbound_req_inc(DaphneRequestType::HpkeConfig); Ok(DapResponse { - version: req.version, + version, media_type: DapMediaType::HpkeConfigList, payload, }) diff --git a/crates/daphne/src/roles/helper.rs b/crates/daphne/src/roles/helper.rs index 74835bd73..97235b888 100644 --- a/crates/daphne/src/roles/helper.rs +++ b/crates/daphne/src/roles/helper.rs @@ -5,7 +5,6 @@ use std::{collections::HashMap, sync::Once}; use async_trait::async_trait; use prio::codec::{Encode, ParameterizedDecode}; -use tracing::error; use super::{check_batch, check_request_content_type, resolve_taskprov, DapAggregator}; use crate::{ @@ -24,11 +23,11 @@ use crate::{ /// DAP Helper functionality. #[async_trait] -pub trait DapHelper: DapAggregator {} +pub trait DapHelper: DapAggregator {} -pub async fn handle_agg_job_init_req<'req, S: Sync, A: DapHelper>( +pub async fn handle_agg_job_init_req<'req, A: DapHelper>( aggregator: &A, - req: &'req DapRequest, + req: &'req DapRequest, replay_protection: ReplayProtection, ) -> Result { let global_config = aggregator.get_global_config().await?; @@ -51,15 +50,6 @@ pub async fn handle_agg_job_init_req<'req, S: Sync, A: DapHelper>( .ok_or(DapAbort::UnrecognizedTask { task_id: *task_id })?; let task_config = wrapped_task_config.as_ref(); - if let Some(reason) = aggregator.unauthorized_reason(task_config, req).await? { - error!("aborted unauthorized collect request: {reason}"); - return Err(DapAbort::UnauthorizedRequest { - detail: reason, - task_id: *task_id, - } - .into()); - } - let DapResource::AggregationJob(_agg_job_id) = req.resource else { return Err(DapAbort::BadRequest("missing aggregation job ID".to_string()).into()); }; @@ -120,9 +110,9 @@ pub async fn handle_agg_job_init_req<'req, S: Sync, A: DapHelper>( } /// Handle a request pertaining to an aggregation job. -pub async fn handle_agg_job_req<'req, S: Sync, A: DapHelper>( +pub async fn handle_agg_job_req<'req, A: DapHelper>( aggregator: &A, - req: &DapRequest, + req: &DapRequest, replay_protection: ReplayProtection, ) -> Result { match req.media_type { @@ -135,9 +125,9 @@ pub async fn handle_agg_job_req<'req, S: Sync, A: DapHelper>( /// Handle a request for an aggregate share. This is called by the Leader to complete a /// collection job. -pub async fn handle_agg_share_req<'req, S: Sync, A: DapHelper>( +pub async fn handle_agg_share_req<'req, A: DapHelper>( aggregator: &A, - req: &DapRequest, + req: &DapRequest, ) -> Result { let global_config = aggregator.get_global_config().await?; let now = aggregator.get_current_time(); @@ -156,15 +146,6 @@ pub async fn handle_agg_share_req<'req, S: Sync, A: DapHelper>( .ok_or(DapAbort::UnrecognizedTask { task_id: *task_id })?; let task_config = wrapped_task_config.as_ref(); - if let Some(reason) = aggregator.unauthorized_reason(task_config, req).await? { - error!("aborted unauthorized collect request: {reason}"); - return Err(DapAbort::UnauthorizedRequest { - detail: reason, - task_id: *task_id, - } - .into()); - } - // Check whether the DAP version in the request matches the task config. if task_config.version != req.version { return Err(DapAbort::version_mismatch(req.version, task_config.version).into()); @@ -276,8 +257,8 @@ fn check_part_batch( Ok(()) } -async fn finish_agg_job_and_aggregate( - helper: &impl DapHelper, +async fn finish_agg_job_and_aggregate( + helper: &impl DapHelper, task_id: &TaskId, task_config: &DapTaskConfig, part_batch_sel: &PartialBatchSelector, diff --git a/crates/daphne/src/roles/leader/mod.rs b/crates/daphne/src/roles/leader/mod.rs index 78e184aeb..816ef7a8e 100644 --- a/crates/daphne/src/roles/leader/mod.rs +++ b/crates/daphne/src/roles/leader/mod.rs @@ -9,7 +9,7 @@ use async_trait::async_trait; use futures::future::try_join_all; use prio::codec::{Decode, Encode, ParameterizedDecode, ParameterizedEncode}; use rand::{thread_rng, Rng}; -use tracing::{debug, error}; +use tracing::debug; use url::Url; use super::{ @@ -45,8 +45,8 @@ enum LeaderHttpRequestMethod { Put, } -async fn leader_send_http_request( - role: &impl DapLeader, +async fn leader_send_http_request( + role: &impl DapLeader, task_id: &TaskId, task_config: &DapTaskConfig, opts: LeaderHttpRequestOptions<'_>, @@ -71,10 +71,6 @@ async fn leader_send_http_request( media_type: Some(req_media_type), task_id: Some(*task_id), resource, - sender_auth: Some( - role.authorize(task_id, task_config, &req_media_type, &req_data) - .await?, - ), payload: req_data, taskprov, }; @@ -88,19 +84,6 @@ async fn leader_send_http_request( Ok(resp) } -/// A party in the DAP protocol who is authorized to send requests to another party. -#[async_trait] -pub trait DapAuthorizedSender { - /// Add authorization to an outbound DAP request with the given task ID, media type, and payload. - async fn authorize( - &self, - task_id: &TaskId, - task_config: &DapTaskConfig, - media_type: &DapMediaType, - payload: &[u8], - ) -> Result; -} - /// A work item, either an aggregation job or collection job. #[derive(Debug)] #[cfg_attr(any(test, feature = "test-utils"), derive(deepsize::DeepSizeOf))] @@ -130,7 +113,7 @@ impl WorkItem { /// DAP Leader functionality. #[async_trait] -pub trait DapLeader: DapAuthorizedSender + DapAggregator { +pub trait DapLeader: DapAggregator { /// Store a report for use later on. async fn put_report(&self, report: &Report, task_id: &TaskId) -> Result<(), DapError>; @@ -171,16 +154,16 @@ pub trait DapLeader: DapAuthorizedSender + DapAggregator { ) -> Result<(), DapError>; /// Send an HTTP POST request. - async fn send_http_post(&self, req: DapRequest, url: Url) -> Result; + async fn send_http_post(&self, req: DapRequest, url: Url) -> Result; /// Send an HTTP PUT request. - async fn send_http_put(&self, req: DapRequest, url: Url) -> Result; + async fn send_http_put(&self, req: DapRequest, url: Url) -> Result; } /// Handle a report from a Client. -pub async fn handle_upload_req>( +pub async fn handle_upload_req( aggregator: &A, - req: &DapRequest, + req: &DapRequest, ) -> Result<(), DapError> { let global_config = aggregator.get_global_config().await?; let metrics = aggregator.metrics(); @@ -255,9 +238,9 @@ pub async fn handle_upload_req>( /// Handle a collect job from the Collector. The response is the URI that the Collector will /// poll later on to get the collection. -pub async fn handle_coll_job_req>( +pub async fn handle_coll_job_req( aggregator: &A, - req: &DapRequest, + req: &DapRequest, ) -> Result { let global_config = aggregator.get_global_config().await?; let now = aggregator.get_current_time(); @@ -271,22 +254,12 @@ pub async fn handle_coll_job_req>( resolve_taskprov(aggregator, task_id, req, &global_config).await?; } - let task_id = req.task_id()?; let wrapped_task_config = aggregator .get_task_config_for(task_id) .await? .ok_or(DapAbort::UnrecognizedTask { task_id: *task_id })?; let task_config = wrapped_task_config.as_ref(); - if let Some(reason) = aggregator.unauthorized_reason(task_config, req).await? { - error!("aborted unauthorized collect request: {reason}"); - return Err(DapAbort::UnauthorizedRequest { - detail: reason, - task_id: *task_id, - } - .into()); - } - let coll_job_req = CollectionReq::get_decoded_with_param(&req.version, req.payload.as_ref()) .map_err(|e| DapAbort::from_codec_error(e, *task_id))?; @@ -334,7 +307,7 @@ pub async fn handle_coll_job_req>( /// Run an aggregation job for a set of reports. Return the number of reports that were /// aggregated successfully. -async fn run_agg_job>( +async fn run_agg_job( aggregator: &A, task_id: &TaskId, task_config: &DapTaskConfig, @@ -440,7 +413,7 @@ async fn run_agg_job>( /// Handle a pending collection job. If the results are ready, then compute the aggregate /// results and store them to be retrieved by the Collector later. Returns the number of /// reports in the batch. -async fn run_coll_job>( +async fn run_coll_job( aggregator: &A, task_id: &TaskId, task_config: &DapTaskConfig, @@ -547,7 +520,7 @@ async fn run_coll_job>( /// /// Collection jobs are processed in order. If a collection job is still pending once processed, it /// is pushed to the back of the work queue. -pub async fn process>( +pub async fn process( aggregator: &A, host: &str, num_items: usize, diff --git a/crates/daphne/src/roles/mod.rs b/crates/daphne/src/roles/mod.rs index a8db54eee..29ef46beb 100644 --- a/crates/daphne/src/roles/mod.rs +++ b/crates/daphne/src/roles/mod.rs @@ -16,10 +16,10 @@ use tracing::warn; pub use aggregator::{DapAggregator, DapReportInitializer}; pub use helper::DapHelper; -pub use leader::{DapAuthorizedSender, DapLeader}; +pub use leader::DapLeader; -async fn check_batch( - agg: &impl DapAggregator, +async fn check_batch( + agg: &impl DapAggregator, task_config: &DapTaskConfig, task_id: &TaskId, query: &Query, @@ -92,10 +92,7 @@ async fn check_batch( Ok(()) } -fn check_request_content_type( - req: &DapRequest, - expected: DapMediaType, -) -> Result<(), DapAbort> { +fn check_request_content_type(req: &DapRequest, expected: DapMediaType) -> Result<(), DapAbort> { if req.media_type != Some(expected) { Err(DapAbort::content_type(req, expected)) } else { @@ -103,10 +100,10 @@ fn check_request_content_type( } } -async fn resolve_taskprov( - agg: &impl DapAggregator, +async fn resolve_taskprov( + agg: &impl DapAggregator, task_id: &TaskId, - req: &DapRequest, + req: &DapRequest, global_config: &DapGlobalConfig, ) -> Result<(), DapError> { if agg.get_task_config_for(task_id).await?.is_some() { @@ -146,12 +143,11 @@ async fn resolve_taskprov( #[cfg(test)] mod test { - use super::{aggregator, helper, leader, DapAuthorizedSender, DapLeader}; + use super::{aggregator, helper, leader, DapLeader}; #[cfg(feature = "experimental")] use crate::vdaf::{mastic::MasticWeight, MasticWeightConfig}; use crate::{ assert_metrics_include, async_test_versions, - auth::BearerToken, constants::DapMediaType, hpke::{HpkeKemId, HpkeProvider, HpkeReceiverConfig}, messages::{ @@ -179,8 +175,6 @@ mod test { pub(super) struct TestData { pub now: Time, global_config: DapGlobalConfig, - collector_token: BearerToken, - taskprov_collector_token: BearerToken, pub time_interval_task_id: TaskId, pub fixed_size_task_id: TaskId, pub expired_task_id: TaskId, @@ -188,10 +182,8 @@ mod test { pub mastic_task_id: TaskId, helper_registry: prometheus::Registry, tasks: HashMap, - pub leader_token: BearerToken, collector_hpke_receiver_config: HpkeReceiverConfig, taskprov_vdaf_verify_key_init: [u8; 32], - taskprov_leader_token: BearerToken, leader_registry: prometheus::Registry, } @@ -313,14 +305,8 @@ mod test { ); } - // Authorization tokens. These are normally chosen at random. - let leader_token = BearerToken::from("leader_token"); - let collector_token = BearerToken::from("collector_token"); - // taskprov let taskprov_vdaf_verify_key_init = rng.gen::<[u8; 32]>(); - let taskprov_leader_token = BearerToken::from("taskprov_leader_token"); - let taskprov_collector_token = BearerToken::from("taskprov_collector_token"); let helper_registry = prometheus::Registry::new_custom( Option::None, @@ -342,8 +328,6 @@ mod test { Self { now, global_config, - collector_token, - taskprov_collector_token, time_interval_task_id, fixed_size_task_id, expired_task_id, @@ -351,8 +335,6 @@ mod test { mastic_task_id, helper_registry, tasks, - leader_token, - taskprov_leader_token, collector_hpke_receiver_config, taskprov_vdaf_verify_key_init, leader_registry, @@ -366,11 +348,9 @@ mod test { .gen_hpke_receiver_config_list(thread_rng().gen()) .expect("failed to generate HPKE receiver config"), self.global_config.clone(), - self.leader_token.clone(), self.collector_hpke_receiver_config.config.clone(), &self.helper_registry, self.taskprov_vdaf_verify_key_init, - self.taskprov_leader_token.clone(), )) } @@ -381,13 +361,9 @@ mod test { .gen_hpke_receiver_config_list(thread_rng().gen()) .expect("failed to generate HPKE receiver config"), self.global_config, - self.leader_token, - self.collector_token.clone(), self.collector_hpke_receiver_config.config.clone(), &self.leader_registry, self.taskprov_vdaf_verify_key_init, - self.taskprov_leader_token, - self.taskprov_collector_token.clone(), Arc::clone(&helper), )); @@ -395,8 +371,6 @@ mod test { now: self.now, leader, helper, - collector_token: self.collector_token, - taskprov_collector_token: self.taskprov_collector_token, time_interval_task_id: self.time_interval_task_id, fixed_size_task_id: self.fixed_size_task_id, expired_task_id: self.expired_task_id, @@ -412,8 +386,6 @@ mod test { now: Time, leader: Arc, helper: Arc, - collector_token: BearerToken, - taskprov_collector_token: BearerToken, time_interval_task_id: TaskId, fixed_size_task_id: TaskId, expired_task_id: TaskId, @@ -430,11 +402,7 @@ mod test { data.with_leader(helper) } - pub async fn gen_test_upload_req( - &self, - report: Report, - task_id: &TaskId, - ) -> DapRequest { + pub async fn gen_test_upload_req(&self, report: Report, task_id: &TaskId) -> DapRequest { let task_config = self.leader.unchecked_get_task_config(task_id).await; let version = task_config.version; @@ -448,11 +416,7 @@ mod test { } } - pub async fn gen_test_coll_job_req( - &self, - query: Query, - task_id: &TaskId, - ) -> DapRequest { + pub async fn gen_test_coll_job_req(&self, query: Query, task_id: &TaskId) -> DapRequest { self.gen_test_coll_job_req_for_collection(query, DapAggregationParam::Empty, task_id) .await } @@ -462,10 +426,10 @@ mod test { query: Query, agg_param: DapAggregationParam, task_id: &TaskId, - ) -> DapRequest { + ) -> DapRequest { let task_config = self.leader.unchecked_get_task_config(task_id).await; - self.collector_authorized_req( + Self::collector_req( task_id, &task_config, DapMediaType::CollectReq, @@ -481,7 +445,7 @@ mod test { task_id: &TaskId, agg_param: DapAggregationParam, reports: Vec, - ) -> (DapAggregationJobState, DapRequest) { + ) -> (DapAggregationJobState, DapRequest) { let mut rng = thread_rng(); let task_config = self.leader.unchecked_get_task_config(task_id).await; let part_batch_sel = match task_config.query { @@ -508,38 +472,14 @@ mod test { ( leader_state, - self.leader_authorized_req( + Self::leader_req( task_id, &task_config, Some(&agg_job_id), DapMediaType::AggregationJobInitReq, agg_job_init_req, - ) - .await, - ) - } - - pub async fn gen_test_agg_share_req( - &self, - report_count: u64, - checksum: [u8; 32], - ) -> DapRequest { - let task_id = &self.time_interval_task_id; - let task_config = self.leader.unchecked_get_task_config(task_id).await; - - self.leader_authorized_req( - task_id, - &task_config, - None, - DapMediaType::AggregateShareReq, - AggregateShareReq { - batch_sel: BatchSelector::default(), - agg_param: Vec::default(), - report_count, - checksum, - }, + ), ) - .await } pub async fn gen_test_report(&self, task_id: &TaskId) -> Report { @@ -585,21 +525,14 @@ mod test { .unwrap() } - pub async fn leader_authorized_req>( - &self, + pub fn leader_req>( task_id: &TaskId, task_config: &DapTaskConfig, agg_job_id: Option<&AggregationJobId>, media_type: DapMediaType, msg: M, - ) -> DapRequest { + ) -> DapRequest { let payload = msg.get_encoded_with_param(&task_config.version).unwrap(); - let sender_auth = Some( - self.leader - .authorize(task_id, task_config, &media_type, &payload) - .await - .unwrap(), - ); DapRequest { version: task_config.version, media_type: Some(media_type), @@ -608,25 +541,18 @@ mod test { DapResource::AggregationJob(*id) }), payload, - sender_auth, ..Default::default() } } - pub fn collector_authorized_req>( - &self, + pub fn collector_req>( task_id: &TaskId, task_config: &DapTaskConfig, media_type: DapMediaType, msg: M, - ) -> DapRequest { + ) -> DapRequest { let mut rng = thread_rng(); let coll_job_id = CollectionJobId(rng.gen()); - let sender_auth = if task_config.method_is_taskprov() { - Some(self.taskprov_collector_token.clone()) - } else { - Some(self.collector_token.clone()) - }; DapRequest { version: task_config.version, @@ -634,7 +560,6 @@ mod test { task_id: Some(*task_id), resource: DapResource::CollectionJob(coll_job_id), payload: msg.get_encoded_with_param(&task_config.version).unwrap(), - sender_auth, ..Default::default() } } @@ -649,21 +574,19 @@ mod test { let agg_job_id = AggregationJobId(rng.gen()); // Helper expects "time_interval" query, but Leader indicates "fixed_size". - let req = t - .leader_authorized_req( - task_id, - &task_config, - Some(&agg_job_id), - DapMediaType::AggregationJobInitReq, - AggregationJobInitReq { - agg_param: Vec::default(), - part_batch_sel: PartialBatchSelector::FixedSizeByBatchId { - batch_id: BatchId(rng.gen()), - }, - prep_inits: Vec::default(), + let req = Test::leader_req( + task_id, + &task_config, + Some(&agg_job_id), + DapMediaType::AggregationJobInitReq, + AggregationJobInitReq { + agg_param: Vec::default(), + part_batch_sel: PartialBatchSelector::FixedSizeByBatchId { + batch_id: BatchId(rng.gen()), }, - ) - .await; + prep_inits: Vec::default(), + }, + ); assert_matches!( helper::handle_agg_job_req(&*t.helper, &req, Default::default()) .await @@ -706,36 +629,6 @@ mod test { // // async_test_versions! { handle_agg_job_req_init_expired_task } - async fn handle_agg_job_init_req_unauthorized_request(version: DapVersion) { - let t = Test::new(version); - let report = t.gen_test_report(&t.time_interval_task_id).await; - let (_, mut req) = t - .gen_test_agg_job_init_req( - &t.time_interval_task_id, - DapAggregationParam::Empty, - vec![report], - ) - .await; - req.sender_auth = None; - - // Expect failure due to missing bearer token. - assert_matches!( - helper::handle_agg_job_req(&*t.helper, &req, Default::default()).await, - Err(DapError::Abort(DapAbort::UnauthorizedRequest { .. })) - ); - - // Expect failure due to incorrect bearer token. - req.sender_auth = Some(BearerToken::from("incorrect auth token!".to_string())); - assert_matches!( - helper::handle_agg_job_req(&*t.helper, &req, Default::default()).await, - Err(DapError::Abort(DapAbort::UnauthorizedRequest { .. })) - ); - - assert_eq!(t.helper.audit_log.invocations(), 0); - } - - async_test_versions! { handle_agg_job_init_req_unauthorized_request } - async fn handle_hpke_config_req_unrecognized_task(version: DapVersion) { let t = Test::new(version); let mut rng = thread_rng(); @@ -750,7 +643,7 @@ mod test { }; assert_eq!( - aggregator::handle_hpke_config_req(&*t.leader, &req, Some(task_id)) + aggregator::handle_hpke_config_req(&*t.leader, req.version, Some(task_id)) .await .unwrap_err(), DapError::Abort(DapAbort::UnrecognizedTask { task_id }) @@ -774,34 +667,13 @@ mod test { // that Daphne-Workder does not implement this behavior. Instead it returns the HPKE config // used for all tasks. assert_matches!( - aggregator::handle_hpke_config_req(&*t.leader, &req, None).await, + aggregator::handle_hpke_config_req(&*t.leader, req.version, None).await, Err(DapError::Abort(DapAbort::MissingTaskId)) ); } async_test_versions! { handle_hpke_config_req_missing_task_id } - async fn handle_agg_share_req_unauthorized_request(version: DapVersion) { - let t = Test::new(version); - let mut req = t.gen_test_agg_share_req(0, [0; 32]).await; - req.sender_auth = None; - - // Expect failure due to missing bearer token. - assert_matches!( - helper::handle_agg_share_req(&*t.helper, &req).await, - Err(DapError::Abort(DapAbort::UnauthorizedRequest { .. })) - ); - - // Expect failure due to incorrect bearer token. - req.sender_auth = Some(BearerToken::from("incorrect auth token!".to_string())); - assert_matches!( - helper::handle_agg_share_req(&*t.helper, &req).await, - Err(DapError::Abort(DapAbort::UnauthorizedRequest { .. })) - ); - } - - async_test_versions! { handle_agg_share_req_unauthorized_request } - // Test that the Helper handles the batch selector sent from the Leader properly. async fn handle_agg_share_req_invalid_batch_sel(version: DapVersion) { let mut rng = thread_rng(); @@ -812,22 +684,20 @@ mod test { .leader .unchecked_get_task_config(&t.time_interval_task_id) .await; - let req = t - .leader_authorized_req( - &t.time_interval_task_id, - &task_config, - None, - DapMediaType::AggregateShareReq, - AggregateShareReq { - batch_sel: BatchSelector::FixedSizeByBatchId { - batch_id: BatchId(rng.gen()), - }, - agg_param: Vec::default(), - report_count: 0, - checksum: [0; 32], + let req = Test::leader_req( + &t.time_interval_task_id, + &task_config, + None, + DapMediaType::AggregateShareReq, + AggregateShareReq { + batch_sel: BatchSelector::FixedSizeByBatchId { + batch_id: BatchId(rng.gen()), }, - ) - .await; + agg_param: Vec::default(), + report_count: 0, + checksum: [0; 32], + }, + ); assert_matches!( helper::handle_agg_share_req(&*t.helper, &req) .await @@ -840,22 +710,20 @@ mod test { .leader .unchecked_get_task_config(&t.fixed_size_task_id) .await; - let req = t - .leader_authorized_req( - &t.fixed_size_task_id, - &task_config, - None, - DapMediaType::AggregateShareReq, - AggregateShareReq { - batch_sel: BatchSelector::FixedSizeByBatchId { - batch_id: BatchId(rng.gen()), // Unrecognized batch ID - }, - agg_param: Vec::default(), - report_count: 0, - checksum: [0; 32], + let req = Test::leader_req( + &t.fixed_size_task_id, + &task_config, + None, + DapMediaType::AggregateShareReq, + AggregateShareReq { + batch_sel: BatchSelector::FixedSizeByBatchId { + batch_id: BatchId(rng.gen()), // Unrecognized batch ID }, - ) - .await; + agg_param: Vec::default(), + report_count: 0, + checksum: [0; 32], + }, + ); assert_matches!( helper::handle_agg_share_req(&*t.helper, &req) .await @@ -866,42 +734,6 @@ mod test { async_test_versions! { handle_agg_share_req_invalid_batch_sel } - async fn handle_coll_job_req_unauthorized_request(version: DapVersion) { - let mut rng = thread_rng(); - let t = Test::new(version); - let task_id = &t.time_interval_task_id; - let task_config = t.leader.unchecked_get_task_config(task_id).await; - let collect_job_id = CollectionJobId(rng.gen()); - let mut req = DapRequest { - version: task_config.version, - media_type: Some(DapMediaType::CollectReq), - task_id: Some(*task_id), - resource: DapResource::CollectionJob(collect_job_id), - payload: CollectionReq { - query: Query::default(), - agg_param: Vec::default(), - } - .get_encoded_with_param(&task_config.version) - .unwrap(), - ..Default::default() // Unauthorized request. - }; - - // Expect failure due to missing bearer token. - assert_matches!( - leader::handle_coll_job_req(&*t.leader, &req).await, - Err(DapError::Abort(DapAbort::UnauthorizedRequest { .. })) - ); - - // Expect failure due to incorrect bearer token. - req.sender_auth = Some(BearerToken::from("incorrect auth token!".to_string())); - assert_matches!( - leader::handle_coll_job_req(&*t.leader, &req).await, - Err(DapError::Abort(DapAbort::UnauthorizedRequest { .. })) - ); - } - - async_test_versions! { handle_coll_job_req_unauthorized_request } - async fn handle_agg_job_req_failure_hpke_decrypt_error(version: DapVersion) { let t = Test::new(version); let task_id = &t.time_interval_task_id; @@ -1162,7 +994,7 @@ mod test { let task_config = t.leader.unchecked_get_task_config(task_id).await; // Collector: Create a CollectReq. - let req = t.collector_authorized_req( + let req = Test::collector_req( task_id, &task_config, DapMediaType::CollectReq, @@ -1248,7 +1080,7 @@ mod test { let task_config = t.leader.unchecked_get_task_config(task_id).await; // Collector: Create a CollectReq with a very large batch interval. - let req = t.collector_authorized_req( + let req = Test::collector_req( task_id, &task_config, DapMediaType::CollectReq, @@ -1273,7 +1105,7 @@ mod test { assert_matches!(err, DapError::Abort(DapAbort::BadRequest(s)) => assert_eq!(s, "batch interval too large".to_string())); // Collector: Create a CollectReq with a batch interval in the past. - let req = t.collector_authorized_req( + let req = Test::collector_req( task_id, &task_config, DapMediaType::CollectReq, @@ -1299,7 +1131,7 @@ mod test { assert_matches!(err, DapError::Abort(DapAbort::BadRequest(s)) => assert_eq!(s, "batch interval too far into past".to_string())); // Collector: Create a CollectReq with a batch interval in the future. - let req = t.collector_authorized_req( + let req = Test::collector_req( task_id, &task_config, DapMediaType::CollectReq, @@ -1333,7 +1165,7 @@ mod test { let task_config = t.leader.unchecked_get_task_config(task_id).await; // Collector: Create a CollectReq with a very large batch interval. - let req = t.collector_authorized_req( + let req = Test::collector_req( task_id, &task_config, DapMediaType::CollectReq, @@ -1419,7 +1251,7 @@ mod test { query: task_config.query_for_current_batch_window(t.now), agg_param: Vec::default(), }; - let req = t.collector_authorized_req( + let req = Test::collector_req( task_id, &task_config, DapMediaType::CollectReq, @@ -1472,7 +1304,7 @@ mod test { .leader .unchecked_get_task_config(&t.time_interval_task_id) .await; - let req = t.collector_authorized_req( + let req = Test::collector_req( &t.time_interval_task_id, &task_config, DapMediaType::CollectReq, @@ -1495,7 +1327,7 @@ mod test { .leader .unchecked_get_task_config(&t.fixed_size_task_id) .await; - let req = t.collector_authorized_req( + let req = Test::collector_req( &t.fixed_size_task_id, &task_config, DapMediaType::CollectReq, @@ -1663,7 +1495,6 @@ mod test { resource: DapResource::Undefined, payload: report.get_encoded_with_param(&version).unwrap(), taskprov: Some(taskprov_advertisement.clone()), - ..Default::default() }; leader::handle_upload_req(&*t.leader, &req).await.unwrap(); } diff --git a/crates/daphne/src/taskprov.rs b/crates/daphne/src/taskprov.rs index d1eac2a7b..1da09d471 100644 --- a/crates/daphne/src/taskprov.rs +++ b/crates/daphne/src/taskprov.rs @@ -106,8 +106,8 @@ fn malformed_task_config(task_id: &TaskId, detail: String) -> DapAbort { /// /// The `task_id` is the task ID indicated by the request; if this does not match the derived task /// ID, then we return `Err(DapError::Abort(DapAbort::UnrecognizedTask))`. -pub(crate) fn resolve_advertised_task_config( - req: &'_ DapRequest, +pub(crate) fn resolve_advertised_task_config( + req: &'_ DapRequest, verify_key_init: &[u8; 32], collector_hpke_config: &HpkeConfig, task_id: &TaskId, @@ -126,8 +126,8 @@ pub(crate) fn resolve_advertised_task_config( } /// Check for a taskprov extension in the report, and return it if found. -fn get_taskprov_task_config( - req: &'_ DapRequest, +fn get_taskprov_task_config( + req: &'_ DapRequest, task_id: &TaskId, ) -> Result, DapAbort> { let taskprov_data = if let Some(ref taskprov_base64url) = req.taskprov { @@ -621,13 +621,12 @@ mod test { let task_id = compute_task_id(&taskprov_task_config_bytes); let taskprov_task_config_base64url = encode_base64url(&taskprov_task_config_bytes); - let req = DapRequest::<()> { + let req = DapRequest { version, media_type: None, // ignored by test task_id: Some(task_id), resource: DapResource::Undefined, // ignored by test payload: Vec::default(), // ignored by test - sender_auth: None, // ignored by test taskprov: Some(taskprov_task_config_base64url), }; @@ -677,13 +676,12 @@ mod test { let task_id = compute_task_id(&taskprov_task_config_bytes); let taskprov_task_config_base64url = encode_base64url(&taskprov_task_config_bytes); - let req = DapRequest::<()> { + let req = DapRequest { version, media_type: None, // ignored by test task_id: Some(task_id), resource: DapResource::Undefined, // ignored by test payload: Vec::default(), // ignored by test - sender_auth: None, // ignored by test taskprov: Some(taskprov_task_config_base64url), }; diff --git a/crates/daphne/src/testing/mod.rs b/crates/daphne/src/testing/mod.rs index e558d2131..7f7bf0e4f 100644 --- a/crates/daphne/src/testing/mod.rs +++ b/crates/daphne/src/testing/mod.rs @@ -8,7 +8,6 @@ pub mod report_generator; use crate::{ audit_log::AuditLog, - auth::{BearerToken, BearerTokenProvider}, constants::DapMediaType, fatal_error, hpke::{HpkeConfig, HpkeDecrypter, HpkeKemId, HpkeProvider, HpkeReceiverConfig}, @@ -23,7 +22,7 @@ use crate::{ aggregator::MergeAggShareError, helper, leader::{in_memory_leader::InMemoryLeaderState, WorkItem}, - DapAggregator, DapAuthorizedSender, DapHelper, DapLeader, DapReportInitializer, + DapAggregator, DapHelper, DapLeader, DapReportInitializer, }, taskprov, vdaf::VdafVerifyKey, @@ -577,8 +576,6 @@ pub struct InMemoryAggregator { pub(crate) global_config: DapGlobalConfig, tasks: Mutex>, pub hpke_receiver_config_list: Box<[HpkeReceiverConfig]>, - leader_token: BearerToken, - collector_token: Option, // Not set by Helper collector_hpke_config: HpkeConfig, // aggregation state @@ -591,8 +588,6 @@ pub struct InMemoryAggregator { // taskprov taskprov_vdaf_verify_key_init: [u8; 32], - taskprov_leader_token: BearerToken, - taskprov_collector_token: Option, // Not set by Helper // Leader: Reference to peer. Used to simulate HTTP requests from Leader to Helper, i.e., // implement `DapLeader::send_http_post()` for `InMemoryAggregator`. Not set by the Helper. @@ -605,29 +600,21 @@ impl DeepSizeOf for InMemoryAggregator { global_config, tasks, hpke_receiver_config_list, - leader_token, - collector_token, leader_state_store, agg_store, collector_hpke_config, metrics: _, audit_log: _, taskprov_vdaf_verify_key_init, - taskprov_leader_token, - taskprov_collector_token, peer, } = self; global_config.deep_size_of_children(context) + tasks.deep_size_of_children(context) + hpke_receiver_config_list.deep_size_of_children(context) - + leader_token.deep_size_of_children(context) - + collector_token.deep_size_of_children(context) + leader_state_store.deep_size_of_children(context) + agg_store.deep_size_of_children(context) + collector_hpke_config.deep_size_of_children(context) + taskprov_vdaf_verify_key_init.deep_size_of_children(context) - + taskprov_leader_token.deep_size_of_children(context) - + taskprov_collector_token.deep_size_of_children(context) + peer.deep_size_of_children(context) } } @@ -638,26 +625,20 @@ impl InMemoryAggregator { tasks: impl IntoIterator, hpke_receiver_config_list: impl IntoIterator, global_config: DapGlobalConfig, - leader_token: BearerToken, collector_hpke_config: HpkeConfig, registry: &prometheus::Registry, taskprov_vdaf_verify_key_init: [u8; 32], - taskprov_leader_token: BearerToken, ) -> Self { Self { global_config, tasks: Mutex::new(tasks.into_iter().collect()), hpke_receiver_config_list: hpke_receiver_config_list.into_iter().collect(), - leader_token, - collector_token: None, leader_state_store: Default::default(), agg_store: Default::default(), collector_hpke_config, metrics: DaphnePromMetrics::register(registry).unwrap(), audit_log: MockAuditLog::default(), taskprov_vdaf_verify_key_init, - taskprov_leader_token, - taskprov_collector_token: None, peer: None, } } @@ -667,29 +648,21 @@ impl InMemoryAggregator { tasks: impl IntoIterator, hpke_receiver_config_list: impl IntoIterator, global_config: DapGlobalConfig, - leader_token: BearerToken, - collector_token: impl Into>, collector_hpke_config: HpkeConfig, registry: &prometheus::Registry, taskprov_vdaf_verify_key_init: [u8; 32], - taskprov_leader_token: BearerToken, - taskprov_collector_token: impl Into>, peer: Arc, ) -> Self { Self { global_config, tasks: Mutex::new(tasks.into_iter().collect()), hpke_receiver_config_list: hpke_receiver_config_list.into_iter().collect(), - leader_token, - collector_token: collector_token.into(), leader_state_store: Default::default(), agg_store: Default::default(), collector_hpke_config, metrics: DaphnePromMetrics::register(registry).unwrap(), audit_log: MockAuditLog::default(), taskprov_vdaf_verify_key_init, - taskprov_leader_token, - taskprov_collector_token: taskprov_collector_token.into(), peer: peer.into(), } } @@ -717,39 +690,6 @@ impl InMemoryAggregator { } } -#[async_trait] -impl BearerTokenProvider for InMemoryAggregator { - type WrappedBearerToken<'a> = &'a BearerToken; - - async fn get_leader_bearer_token_for<'s>( - &'s self, - _task_id: &'s TaskId, - task_config: &DapTaskConfig, - ) -> Result>, DapError> { - if task_config.method_is_taskprov() { - Ok(Some(&self.taskprov_leader_token)) - } else { - Ok(Some(&self.leader_token)) - } - } - - async fn get_collector_bearer_token_for<'s>( - &'s self, - _task_id: &'s TaskId, - task_config: &DapTaskConfig, - ) -> Result>, DapError> { - if task_config.method_is_taskprov() { - Ok(Some(self.taskprov_collector_token.as_ref().expect( - "InMemoryAggregator not configured with taskprov collector token", - ))) - } else { - Ok(Some(self.collector_token.as_ref().expect( - "InMemoryAggregator not configured with collector token", - ))) - } - } -} - #[async_trait] impl HpkeProvider for InMemoryAggregator { type WrappedHpkeConfig<'a> = &'a HpkeConfig; @@ -799,22 +739,6 @@ impl HpkeDecrypter for InMemoryAggregator { } } -#[async_trait] -impl DapAuthorizedSender for InMemoryAggregator { - async fn authorize( - &self, - task_id: &TaskId, - task_config: &DapTaskConfig, - media_type: &DapMediaType, - _payload: &[u8], - ) -> Result { - Ok(self - .authorize_with_bearer_token(task_id, task_config, media_type) - .await? - .clone()) - } -} - #[async_trait] impl DapReportInitializer for InMemoryAggregator { fn valid_report_time_range(&self) -> Range { @@ -841,20 +765,12 @@ impl DapReportInitializer for InMemoryAggregator { } #[async_trait] -impl DapAggregator for InMemoryAggregator { +impl DapAggregator for InMemoryAggregator { // The lifetimes on the traits ensure that we can return a reference to a task config stored by // the DapAggregator. (See DaphneWorkerConfig for an example.) For simplicity, InMemoryAggregator // clones the task config as needed. type WrappedDapTaskConfig<'a> = DapTaskConfig; - async fn unauthorized_reason( - &self, - task_config: &DapTaskConfig, - req: &DapRequest, - ) -> Result, DapError> { - self.bearer_token_authorized(task_config, req).await - } - async fn get_global_config(&self) -> Result { Ok(self.global_config.clone()) } @@ -882,7 +798,7 @@ impl DapAggregator for InMemoryAggregator { async fn taskprov_put( &self, - req: &DapRequest, + req: &DapRequest, task_config: DapTaskConfig, ) -> Result<(), DapError> { let task_id = req.task_id().map_err(DapError::Abort)?; @@ -1063,10 +979,10 @@ impl DapAggregator for InMemoryAggregator { } #[async_trait] -impl DapHelper for InMemoryAggregator {} +impl DapHelper for InMemoryAggregator {} #[async_trait] -impl DapLeader for InMemoryAggregator { +impl DapLeader for InMemoryAggregator { async fn put_report(&self, report: &Report, task_id: &TaskId) -> Result<(), DapError> { let task_config = self .get_task_config_for(task_id) @@ -1152,11 +1068,7 @@ impl DapLeader for InMemoryAggregator { .finish_collect_job(task_id, coll_job_id, collection) } - async fn send_http_post( - &self, - req: DapRequest, - _url: Url, - ) -> Result { + async fn send_http_post(&self, req: DapRequest, _url: Url) -> Result { match req.media_type { Some(DapMediaType::AggregationJobInitReq) => Ok(helper::handle_agg_job_req( &**self.peer.as_ref().expect("peer not configured"), @@ -1175,11 +1087,7 @@ impl DapLeader for InMemoryAggregator { } } - async fn send_http_put( - &self, - req: DapRequest, - _url: Url, - ) -> Result { + async fn send_http_put(&self, req: DapRequest, _url: Url) -> Result { if req.media_type == Some(DapMediaType::AggregationJobInitReq) { Ok(helper::handle_agg_job_req( &**self.peer.as_ref().expect("peer not configured"),