Skip to content

Commit

Permalink
subscription: make ClientFilter optional instead of Operation
Browse files Browse the repository at this point in the history
  • Loading branch information
MrAnno committed Dec 5, 2024
1 parent bcc98ec commit 2ff60b9
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 138 deletions.
62 changes: 35 additions & 27 deletions cli/src/subscriptions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use common::{
encoding::decode_utf16le,
settings::Settings,
subscription::{
ContentFormat, FilesConfiguration, KafkaConfiguration, ClientFilterOperation,
ContentFormat, FilesConfiguration, KafkaConfiguration, ClientFilter, ClientFilterOperation,
RedisConfiguration, SubscriptionData, SubscriptionMachineState, SubscriptionOutput,
SubscriptionOutputDriver, SubscriptionOutputFormat, TcpConfiguration,
UnixDatagramConfiguration,
Expand All @@ -20,7 +20,7 @@ use std::{
};
use uuid::Uuid;

use anyhow::{anyhow, bail, ensure, Context, Result};
use anyhow::{anyhow, bail, ensure, Context, Ok, Result};
use clap::ArgMatches;
use log::{debug, info, warn};
use std::io::Write;
Expand Down Expand Up @@ -595,30 +595,39 @@ async fn delete(db: &Db, matches: &ArgMatches) -> Result<()> {
}

async fn edit_filter(subscription: &mut SubscriptionData, matches: &ArgMatches) -> Result<()> {
let mut filter = subscription.client_filter().clone();
match matches.subcommand() {
Some(("set", matches)) => {
let op_str = matches
.get_one::<String>("operation")
.ok_or_else(|| anyhow!("Missing operation argument"))?;
if let Some(("set", matches)) = matches.subcommand() {
let op_str = matches
.get_one::<String>("operation")
.ok_or_else(|| anyhow!("Missing operation argument"))?;

let op_opt = ClientFilterOperation::opt_from_str(op_str)?;
filter.set_operation(op_opt.clone());
if op_str.eq_ignore_ascii_case("none") {
subscription.set_client_filter(None);
return Ok(());
}

if let Some(op) = op_opt {
let mut princs = HashSet::new();
if let Some(identifiers) = matches.get_many::<String>("principals") {
for identifier in identifiers {
princs.insert(identifier.clone());
}
}
if op == ClientFilterOperation::Only && princs.is_empty() {
warn!("'{}' filter has been set without principals making this subscription apply to nothing.", op)
}
filter.set_targets(princs)?;
let op = op_str.parse()?;

let mut princs = HashSet::new();
if let Some(identifiers) = matches.get_many::<String>("principals") {
for identifier in identifiers {
princs.insert(identifier.clone());
}
}
Some(("princs", matches)) => match matches.subcommand() {
if op == ClientFilterOperation::Only && princs.is_empty() {
warn!("'{}' filter has been set without principals making this subscription apply to nothing.", op)
}

subscription.set_client_filter(Some(ClientFilter::new(op, princs)));
return Ok(());
}

if let Some(("princs", matches)) = matches.subcommand() {
let filter = subscription.client_filter().cloned();
let Some(mut filter) = filter else {
bail!("No filter is set");
};

match matches.subcommand() {
Some(("add", matches)) => {
filter.add_target(
matches
Expand Down Expand Up @@ -648,14 +657,13 @@ async fn edit_filter(subscription: &mut SubscriptionData, matches: &ArgMatches)
_ => {
bail!("Nothing to do");
}
},
_ => {
bail!("Nothing to do");
}

subscription.set_client_filter(Some(filter));
return Ok(())
}
subscription.set_client_filter(filter);

Ok(())
bail!("Nothing to do");
}
async fn outputs(subscription: &mut SubscriptionData, matches: &ArgMatches) -> Result<()> {
info!(
Expand Down
37 changes: 17 additions & 20 deletions common/src/database/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ pub mod tests {
assert_eq!(toto.read_existing_events(), DEFAULT_READ_EXISTING_EVENTS);
assert_eq!(toto.content_format(), &DEFAULT_CONTENT_FORMAT);
assert_eq!(toto.ignore_channel_error(), DEFAULT_IGNORE_CHANNEL_ERROR);
assert_eq!(toto.client_filter().operation(), None);
assert_eq!(toto.client_filter(), None);
assert_eq!(toto.is_active(), false);
assert_eq!(toto.is_active_for("couscous"), false);
assert_eq!(toto.revision(), None);
Expand All @@ -194,10 +194,10 @@ pub mod tests {
.set_read_existing_events(true)
.set_content_format(ContentFormat::RenderedText)
.set_ignore_channel_error(false)
.set_client_filter(ClientFilter::from(
Some("Only".to_string()),
.set_client_filter(Some(ClientFilter::from(
"Only".to_string(),
Some("couscous,boulette".to_string()),
)?)
)?))
.set_outputs(vec![
SubscriptionOutput::new(
SubscriptionOutputFormat::Json,
Expand Down Expand Up @@ -227,11 +227,11 @@ pub mod tests {
assert_eq!(tata.content_format(), &ContentFormat::RenderedText);
assert_eq!(tata.ignore_channel_error(), false);
assert_eq!(
*tata.client_filter().operation().unwrap(),
*tata.client_filter().unwrap().operation(),
ClientFilterOperation::Only
);
assert_eq!(
tata.client_filter().targets(),
tata.client_filter().unwrap().targets(),
&HashSet::from(["couscous".to_string(), "boulette".to_string()])
);

Expand Down Expand Up @@ -272,8 +272,8 @@ pub mod tests {
.set_ignore_channel_error(true)
.set_revision(Some("1890".to_string()))
.set_data_locale(Some("fr-FR".to_string()));
let mut new_client_filter = tata.client_filter().clone();
new_client_filter.add_target("semoule")?;
let mut new_client_filter = tata.client_filter().cloned();
new_client_filter.as_mut().unwrap().add_target("semoule")?;
tata.set_client_filter(new_client_filter);

db.store_subscription(&tata).await?;
Expand All @@ -293,11 +293,11 @@ pub mod tests {
assert_eq!(tata2.content_format(), &ContentFormat::Raw);
assert_eq!(tata2.ignore_channel_error(), true);
assert_eq!(
*tata2.client_filter().operation().unwrap(),
*tata2.client_filter().unwrap().operation(),
ClientFilterOperation::Only
);
assert_eq!(
tata2.client_filter().targets(),
tata2.client_filter().unwrap().targets(),
&HashSet::from([
"couscous".to_string(),
"boulette".to_string(),
Expand All @@ -312,9 +312,9 @@ pub mod tests {

assert!(tata2.public_version()? != tata_save.public_version()?);

let mut new_client_filter = tata2.client_filter().clone();
new_client_filter.delete_target("couscous")?;
new_client_filter.set_operation(Some(ClientFilterOperation::Except));
let mut new_client_filter = tata2.client_filter().cloned();
new_client_filter.as_mut().unwrap().delete_target("couscous")?;
new_client_filter.as_mut().unwrap().set_operation(ClientFilterOperation::Except);
tata2.set_client_filter(new_client_filter);

db.store_subscription(&tata2).await?;
Expand All @@ -324,30 +324,27 @@ pub mod tests {
.await?
.unwrap();
assert_eq!(
*tata2_clone.client_filter().operation().unwrap(),
*tata2_clone.client_filter().unwrap().operation(),
ClientFilterOperation::Except
);
assert_eq!(
tata2_clone.client_filter().targets(),
tata2_clone.client_filter().unwrap().targets(),
&HashSet::from(["boulette".to_string(), "semoule".to_string()])
);

assert_eq!(tata2_clone.is_active_for("couscous"), true);
assert_eq!(tata2_clone.is_active_for("semoule"), false);
assert_eq!(tata2_clone.is_active_for("boulette"), false);

let mut new_client_filter = tata2_clone.client_filter().clone();
new_client_filter.set_operation(None);
tata2_clone.set_client_filter(new_client_filter);
tata2_clone.set_client_filter(None);

db.store_subscription(&tata2_clone).await?;

let tata2_clone_clone = db
.get_subscription_by_identifier(&tata.uuid_string())
.await?
.unwrap();
assert_eq!(tata2_clone_clone.client_filter().operation(), None);
assert_eq!(tata2_clone_clone.client_filter().targets(), &HashSet::new());
assert_eq!(tata2_clone_clone.client_filter(), None);
assert_eq!(tata2_clone_clone.is_active_for("couscous"), true);
assert_eq!(tata2_clone_clone.is_active_for("semoule"), true);
assert_eq!(tata2_clone_clone.is_active_for("boulette"), true);
Expand Down
20 changes: 11 additions & 9 deletions common/src/database/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,12 @@ fn row_to_subscription(row: &Row) -> Result<SubscriptionData> {
let max_time: i32 = row.try_get("max_time")?;
let max_elements: Option<i32> = row.try_get("max_elements")?;

let client_filter = ClientFilter::from(
row.try_get("princs_filter_op")?,
row.try_get("princs_filter_value")?,
)?;
let client_filter_op: Option<String> = row.try_get("princs_filter_op")?;

let client_filter = match client_filter_op {
Some(op) => Some(ClientFilter::from(op, row.try_get("princs_filter_value")?)?),
None => None
};

let mut subscription = SubscriptionData::new(row.try_get("name")?, row.try_get("query")?);
subscription
Expand Down Expand Up @@ -625,6 +627,9 @@ impl Database for PostgresDatabase {
};

let max_envelope_size: i32 = subscription.max_envelope_size().try_into()?;
let client_filter_op: Option<String> = subscription.client_filter().map(|f| f.operation().to_string());
let client_filter_value = subscription.client_filter().and_then(|f| f.targets_to_opt_string());

let count = self
.pool
.get()
Expand Down Expand Up @@ -674,11 +679,8 @@ impl Database for PostgresDatabase {
&subscription.read_existing_events(),
&subscription.content_format().to_string(),
&subscription.ignore_channel_error(),
&subscription
.client_filter()
.operation()
.map(|x| x.to_string()),
&subscription.client_filter().targets_to_opt_string(),
&client_filter_op,
&client_filter_value,
&serde_json::to_string(subscription.outputs())?.as_str(),
&subscription.locale(),
&subscription.data_locale()
Expand Down
15 changes: 12 additions & 3 deletions common/src/database/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,13 @@ fn row_to_subscription(row: &Row) -> Result<SubscriptionData> {
let query: String = row.get("query")?;

let content_format = ContentFormat::from_str(row.get::<&str, String>("content_format")?.as_ref())?;
let client_filter = ClientFilter::from(row.get("princs_filter_op")?, row.get("princs_filter_value")?)?;

let client_filter_op: Option<String> = row.get("princs_filter_op")?;

let client_filter = match client_filter_op {
Some(op) => Some(ClientFilter::from(op, row.get("princs_filter_value")?)?),
None => None
};

let mut subscription= SubscriptionData::new(&name, &query);
subscription.set_uuid(SubscriptionUuid(Uuid::parse_str(&uuid)?))
Expand Down Expand Up @@ -546,6 +552,9 @@ impl Database for SQLiteDatabase {

async fn store_subscription(&self, subscription: &SubscriptionData) -> Result<()> {
let subscription = subscription.clone();
let client_filter_op: Option<String> = subscription.client_filter().map(|f| f.operation().to_string());
let client_filter_value = subscription.client_filter().and_then(|f| f.targets_to_opt_string());

let count = self
.pool
.get()
Expand Down Expand Up @@ -600,8 +609,8 @@ impl Database for SQLiteDatabase {
":read_existing_events": subscription.read_existing_events(),
":content_format": subscription.content_format().to_string(),
":ignore_channel_error": subscription.ignore_channel_error(),
":princs_filter_op": subscription.client_filter().operation().map(|x| x.to_string()),
":princs_filter_value": subscription.client_filter().targets_to_opt_string(),
":princs_filter_op": client_filter_op,
":princs_filter_value": client_filter_value,
":outputs": serde_json::to_string(subscription.outputs())?,
":locale": subscription.locale(),
":data_locale": subscription.data_locale(),
Expand Down
16 changes: 5 additions & 11 deletions common/src/models/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ impl From<ClientFilterOperation> for crate::subscription::ClientFilterOperation
#[derive(Debug, Clone, Eq, PartialEq, Deserialize)]
#[serde(deny_unknown_fields)]
struct ClientFilter {
pub operation: Option<ClientFilterOperation>,
pub operation: ClientFilterOperation,
#[serde(alias = "cert_subjects", alias = "princs")]
pub targets: HashSet<String>,
}
Expand All @@ -186,11 +186,7 @@ impl TryFrom<ClientFilter> for crate::subscription::ClientFilter {
type Error = anyhow::Error;

fn try_from(value: ClientFilter) -> std::prelude::v1::Result<Self, Self::Error> {
let mut filter = crate::subscription::ClientFilter::empty();
let operation = value.operation.map(|op| op.into());
filter.set_operation(operation);
filter.set_targets(value.targets)?;
Ok(filter)
Ok(crate::subscription::ClientFilter::new(value.operation.into(), value.targets))
}
}

Expand Down Expand Up @@ -296,7 +292,7 @@ impl TryFrom<Subscription> for crate::subscription::SubscriptionData {
data.set_name(subscription.name.clone());
data.set_query(subscription.query.clone());
if let Some(filter) = subscription.filter {
data.set_client_filter(filter.try_into()?);
data.set_client_filter(Some(filter.try_into()?));
}

if subscription.outputs.is_empty() {
Expand Down Expand Up @@ -507,14 +503,12 @@ path = "/whatever/you/{ip}/want/{principal}/{ip:2}/{node}/end"

expected.set_outputs(outputs);

let mut filter = crate::subscription::ClientFilter::empty();
filter.set_operation(Some(crate::subscription::ClientFilterOperation::Only));
let mut targets = HashSet::new();
targets.insert("toto@windomain.local".to_string());
targets.insert("tutu@windomain.local".to_string());
filter.set_targets(targets)?;
let filter = crate::subscription::ClientFilter::new(crate::subscription::ClientFilterOperation::Only, targets);

expected.set_client_filter(filter);
expected.set_client_filter(Some(filter));

// The only difference between both subscriptions should be the
// internal version, so we set both the same value
Expand Down
Loading

0 comments on commit 2ff60b9

Please sign in to comment.