Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various fixes #57

Merged
merged 6 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/rust-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ jobs:
sh up.sh
- name: Build
run: cargo build --verbose
- name: Clippy
run: cargo clippy -- -Dwarnings
- name: Run tests
run: |
# Env var to connect to the Postgres docker
Expand Down
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions cli/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#![allow(clippy::too_many_arguments)]

use std::env;

use common::{database::schema::Version, settings::DEFAULT_CONFIG_FILE};
Expand Down
7 changes: 3 additions & 4 deletions cli/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@ pub fn confirm(message: &str) -> bool {
print!("{} [y/n] ", message);
io::stdout().flush().unwrap();
let mut input = String::new();
match io::stdin().read_line(&mut input) {
Ok(n) if n == 2 => return input.to_ascii_lowercase().trim() == "y",
_ => (),
};
if let Ok(2) = io::stdin().read_line(&mut input) {
return input.to_ascii_lowercase().trim() == "y";
}
}
false
}
Expand Down
14 changes: 8 additions & 6 deletions common/src/database/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,14 @@ pub struct PostgresDatabase {

impl PostgresDatabase {
pub async fn new(settings: &Postgres) -> Result<PostgresDatabase> {
let mut config = Config::default();
config.host = Some(settings.host().to_string());
config.port = Some(settings.port());
config.user = Some(settings.user().to_string());
config.password = Some(settings.password().to_string());
config.dbname = Some(settings.dbname().to_string());
let mut config = Config {
host: Some(settings.host().to_string()),
port: Some(settings.port()),
user: Some(settings.user().to_string()),
password: Some(settings.password().to_string()),
dbname: Some(settings.dbname().to_string()),
..Default::default()
};

let pool = if *settings.ssl_mode() == PostgresSslMode::Disable {
config
Expand Down
2 changes: 2 additions & 0 deletions common/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#![allow(clippy::too_many_arguments)]

pub mod bookmark;
pub mod database;
pub mod encoding;
Expand Down
9 changes: 2 additions & 7 deletions common/src/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,19 +98,14 @@ impl SQLite {
}
}

#[derive(Debug, Deserialize, Clone, PartialEq, Eq)]
#[derive(Debug, Deserialize, Clone, PartialEq, Eq, Default)]
pub enum PostgresSslMode {
Disable,
#[default]
Prefer,
Require,
}

impl Default for PostgresSslMode {
fn default() -> Self {
PostgresSslMode::Prefer
}
}

#[derive(Debug, Deserialize, Clone)]
pub struct Postgres {
host: String,
Expand Down
1 change: 1 addition & 0 deletions server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,4 @@ hex = "0.4.3"
redis = { version = "0.23.3", features = ["tokio-comp", "aio"]}
log4rs = "1.2.0"
log-mdc = "0.1.0"
tokio-util = "0.7.10"
24 changes: 6 additions & 18 deletions server/src/heartbeat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,8 @@ use common::{
heartbeat::{HeartbeatKey, HeartbeatValue, HeartbeatsCache},
};
use log::{debug, error, info};
use tokio::{
select,
sync::{mpsc, oneshot},
time,
};
use tokio::{select, sync::mpsc, time};
use tokio_util::sync::CancellationToken;

pub async fn store_heartbeat(
heartbeat_tx: mpsc::Sender<WriteHeartbeatMessage>,
Expand Down Expand Up @@ -62,7 +59,7 @@ pub async fn heartbeat_task(
db: Db,
interval: u64,
mut task_rx: mpsc::Receiver<WriteHeartbeatMessage>,
mut task_exit_rx: oneshot::Receiver<oneshot::Sender<()>>,
cancellation_token: CancellationToken,
) {
info!("Heartbeat task started");
let mut interval = time::interval(Duration::from_secs(interval));
Expand Down Expand Up @@ -117,26 +114,17 @@ pub async fn heartbeat_task(
info!("Heartbeat cache flushed and cleared");
}
},
sender = &mut task_exit_rx => {
_ = cancellation_token.cancelled() => {
info!("Heartbeat task received killed signal");
if !heartbeats.is_empty() {
info!("Flush heartbeat cache before killing the task");
if let Err(e) = db.store_heartbeats(&heartbeats).await {
error!("Could not store heartbeats in database: {:?}", e);
}
}

match sender {
Ok(sender) => {
if let Err(e) = sender.send(()) {
error!("Failed to respond to kill order: {:?}", e);
}
},
Err(e) => {
error!("Could not respond to kill order: {:?}", e);
}
}
break;
}
}
}
info!("Heartbeat task exited");
}
47 changes: 17 additions & 30 deletions server/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#![allow(clippy::too_many_arguments)]

mod event;
mod formatter;
mod heartbeat;
Expand Down Expand Up @@ -50,9 +52,10 @@ use std::{env, mem};
use subscription::{reload_subscriptions_task, Subscriptions};
use tls_listener::TlsListener;
use tokio::signal::unix::SignalKind;
use tokio::sync::{mpsc, oneshot};
use tokio::sync::mpsc;
use tokio_rustls::server::TlsStream;
use tokio_rustls::TlsAcceptor;
use tokio_util::sync::CancellationToken;

use crate::logging::ACCESS_LOGGER;
use crate::tls::{make_config, subject_from_cert};
Expand Down Expand Up @@ -651,16 +654,9 @@ pub async fn run(settings: Settings, verbosity: u8) {
let interval = settings.server().db_sync_interval();
let update_task_db = db.clone();
let update_task_subscriptions = subscriptions.clone();
let (update_task_subscription_exit_tx, update_task_subscription_exit_rx) = oneshot::channel();
// Launch a task responsible for updating subscriptions
tokio::spawn(async move {
reload_subscriptions_task(
update_task_db,
update_task_subscriptions,
interval,
update_task_subscription_exit_rx,
)
.await
reload_subscriptions_task(update_task_db, update_task_subscriptions, interval).await
});

// To reduce database load, heartbeats are not saved immediately.
Expand All @@ -674,11 +670,15 @@ pub async fn run(settings: Settings, verbosity: u8) {
let update_task_db = db.clone();
let (heartbeat_tx, heartbeat_rx) =
mpsc::channel(settings.server().heartbeats_queue_size() as usize);
let (heartbeat_exit_tx, heartbeat_exit_rx) = oneshot::channel();

// We use a CancellationToken to tell the task to shutdown, so
// that it is able to store cached heartbeats.
let heartbeat_ct = CancellationToken::new();
let cloned_heartbaat_ct = heartbeat_ct.clone();

// Launch the task responsible for managing heartbeats
tokio::spawn(async move {
heartbeat_task(update_task_db, interval, heartbeat_rx, heartbeat_exit_rx).await
let heartbeat_task = tokio::spawn(async move {
heartbeat_task(update_task_db, interval, heartbeat_rx, cloned_heartbaat_ct).await
});

// Set KRB5_KTNAME env variable if necessary (i.e. if at least one collector uses
Expand Down Expand Up @@ -752,23 +752,10 @@ pub async fn run(settings: Settings, verbosity: u8) {

info!("HTTP server has been shutdown.");

let (task_ended_tx, task_ended_rx) = oneshot::channel();
if let Err(e) = heartbeat_exit_tx.send(task_ended_tx) {
error!("Failed to shutdown heartbeat task: {:?}", e);
};
if let Err(e) = task_ended_rx.await {
error!("Failed to wait for heartbeat task shutdown: {:?}", e);
}

info!("Heartbeat task has been terminated.");

let (task_ended_tx, task_ended_rx) = oneshot::channel();
if let Err(e) = update_task_subscription_exit_tx.send(task_ended_tx) {
error!("Failed to shutdown update subscription task: {:?}", e);
// Signal the task that we want to shutdown
heartbeat_ct.cancel();
// Wait for the task to shutdown gracefully
if let Err(e) = heartbeat_task.await {
error!("Failed to wait for heartbeat task to shutdown: {:?}", e)
}
if let Err(e) = task_ended_rx.await {
error!("Failed to wait for heartbeat task shutdown: {:?}", e);
}

info!("Subscription update task has been terminated.");
}
12 changes: 6 additions & 6 deletions server/src/logic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use common::{
settings::{Collector, Server},
};
use http::status::StatusCode;
use log::{debug, error, info, warn};
use log::{debug, error, warn};
use std::{collections::HashMap, sync::Arc};
use tokio::{sync::mpsc, task::JoinSet};

Expand Down Expand Up @@ -70,7 +70,7 @@ async fn handle_enumerate(
}
};

info!(
debug!(
"Received Enumerate request from {}:{} ({}) with URI {}",
request_data.remote_addr.ip(),
request_data.remote_addr.port(),
Expand Down Expand Up @@ -261,7 +261,7 @@ async fn handle_heartbeat(
};

if !subscription.data().is_active_for(request_data.principal()) {
info!(
debug!(
"Received Heartbeat from {}:{} ({}) for subscription {} ({}) but the principal is not allowed to use the subscription.",
request_data.remote_addr().ip(),
request_data.remote_addr().port(),
Expand All @@ -272,7 +272,7 @@ async fn handle_heartbeat(
return Ok(Response::err(StatusCode::FORBIDDEN));
}

info!(
debug!(
"Received Heartbeat from {}:{} ({:?}) for subscription {} ({})",
request_data.remote_addr().ip(),
request_data.remote_addr().port(),
Expand Down Expand Up @@ -324,7 +324,7 @@ async fn handle_events(
};

if !subscription.data().is_active_for(request_data.principal()) {
info!(
debug!(
"Received Events from {}:{} ({}) for subscription {} ({}) but the principal is not allowed to use this subscription.",
request_data.remote_addr().ip(),
request_data.remote_addr().port(),
Expand All @@ -335,7 +335,7 @@ async fn handle_events(
return Ok(Response::err(StatusCode::FORBIDDEN));
}

info!(
debug!(
"Received Events from {}:{} ({}) for subscription {} ({})",
request_data.remote_addr().ip(),
request_data.remote_addr().port(),
Expand Down
Loading