Skip to content

Commit

Permalink
Run handle_payload in an independent tokio Task
Browse files Browse the repository at this point in the history
  • Loading branch information
vruello committed Nov 22, 2023
1 parent b768aed commit b37e4f1
Show file tree
Hide file tree
Showing 5 changed files with 186 additions and 51 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Ensure that openwecd shutdowns gracefully even if hyper server is not responding
- Improve the logging of failed Kerberos authentications: missing authorization header warning is now in DEBUG level

### Fixed

- Fixed an issue that could result in an inconsistent state when a client unexpectedly closes an HTTP connection.

## [0.1.0] - 2023-05-30

Initial commit containing most of the desired features. The project is still under heavy development and production use without a backup solution should be avoided.
4 changes: 3 additions & 1 deletion openwec.conf.sample.toml
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@
# - ip
# - port
# - principal
# - conn_status: 'X' (connection aborted before the response completed)
# '+' (connection may be kept alive after the response is sent)
# Default value is None, meaning "{X(ip)}:{X(port)} - {X(principal)} [{d}] \"{X(http_uri)}\" {X(http_status)} {X(response_time)}{n}"
# access_logs_pattern = None

Expand Down Expand Up @@ -239,4 +241,4 @@
# Server private key, corresponding to the certificate
# server_private_key = "/etc/server-key.pem"

## End of TLS configuration
## End of TLS configuration
215 changes: 172 additions & 43 deletions server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ use std::{env, mem};
use subscription::{reload_subscriptions_task, Subscriptions};
use tls_listener::TlsListener;
use tokio::signal::unix::SignalKind;
use tokio::sync::mpsc;
use tokio::sync::{mpsc, oneshot};
use tokio_rustls::server::TlsStream;
use tokio_rustls::TlsAcceptor;
use tokio_util::sync::CancellationToken;
Expand Down Expand Up @@ -273,7 +273,7 @@ async fn handle_payload(
db: Db,
subscriptions: Subscriptions,
heartbeat_tx: mpsc::Sender<WriteHeartbeatMessage>,
request_data: RequestData,
request_data: &RequestData,
request_payload: Option<String>,
auth_ctx: &AuthenticationContext,
) -> Result<(StatusCode, Option<String>)> {
Expand Down Expand Up @@ -313,15 +313,35 @@ async fn handle_payload(
}
}

enum ConnectionStatus {
// Connection aborted before the response completed.
Aborted,
// Connection may be kept alive after the response is sent.
Alive,
}

impl ConnectionStatus {
pub fn as_str(&self) -> &str {
// This is inspired by %X of Apache httpd:
// https://httpd.apache.org/docs/current/mod/mod_log_config.html
match self {
Self::Aborted => "X",
Self::Alive => "+",
}
}
}

