Skip to content

Commit

Permalink
Parameterize DapRequest by its payload type
Browse files Browse the repository at this point in the history
  • Loading branch information
mendess committed Sep 30, 2024
1 parent 3358f91 commit 570dfc7
Show file tree
Hide file tree
Showing 11 changed files with 375 additions and 253 deletions.
29 changes: 22 additions & 7 deletions crates/daphne-server/src/roles/leader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@ use daphne::{
fatal_error,
messages::{BatchId, BatchSelector, Collection, CollectionJobId, Report, TaskId},
roles::{leader::WorkItem, DapAggregator, DapLeader},
DapAggregationParam, DapCollectionJob, DapError, DapRequest, DapResponse,
DapAggregationParam, DapCollectionJob, DapError, DapRequest, DapResponse, DapVersion,
};
use daphne_service_utils::http_headers;
use http::StatusCode;
use prio::codec::ParameterizedEncode;
use tracing::{error, info};
use url::Url;

Expand Down Expand Up @@ -97,22 +98,31 @@ impl DapLeader for crate::App {
self.test_leader_state.lock().await.enqueue_work(items)
}

async fn send_http_post(&self, req: DapRequest, url: Url) -> Result<DapResponse, DapError> {
async fn send_http_post<M>(&self, req: DapRequest<M>, url: Url) -> Result<DapResponse, DapError>
where
M: Send + ParameterizedEncode<DapVersion>,
{
self.send_http(req, Method::POST, url).await
}

async fn send_http_put(&self, req: DapRequest, url: Url) -> Result<DapResponse, DapError> {
async fn send_http_put<M>(&self, req: DapRequest<M>, url: Url) -> Result<DapResponse, DapError>
where
M: Send + ParameterizedEncode<DapVersion>,
{
self.send_http(req, Method::PUT, url).await
}
}

impl crate::App {
async fn send_http(
async fn send_http<M>(
&self,
mut req: DapRequest,
req: DapRequest<M>,
method: Method,
url: Url,
) -> Result<DapResponse, DapError> {
) -> Result<DapResponse, DapError>
where
M: Send + ParameterizedEncode<DapVersion>,
{
use reqwest::header::{self, HeaderMap, HeaderName, HeaderValue};
let content_type = req
.media_type
Expand Down Expand Up @@ -168,10 +178,15 @@ impl crate::App {
);
}

let (req, payload) = req.take_payload();
let req_builder = self
.http
.request(method, url.clone())
.body(std::mem::take(&mut req.payload))
.body(
payload
.get_encoded_with_param(&req.version)
.map_err(|e| DapAbort::from_codec_error(e, req.task_id))?,
)
.headers(headers);

let start = Instant::now();
Expand Down
147 changes: 120 additions & 27 deletions crates/daphne-server/src/router/extractor.rs
Original file line number Diff line number Diff line change
@@ -1,30 +1,104 @@
// Copyright (c) 2024 Cloudflare, Inc. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause

use super::{AxumDapResponse, DaphneService};
use std::io::Cursor;

use axum::{
async_trait,
body::HttpBody,
body::{Bytes, HttpBody},
extract::{FromRequest, FromRequestParts, Path},
};
use daphne::{
constants::DapMediaType,
error::DapAbort,
fatal_error,
messages::{request::DapRequestMeta, AggregationJobId, CollectionJobId, TaskId},
DapRequest, DapResource, DapVersion,
messages::{
AggregateShareReq, AggregationJobId, AggregationJobInitReq, CollectionJobId, CollectionReq,
Report, TaskId,
},
DapRequest, DapRequestMeta, DapResource, DapVersion,
};
use daphne_service_utils::{bearer_token::BearerToken, http_headers, metrics};
use http::{header::CONTENT_TYPE, HeaderMap, Request};
use prio::codec::ParameterizedDecode;
use serde::Deserialize;

use super::{AxumDapResponse, DaphneService};

pub trait DecodeFromDapHttpBody: Sized {
fn decode_from_http_body(bytes: Bytes, meta: &DapRequestMeta) -> Result<Self, DapAbort>;
}

pub enum AggregationJobReq {
Init(AggregationJobInitReq),
// TODO: support continue requests
}

impl DecodeFromDapHttpBody for AggregationJobReq {
fn decode_from_http_body(bytes: Bytes, meta: &DapRequestMeta) -> Result<Self, DapAbort> {
let mut cursor = Cursor::new(bytes.as_ref());
let media_type = meta.get_checked_media_type([DapMediaType::AggregationJobInitReq])?;
match media_type {
DapMediaType::AggregationJobInitReq => Ok(Self::Init(
AggregationJobInitReq::decode_with_param(&meta.version, &mut cursor)
.map_err(|e| DapAbort::from_codec_error(e, meta.task_id))?,
)),
// we list out all the variants so we remember to come here when we implement
// AggregateShareContReq
DapMediaType::AggregateShareReq
| DapMediaType::Report
| DapMediaType::CollectReq
| DapMediaType::HpkeConfigList
| DapMediaType::AggregationJobResp
| DapMediaType::AggregateShare
| DapMediaType::Collection => {
unreachable!("get_checked_media_type already filtered these out")
}
}
}
}

impl DecodeFromDapHttpBody for AggregateShareReq {
fn decode_from_http_body(bytes: Bytes, meta: &DapRequestMeta) -> Result<Self, DapAbort> {
let mut cursor = Cursor::new(bytes.as_ref());
meta.get_checked_media_type([DapMediaType::AggregateShareReq])?;
Self::decode_with_param(&meta.version, &mut cursor)
.map_err(|e| DapAbort::from_codec_error(e, meta.task_id))
}
}

impl DecodeFromDapHttpBody for Report {
fn decode_from_http_body(bytes: Bytes, meta: &DapRequestMeta) -> Result<Self, DapAbort> {
let mut cursor = Cursor::new(bytes.as_ref());
meta.get_checked_media_type([DapMediaType::Report])?;
Self::decode_with_param(&meta.version, &mut cursor)
.map_err(|e| DapAbort::from_codec_error(e, meta.task_id))
}
}

impl DecodeFromDapHttpBody for CollectionReq {
fn decode_from_http_body(bytes: Bytes, meta: &DapRequestMeta) -> Result<Self, DapAbort> {
let mut cursor = Cursor::new(bytes.as_ref());
meta.get_checked_media_type([DapMediaType::CollectReq])?;
Self::decode_with_param(&meta.version, &mut cursor)
.map_err(|e| DapAbort::from_codec_error(e, meta.task_id))
}
}

impl DecodeFromDapHttpBody for () {
fn decode_from_http_body(_bytes: Bytes, _meta: &DapRequestMeta) -> Result<Self, DapAbort> {
Ok(())
}
}

/// An axum extractor capable of parsing a [`DapRequest`].
#[derive(Debug)]
pub(super) struct UnauthenticatedDapRequestExtractor(pub DapRequest);
pub(super) struct UnauthenticatedDapRequestExtractor<P>(pub DapRequest<P>);

#[async_trait]
impl<S, B> FromRequest<S, B> for UnauthenticatedDapRequestExtractor
impl<S, B, P> FromRequest<S, B, P> for UnauthenticatedDapRequestExtractor<P>
where
P: DecodeFromDapHttpBody,
S: DaphneService + Send + Sync,
B: HttpBody + Send + 'static,
<B as HttpBody>::Data: Send,
Expand Down Expand Up @@ -76,8 +150,8 @@ where
None
};

// TODO(mendess): this is very eager, we could redesign DapResponse later to allow for
// streaming of data.
// TODO(mendess): this allocates needlessly, if prio supported some kind of
// AsyncParameterizedDecode we could avoid this allocation
let payload = hyper::body::to_bytes(body).await;

let Ok(payload) = payload else {
Expand Down Expand Up @@ -107,17 +181,19 @@ where
(task_id, resource)
};

let request = DapRequest {
meta: DapRequestMeta {
version,
task_id,
resource,
media_type,
taskprov: extract_header_as_string(&parts.headers, http_headers::DAP_TASKPROV),
},
payload: payload.to_vec(),
let meta = DapRequestMeta {
version,
task_id,
taskprov: extract_header_as_string(&parts.headers, http_headers::DAP_TASKPROV),
media_type,
resource,
};

let payload = P::decode_from_http_body(payload, &meta)
.map_err(|e| AxumDapResponse::new_error(e, state.server_metrics()))?;

let request = DapRequest { meta, payload };

Ok(UnauthenticatedDapRequestExtractor(request))
}
}
Expand All @@ -126,11 +202,12 @@ where
///
/// This extractor asserts that the request is authenticated.
#[derive(Debug)]
pub(super) struct DapRequestExtractor(pub DapRequest);
pub(super) struct DapRequestExtractor<P>(pub DapRequest<P>);

#[async_trait]
impl<S, B> FromRequest<S, B> for DapRequestExtractor
impl<S, B, P> FromRequest<S, B, P> for DapRequestExtractor<P>
where
P: DecodeFromDapHttpBody + Send + Sync,
S: DaphneService + Send + Sync,
B: HttpBody + Send + 'static,
<B as HttpBody>::Data: Send,
Expand Down Expand Up @@ -213,8 +290,9 @@ mod test {
use daphne::{
async_test_versions,
constants::DapMediaType,
error::DapAbort,
messages::{AggregationJobId, Base64Encode, CollectionJobId, TaskId},
DapRequest, DapResource, DapSender, DapVersion,
DapRequestMeta, DapResource, DapSender, DapVersion,
};
use daphne_service_utils::{
bearer_token::BearerToken, http_headers, metrics::DaphnePromServiceMetrics,
Expand All @@ -224,12 +302,27 @@ mod test {
use tokio::sync::mpsc::{self, Sender};
use tower::ServiceExt;

use crate::router::DapRequestExtractor;

use super::UnauthenticatedDapRequestExtractor;
use super::DecodeFromDapHttpBody;

const BEARER_TOKEN: &str = "test-token";

// in these tests we don't care about the body, so we can type alias it away
#[derive(Debug)]
struct EmptyBody;

impl DecodeFromDapHttpBody for EmptyBody {
fn decode_from_http_body(
_: axum::body::Bytes,
_: &DapRequestMeta,
) -> Result<Self, DapAbort> {
Ok(Self)
}
}

type DapRequest = daphne::DapRequest<EmptyBody>;
type UnauthenticatedDapRequestExtractor = super::UnauthenticatedDapRequestExtractor<EmptyBody>;
type DapRequestExtractor = super::DapRequestExtractor<EmptyBody>;

/// Return a function that will parse a request using the [`UnAuthenticatedDapRequestExtractor`] and return
/// the parsed request.
///
Expand Down Expand Up @@ -277,16 +370,16 @@ mod test {

async fn handler(
State(ch): State<Arc<Channel>>,
UnauthenticatedDapRequestExtractor(req): UnauthenticatedDapRequestExtractor,
req: UnauthenticatedDapRequestExtractor,
) -> impl IntoResponse {
ch.send(req).await.unwrap();
ch.send(req.0).await.unwrap();
}

async fn auth_handler(
State(ch): State<Arc<Channel>>,
DapRequestExtractor(req): DapRequestExtractor,
req: DapRequestExtractor,
) -> impl IntoResponse {
ch.send(req).await.unwrap();
ch.send(req.0).await.unwrap();
}

let (tx, mut rx) = mpsc::channel(1);
Expand Down
22 changes: 9 additions & 13 deletions crates/daphne-server/src/router/helper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@ use axum::{
routing::{post, put},
};
use daphne::{
constants::DapMediaType,
error::DapAbort,
messages::AggregateShareReq,
roles::{helper, DapHelper},
};
use http::StatusCode;

use crate::{roles::fetch_replay_protection_override, App};

use super::{AxumDapResponse, DapRequestExtractor, DaphneService};
use super::{extractor::AggregationJobReq, AxumDapResponse, DapRequestExtractor, DaphneService};

pub(super) fn add_helper_routes<B>(router: super::Router<App, B>) -> super::Router<App, B>
where
Expand All @@ -43,13 +42,14 @@ where
)]
async fn agg_job(
State(app): State<Arc<App>>,
DapRequestExtractor(req): DapRequestExtractor,
DapRequestExtractor(req): DapRequestExtractor<AggregationJobReq>,
) -> AxumDapResponse {
match req.media_type {
Some(DapMediaType::AggregationJobInitReq) => {
let (req, payload) = req.take_payload();
match payload {
AggregationJobReq::Init(agg_init_req) => {
let resp = helper::handle_agg_job_init_req(
&*app,
&req,
req.map(|()| agg_init_req),
fetch_replay_protection_override(app.kv()).await,
)
.await;
Expand All @@ -59,10 +59,6 @@ async fn agg_job(
StatusCode::CREATED,
)
}
m => AxumDapResponse::new_error(
DapAbort::BadRequest(format!("unexpected media type: {m:?}")),
app.server_metrics(),
),
}
}

Expand All @@ -76,13 +72,13 @@ async fn agg_job(
)]
async fn agg_share<A>(
State(app): State<Arc<A>>,
DapRequestExtractor(req): DapRequestExtractor,
DapRequestExtractor(req): DapRequestExtractor<AggregateShareReq>,
) -> AxumDapResponse
where
A: DapHelper + DaphneService + Send + Sync,
{
AxumDapResponse::from_result(
helper::handle_agg_share_req(&*app, &req).await,
helper::handle_agg_share_req(&*app, req).await,
app.server_metrics(),
)
}
Loading

0 comments on commit 570dfc7

Please sign in to comment.