Skip to content

Commit

Permalink
feat: added EnforcerContext + enforce_with_context function
Browse files Browse the repository at this point in the history
  • Loading branch information
worapolw committed Nov 14, 2023
1 parent b429e4d commit aa85ca8
Show file tree
Hide file tree
Showing 3 changed files with 250 additions and 3 deletions.
12 changes: 10 additions & 2 deletions src/core_api.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::{
Adapter, Effector, EnforceArgs, Event, EventEmitter, Filter, Model, Result,
RoleManager, TryIntoAdapter, TryIntoModel,
enforcer::EnforceContext, Adapter, Effector, EnforceArgs, Event,
EventEmitter, Filter, Model, Result, RoleManager, TryIntoAdapter,
TryIntoModel,
};

#[cfg(feature = "watcher")]
Expand Down Expand Up @@ -64,6 +65,13 @@ pub trait CoreApi: Send + Sync {
Self: Sized;
fn set_effector(&mut self, e: Box<dyn Effector>);
fn enforce<ARGS: EnforceArgs>(&self, rvals: ARGS) -> Result<bool>
where
Self: Sized;
fn enforce_with_context<ARGS: EnforceArgs>(
&self,
ctx: EnforceContext,
rvals: ARGS,
) -> Result<bool>
where
Self: Sized;
fn enforce_mut<ARGS: EnforceArgs>(&mut self, rvals: ARGS) -> Result<bool>
Expand Down
218 changes: 217 additions & 1 deletion src/enforcer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{
effector::{DefaultEffector, EffectKind, Effector},
emitter::{Event, EventData, EventEmitter},
error::{ModelError, PolicyError, RequestError},
get_or_err,
get_or_err, get_or_err_with_context,
management_api::MgmtApi,
model::{FunctionMap, Model},
rbac::{DefaultRoleManager, RoleManager},
Expand Down Expand Up @@ -74,6 +74,30 @@ pub struct Enforcer {
logger: Box<dyn Logger>,
}

pub struct EnforceContext {
pub r_type: String,
pub p_type: String,
pub e_type: String,
pub m_type: String,
}

impl EnforceContext {
pub fn new(suffix: &str) -> Self {
Self {
r_type: format!("r{}", suffix),
p_type: format!("p{}", suffix),
e_type: format!("e{}", suffix),
m_type: format!("m{}", suffix),
}
}
pub fn get_cache_key(&self) -> String {
format!(
"EnforceContext{{{}-{}-{}-{}}}",
&self.r_type, &self.p_type, &self.e_type, &self.m_type,
)
}
}

impl EventEmitter<Event> for Enforcer {
fn on(&mut self, e: Event, f: fn(&mut Self, EventData)) {
self.events.entry(e).or_insert_with(Vec::new).push(f)
Expand Down Expand Up @@ -198,6 +222,135 @@ impl Enforcer {
}))
}

pub(crate) fn private_enforce_with_context(
&self,
ctx: EnforceContext,
rvals: &[Dynamic],
) -> Result<(bool, Option<Vec<usize>>)> {
if !self.enabled {
return Ok((true, None));
}

let mut scope: Scope = Scope::new();
let r_ast = get_or_err_with_context!(
self,
"r",
ctx.r_type,
ModelError::R,
"request"
);
let p_ast = get_or_err_with_context!(
self,
"p",
ctx.p_type,
ModelError::P,
"policy"
);
let m_ast = get_or_err_with_context!(
self,
"m",
ctx.m_type,
ModelError::M,
"matcher"
);
let e_ast = get_or_err_with_context!(
self,
"e",
ctx.e_type,
ModelError::E,
"effector"
);

if r_ast.tokens.len() != rvals.len() {
return Err(RequestError::UnmatchRequestDefinition(
r_ast.tokens.len(),
rvals.len(),
)
.into());
}

for (rtoken, rval) in r_ast.tokens.iter().zip(rvals.iter()) {
scope.push_constant_dynamic(rtoken, rval.to_owned());
}

let policies = p_ast.get_policy();
let (policy_len, scope_len) = (policies.len(), scope.len());

let mut eft_stream =
self.eft.new_stream(&e_ast.value, max(policy_len, 1));
let m_ast_compiled = self
.engine
.compile_expression(&escape_eval(&m_ast.value))
.map_err(Into::<Box<EvalAltResult>>::into)?;

if policy_len == 0 {
for token in p_ast.tokens.iter() {
scope.push_constant(token, String::new());
}

let eval_result = self
.engine
.eval_ast_with_scope::<bool>(&mut scope, &m_ast_compiled)?;
let eft = if eval_result {
EffectKind::Allow
} else {
EffectKind::Indeterminate
};

eft_stream.push_effect(eft);

return Ok((eft_stream.next(), None));
}

for pvals in policies {
scope.rewind(scope_len);

if p_ast.tokens.len() != pvals.len() {
return Err(PolicyError::UnmatchPolicyDefinition(
p_ast.tokens.len(),
pvals.len(),
)
.into());
}
for (ptoken, pval) in p_ast.tokens.iter().zip(pvals.iter()) {
scope.push_constant(ptoken, pval.to_owned());
}

let eval_result = self
.engine
.eval_ast_with_scope::<bool>(&mut scope, &m_ast_compiled)?;
let eft = match p_ast.tokens.iter().position(|x| x == "p_eft") {
Some(j) if eval_result => {
let p_eft = &pvals[j];
if p_eft == "deny" {
EffectKind::Deny
} else if p_eft == "allow" {
EffectKind::Allow
} else {
EffectKind::Indeterminate
}
}
None if eval_result => EffectKind::Allow,
_ => EffectKind::Indeterminate,
};

if eft_stream.push_effect(eft) {
break;
}
}

Ok((eft_stream.next(), {
#[cfg(feature = "explain")]
{
eft_stream.explain()
}
#[cfg(not(feature = "explain"))]
{
None
}
}))
}

pub(crate) fn register_g_functions(&mut self) -> Result<()> {
if let Some(ast_map) = self.model.get_model().get("g") {
for (fname, ast) in ast_map {
Expand Down Expand Up @@ -426,6 +579,69 @@ impl CoreApi for Enforcer {

Ok(authorized)
}
/// Enforce decides whether a "subject" can access a "object" with the operation "action",
/// input parameters are usually: (sub, obj, act).
/// this function will add suffix to each model eg. r2, p2, e2, m2, g2,
///
/// # Examples
/// ```
/// use casbin::prelude::*;
/// #[cfg(feature = "runtime-async-std")]
/// #[async_std::main]
/// async fn main() -> Result<()> {
/// let mut e = Enforcer::new("examples/basic_model.conf", "examples/basic_policy.csv").await?;
/// assert_eq!(true, e.enforce_with_index(2, ("alice", "data1", "read"))?);
/// Ok(())
/// }
///
/// #[cfg(feature = "runtime-tokio")]
/// #[tokio::main]
/// async fn main() -> Result<()> {
/// let mut e = Enforcer::new("examples/basic_model.conf", "examples/basic_policy.csv").await?;
/// assert_eq!(true, e.enforce_with_index(2, ("alice", "data1", "read"))?);
///
/// Ok(())
/// }
/// #[cfg(all(not(feature = "runtime-async-std"), not(feature = "runtime-tokio")))]
/// fn main() {}
/// ```
fn enforce_with_context<ARGS: EnforceArgs>(
&self,
ctx: EnforceContext,
rvals: ARGS,
) -> Result<bool> {
let rvals = rvals.try_into_vec()?;
#[allow(unused_variables)]
let (authorized, indices) =
self.private_enforce_with_context(ctx, &rvals)?;

#[cfg(feature = "logging")]
{
self.logger.print_enforce_log(
rvals.iter().map(|x| x.to_string()).collect(),
authorized,
false,
);

#[cfg(feature = "explain")]
if let Some(indices) = indices {
let all_rules = get_or_err!(self, "p", ModelError::P, "policy")
.get_policy();

let rules: Vec<String> = indices
.into_iter()
.filter_map(|y| {
all_rules.iter().nth(y).map(|x| x.join(", "))
})
.collect();

self.logger.print_explain_log(rules);
}
}

Ok(authorized)
}

fn enforce_mut<ARGS: EnforceArgs>(&mut self, rvals: ARGS) -> Result<bool> {
self.enforce(rvals)
Expand Down
23 changes: 23 additions & 0 deletions src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,29 @@ macro_rules! get_or_err {
}};
}

#[macro_export]
macro_rules! get_or_err_with_context {
($this:ident, $key:expr, $ctx:expr, $err:expr, $msg:expr) => {{
$this
.get_model()
.get_model()
.get($key)
.ok_or_else(|| {
$crate::error::Error::from($err(format!(
"Missing {} definition in conf file",
$msg
)))
})?
.get(&format!("{}{}", $key, $ctx))
.ok_or_else(|| {
$crate::error::Error::from($err(format!(
"Missing {} section in conf file",
$msg
)))
})?
}};
}

#[macro_export]
macro_rules! register_g_function {
($enforcer:ident, $fname:ident, $ast:ident) => {{
Expand Down

0 comments on commit aa85ca8

Please sign in to comment.