fn log_response(
addr: &SocketAddr,
method: &str,
uri: &str,
start: &Instant,
status: StatusCode,
principal: &str,
conn_status: ConnectionStatus,
) {
let duration: f32 = start.elapsed().as_micros() as f32;

// MDC is thread related, so it should be safe to use it in a non-async
// function.
log_mdc::insert("http_status", status.as_str());
Expand All @@ -331,12 +351,20 @@ fn log_response(
log_mdc::insert("ip", addr.ip().to_string());
log_mdc::insert("port", addr.port().to_string());
log_mdc::insert("principal", principal);
log_mdc::insert("conn_status", conn_status.as_str());

// Empty message, logging pattern should use MDC
info!(target: ACCESS_LOGGER, "");
log_mdc::clear();
}

fn build_error_response(status: StatusCode) -> Response<Body> {
Response::builder()
.status(status)
.body(Body::empty())
.expect("Failed to build HTTP response")
}

async fn handle(
server: ServerSettings,
collector: Collector,
Expand Down Expand Up @@ -365,11 +393,16 @@ async fn handle(
Ok((principal, builder)) => (principal, builder),
Err(_) => {
let status = StatusCode::UNAUTHORIZED;
log_response(&addr, &method, &uri, &start, status, "-");
return Ok(Response::builder()
.status(status)
.body(Body::empty())
.expect("Failed to build HTTP response"));
log_response(
&addr,
&method,
&uri,
&start,
status,
"-",
ConnectionStatus::Alive,
);
return Ok(build_error_response(status));
}
};

Expand All @@ -380,11 +413,16 @@ async fn handle(
Err(e) => {
error!("Failed to compute request data: {:?}", e);
let status = StatusCode::NOT_FOUND;
log_response(&addr, &method, &uri, &start, status, &principal);
return Ok(Response::builder()
.status(status)
.body(Body::empty())
.expect("Failed to build HTTP response"));
log_response(
&addr,
&method,
&uri,
&start,
status,
&principal,
ConnectionStatus::Alive,
);
return Ok(build_error_response(status));
}
};

Expand All @@ -394,11 +432,16 @@ async fn handle(
Err(e) => {
error!("Failed to retrieve request payload: {:?}", e);
let status = StatusCode::BAD_REQUEST;
log_response(&addr, &method, &uri, &start, status, &principal);
return Ok(Response::builder()
.status(status)
.body(Body::empty())
.expect("Failed to build HTTP response"));
log_response(
&addr,
&method,
&uri,
&start,
status,
&principal,
ConnectionStatus::Alive,
);
return Ok(build_error_response(status));
}
};

Expand All @@ -408,27 +451,100 @@ async fn handle(
);

// Handle request payload, and retrieves response payload
let (status, response_payload) = match handle_payload(
&server,
&collector,
db,
subscriptions,
heartbeat_tx,
request_data,
request_payload,
&auth_ctx,
)
.await
{
Ok((status, response_payload)) => (status, response_payload),
Err(e) => {
error!("Failed to compute a response payload to request: {:?}", e);
//
// It seems that Hyper can abort the Service future at any time (for example if the client
// closes the connection), meaning that any ".await" can be a dead end.
// We want to ensure that the payload handling cannot be aborted unexpectedly resulting
// in an inconsistent state.
// To achieve that, the handle_payload function is executed in an independent Tokio task.
//
// In practice, Windows clients appear to close connections to their configured WEC server
// when they (re-)apply group policies.

// handle_payload task result will be returned using a oneshot channel
let (tx, rx) = oneshot::channel();

// The following variables need to be cloned because they are moved in the spawned closure
let auth_ctx_cloned = auth_ctx.clone();
let method_cloned = method.clone();
let uri_cloned = uri.clone();
let principal_cloned = principal.clone();

tokio::spawn(async move {
let res = handle_payload(
&server,
&collector,
db,
subscriptions,
heartbeat_tx,
&request_data,
request_payload,
&auth_ctx_cloned,
)
.await;
if let Err(e) = &res {
error!(
"Failed to compute a response payload to request (from {}:{}): {:?}",
request_data.remote_addr().ip(),
request_data.remote_addr().port(),
e
);
}
if let Err(value) = tx.send(res) {
debug!(
"Could not send handle_payload result to handling Service for {}:{} (receiver dropped). Result was: {:?}",
request_data.remote_addr().ip(),
request_data.remote_addr().port(),
value
);
// Log this response with conn_status = Aborted
let status = match value {
Ok((status, _)) => status,
Err(_) => StatusCode::INTERNAL_SERVER_ERROR,
};
log_response(
request_data.remote_addr(),
&method_cloned,
&uri_cloned,
&start,
status,
&principal_cloned,
ConnectionStatus::Aborted,
);
}
});

// Wait for the handle_payload task to answer using the oneshot channel
let (status, response_payload) = match rx.await {
Ok(Ok((status, response_payload))) => (status, response_payload),
Ok(Err(_)) => {
// Ok(Err(_)): the handle_payload task returned an Err
let status = StatusCode::INTERNAL_SERVER_ERROR;
log_response(&addr, &method, &uri, &start, status, &principal);
return Ok(Response::builder()
.status(status)
.body(Body::empty())
.expect("Failed to build HTTP response"));
log_response(
&addr,
&method,
&uri,
&start,
status,
&principal,
ConnectionStatus::Alive,
);
return Ok(build_error_response(status));
}
Err(_) => {
// Err(_): the handle_payload task "sender" has been dropped (should not happen)
error!("handle_payload task sender has been dropped. Maybe the task panicked?");
let status = StatusCode::INTERNAL_SERVER_ERROR;
log_response(
&addr,
&method,
&uri,
&start,
status,
&principal,
ConnectionStatus::Alive,
);
return Ok(build_error_response(status));
}
};

Expand All @@ -445,15 +561,28 @@ async fn handle(
Err(e) => {
error!("Failed to build HTTP response: {:?}", e);
let status = StatusCode::INTERNAL_SERVER_ERROR;
log_response(&addr, &method, &uri, &start, status, &principal);
return Ok(Response::builder()
.status(status)
.body(Body::empty())
.expect("Failed to build HTTP response"));
log_response(
&addr,
&method,
&uri,
&start,
status,
&principal,
ConnectionStatus::Alive,
);
return Ok(build_error_response(status));
}
};

log_response(&addr, &method, &uri, &start, response.status(), &principal);
log_response(
&addr,
&method,
&uri,
&start,
response.status(),
&principal,
ConnectionStatus::Alive,
);
Ok(response)
}

Expand Down
8 changes: 4 additions & 4 deletions server/src/logic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -441,21 +441,21 @@ pub async fn handle_message(
db: Db,
subscriptions: Subscriptions,
heartbeat_tx: mpsc::Sender<WriteHeartbeatMessage>,
request_data: RequestData,
request_data: &RequestData,
message: &Message,
auth_ctx: &AuthenticationContext,
) -> Result<Response> {
let action = message.action()?;
debug!("Received {} request", action);

if action == ACTION_ENUMERATE {
handle_enumerate(collector, &db, subscriptions, &request_data, auth_ctx)
handle_enumerate(collector, &db, subscriptions, request_data, auth_ctx)
.await
.context("Failed to handle Enumerate action")
} else if action == ACTION_END || action == ACTION_SUBSCRIPTION_END {
Ok(Response::err(StatusCode::OK))
} else if action == ACTION_HEARTBEAT {
handle_heartbeat(subscriptions, heartbeat_tx, &request_data, message)
handle_heartbeat(subscriptions, heartbeat_tx, request_data, message)
.await
.context("Failed to handle Heartbeat action")
} else if action == ACTION_EVENTS {
Expand All @@ -464,7 +464,7 @@ pub async fn handle_message(
&db,
subscriptions,
heartbeat_tx,
&request_data,
request_data,
message,
)
.await
Expand Down
6 changes: 3 additions & 3 deletions server/src/outputs/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub struct WriteFileMessage {

async fn handle_message(
file_handles: &mut HashMap<PathBuf, File>,
message: &mut WriteFileMessage,
message: &WriteFileMessage,
) -> Result<()> {
let parent = message
.path
Expand Down Expand Up @@ -67,8 +67,8 @@ pub async fn run(mut task_rx: mpsc::Receiver<WriteFileMessage>, task_ct: Cancell
let mut file_handles: HashMap<PathBuf, File> = HashMap::new();
loop {
tokio::select! {
Some(mut message) = task_rx.recv() => {
let result = handle_message(&mut file_handles, &mut message).await;
Some(message) = task_rx.recv() => {
let result = handle_message(&mut file_handles, &message).await;
if let Err(e) = message
.resp
.send(result) {
Expand Down

0 comments on commit b37e4f1

Please sign in to comment.