Skip to content

Commit

Permalink
Use expiration time when updating KV storage
Browse files Browse the repository at this point in the history
  • Loading branch information
cjpatton authored and lbaquerofierro committed Jul 22, 2024
1 parent 48db881 commit cd0f036
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 19 deletions.
13 changes: 11 additions & 2 deletions crates/daphne-server/src/roles/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,11 @@ impl DapAggregator<DaphneAuth> for crate::App {
};
if let Err(e) = self
.kv()
.put::<kv::prefix::TaskprovOptInParam>(task_id, param.clone())
.put_with_expiration::<kv::prefix::TaskprovOptInParam>(
task_id,
param.clone(),
param.not_before,
)
.await
{
tracing::warn!(error = ?e, "failed to store taskprov opt in param");
Expand All @@ -274,10 +278,15 @@ impl DapAggregator<DaphneAuth> for crate::App {
task_config: DapTaskConfig,
) -> Result<(), DapError> {
let task_id = req.task_id().map_err(DapError::Abort)?;
let expiration_time = task_config.not_after;

if self.service_config.role.is_leader() || req.taskprov.is_none() {
self.kv()
.put::<kv::prefix::TaskConfig>(task_id, task_config)
.put_with_expiration::<kv::prefix::TaskConfig>(
task_id,
task_config,
expiration_time,
)
.await
.map_err(|e| fatal_error!(err = ?e))?;
} else {
Expand Down
3 changes: 2 additions & 1 deletion crates/daphne-server/src/roles/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ mod test_utils {

if self
.kv()
.put_if_not_exists::<kv::prefix::TaskConfig>(
.put_if_not_exists_with_expiration::<kv::prefix::TaskConfig>(
&cmd.task_id,
DapTaskConfig {
version,
Expand All @@ -230,6 +230,7 @@ mod test_utils {
method: Default::default(),
num_agg_span_shards: NonZeroUsize::new(4).unwrap(),
},
cmd.task_expiration,
)
.await
.map_err(|e| fatal_error!(err = ?e))?
Expand Down
93 changes: 82 additions & 11 deletions crates/daphne-server/src/storage_proxy_connection/kv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ use crate::StorageProxyConfig;

use super::{status_http_1_0_to_reqwest_0_11, Error};
pub(crate) use cache::Cache;
use daphne::messages::Time;
use daphne_service_utils::http_headers::STORAGE_PROXY_PUT_KV_EXPIRATION;

pub(crate) struct Kv<'h> {
config: &'h StorageProxyConfig,
Expand Down Expand Up @@ -256,24 +258,59 @@ impl<'h> Kv<'h> {
skip_all,
fields(key, prefix = std::any::type_name::<P>()),
)]
pub async fn put<P>(&self, key: &P::Key, value: P::Value) -> Result<(), Error>
pub async fn put_internal<P>(
&self,
key: &P::Key,
value: P::Value,
expiration: Option<Time>,
) -> Result<(), Error>
where
P: KvPrefix,
P::Key: std::fmt::Debug,
P::Value: Serialize,
{
let key = Self::to_key::<P>(key);
tracing::debug!(key, "PUT");
self.http

let mut request = self
.http
.post(self.config.url.join(&key).unwrap())
.bearer_auth(&self.config.auth_token)
.body(serde_json::to_vec(&value).unwrap())
.send()
.await?
.error_for_status()?;
.body(serde_json::to_vec(&value).unwrap());

if let Some(expiration) = expiration {
request = request.header(STORAGE_PROXY_PUT_KV_EXPIRATION, expiration.to_string());
}

request.send().await?.error_for_status()?;

self.cache.write().await.put::<P>(key, Some(value.into()));
Ok(())
}

pub async fn put_with_expiration<P>(
&self,
key: &P::Key,
value: P::Value,
expiration: Time,
) -> Result<(), Error>
where
P: KvPrefix,
P::Key: std::fmt::Debug,
P::Value: Serialize,
{
self.put_internal::<P>(key, value, Some(expiration)).await
}

pub async fn put<P>(&self, key: &P::Key, value: P::Value) -> Result<(), Error>
where
P: KvPrefix,
P::Key: std::fmt::Debug,
P::Value: Serialize,
{
self.put_internal::<P>(key, value, None).await
}

/// Stores a value in kv if it doesn't already exist.
///
/// If the value already exists, returns the passed in value inside the Ok variant.
Expand All @@ -282,10 +319,11 @@ impl<'h> Kv<'h> {
skip_all,
fields(key, prefix = std::any::type_name::<P>()),
)]
pub async fn put_if_not_exists<P>(
pub async fn put_if_not_exists_internal<P>(
&self,
key: &P::Key,
value: P::Value,
expiration: Option<Time>,
) -> Result<Option<P::Value>, Error>
where
P: KvPrefix,
Expand All @@ -294,13 +332,18 @@ impl<'h> Kv<'h> {
let key = Self::to_key::<P>(key);

tracing::debug!(key, "PUT if not exists");
let response = self

let mut request = self
.http
.put(self.config.url.join(&key).unwrap())
.bearer_auth(&self.config.auth_token)
.body(serde_json::to_vec(&value).unwrap())
.send()
.await?;
.body(serde_json::to_vec(&value).unwrap());

if let Some(expiration) = expiration {
request = request.header(STORAGE_PROXY_PUT_KV_EXPIRATION, expiration.to_string());
}

let response = request.send().await?;

if response.status() == status_http_1_0_to_reqwest_0_11(StatusCode::CONFLICT) {
Ok(Some(value))
Expand All @@ -311,6 +354,34 @@ impl<'h> Kv<'h> {
}
}

pub async fn put_if_not_exists_with_expiration<P>(
&self,
key: &P::Key,
value: P::Value,
expiration: Time,
) -> Result<Option<P::Value>, Error>
where
P: KvPrefix,
P::Key: std::fmt::Debug,
P::Value: Serialize,
{
self.put_if_not_exists_internal::<P>(key, value, Some(expiration))
.await
}

pub async fn put_if_not_exists<P>(
&self,
key: &P::Key,
value: P::Value,
) -> Result<Option<P::Value>, Error>
where
P: KvPrefix,
P::Key: std::fmt::Debug,
P::Value: Serialize,
{
self.put_if_not_exists_internal::<P>(key, value, None).await
}

#[tracing::instrument(skip_all, fields(key, prefix = std::any::type_name::<P>()))]
pub async fn only_cache_put<P>(&self, key: &P::Key, value: P::Value)
where
Expand Down
1 change: 1 addition & 0 deletions crates/daphne-service-utils/src/http_headers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
pub const HPKE_SIGNATURE: &str = "x-hpke-config-signature";
pub const DAP_AUTH_TOKEN: &str = "dap-auth-token";
pub const DAP_TASKPROV: &str = "dap-taskprov";
pub const STORAGE_PROXY_PUT_KV_EXPIRATION: &str = "x-daphne-storage-proxy-kv-put-expiration";
32 changes: 27 additions & 5 deletions crates/daphne-worker/src/storage_proxy/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,17 @@ mod metrics;

use std::{sync::OnceLock, time::Duration};

pub use self::metrics::Metrics;
use daphne::auth::BearerToken;
use daphne::messages::Time;
use daphne_service_utils::durable_requests::{
DurableRequest, ObjectIdFrom, DO_PATH_PREFIX, KV_PATH_PREFIX,
};
use daphne_service_utils::http_headers::STORAGE_PROXY_PUT_KV_EXPIRATION;
use prometheus::Registry;
use tracing::warn;
use url::Url;
use worker::{js_sys::Uint8Array, Delay, Env, Request, RequestInit, Response};

pub use self::metrics::Metrics;

const KV_BINDING_DAP_CONFIG: &str = "DAP_CONFIG";

struct RequestContext<'e> {
Expand Down Expand Up @@ -197,6 +197,17 @@ async fn storage_purge(ctx: &RequestContext<'_>) -> worker::Result<Response> {
Response::empty()
}

fn parse_expiration_header(ctx: &RequestContext) -> Result<Option<Time>, worker::Error> {
let expiration_header = ctx.req.headers().get(STORAGE_PROXY_PUT_KV_EXPIRATION)?;
expiration_header
.map(|expiration| {
expiration.parse::<Time>().map_err(|e| {
worker::Error::RustError(format!("Failed to parse expiration header: {e:?}"))
})
})
.transpose()
}

/// Handle a kv request.
async fn handle_kv_request(ctx: &mut RequestContext<'_>, key: &str) -> worker::Result<Response> {
match ctx.req.method() {
Expand All @@ -209,12 +220,17 @@ async fn handle_kv_request(ctx: &mut RequestContext<'_>, key: &str) -> worker::R
}
}
worker::Method::Post => {
let expiration_unix_timestamp = parse_expiration_header(ctx)?;

match ctx
.env
.kv(KV_BINDING_DAP_CONFIG)?
.put_bytes(key, &ctx.req.bytes().await?)
{
Ok(put) => {
Ok(mut put) => {
if let Some(expiration_unix_timestamp) = expiration_unix_timestamp {
put = put.expiration(expiration_unix_timestamp);
}
if let Err(error) = put.execute().await {
tracing::warn!(
?error,
Expand All @@ -232,7 +248,10 @@ async fn handle_kv_request(ctx: &mut RequestContext<'_>, key: &str) -> worker::R

Response::empty()
}

worker::Method::Put => {
let expiration_unix_timestamp = parse_expiration_header(ctx)?;

let kv = ctx.env.kv(KV_BINDING_DAP_CONFIG)?;
if kv
.list()
Expand All @@ -246,7 +265,10 @@ async fn handle_kv_request(ctx: &mut RequestContext<'_>, key: &str) -> worker::R
Response::error(String::new(), 409 /* Conflict */)
} else {
match kv.put_bytes(key, &ctx.req.bytes().await?) {
Ok(put) => {
Ok(mut put) => {
if let Some(expiration_unix_timestamp) = expiration_unix_timestamp {
put = put.expiration(expiration_unix_timestamp);
}
if let Err(error) = put.execute().await {
tracing::warn!(
?error,
Expand Down

0 comments on commit cd0f036

Please sign in to comment.