diff --git a/src/convert.rs b/src/convert.rs index ab39cd20..624c0800 100644 --- a/src/convert.rs +++ b/src/convert.rs @@ -17,6 +17,10 @@ use std::{ pub trait TryIntoModel: Send + Sync { async fn try_into_model(self) -> Result>; } +#[async_trait] +pub trait TryIntoModelFromStr:Send + Sync { + async fn try_into_model_from_str(self) -> Result>; +} #[async_trait] pub trait TryIntoAdapter: Send + Sync { @@ -36,6 +40,12 @@ impl TryIntoModel for &'static str { } } } +#[async_trait] +impl TryIntoModelFromStr for &'static str { + async fn try_into_model_from_str(self)->Result>{ + Ok(Box::new(DefaultModel::from_str(self).await?)) + } +} #[async_trait] impl TryIntoModel for Option diff --git a/src/core_api.rs b/src/core_api.rs index 352bf432..a9e29bf1 100644 --- a/src/core_api.rs +++ b/src/core_api.rs @@ -1,9 +1,8 @@ use crate::{ enforcer::EnforceContext, model::OperatorFunction, Adapter, Effector, EnforceArgs, Event, EventEmitter, Filter, Model, Result, RoleManager, - TryIntoAdapter, TryIntoModel, + TryIntoAdapter, TryIntoModel,convert::TryIntoModelFromStr }; - #[cfg(feature = "watcher")] use crate::Watcher; @@ -24,12 +23,24 @@ pub trait CoreApi: Send + Sync { m: M, a: A, ) -> Result + where + Self: Sized; + async fn new_raw_from_str( + m: M, + a: A, + ) -> Result where Self: Sized; async fn new( m: M, a: A, ) -> Result + where + Self: Sized; + async fn new_from_str( + m: M, + a: A, + ) -> Result where Self: Sized; fn add_function(&mut self, fname: &str, f: OperatorFunction); diff --git a/src/enforcer.rs b/src/enforcer.rs index 7f87eae1..e4d2c443 100644 --- a/src/enforcer.rs +++ b/src/enforcer.rs @@ -1,6 +1,6 @@ use crate::{ adapter::{Adapter, Filter}, - convert::{EnforceArgs, TryIntoAdapter, TryIntoModel}, + convert::{EnforceArgs, TryIntoAdapter, TryIntoModel, TryIntoModelFromStr}, core_api::CoreApi, effector::{DefaultEffector, EffectKind, Effector}, emitter::{Event, EventData, EventEmitter}, @@ -426,6 +426,50 @@ impl CoreApi for Enforcer { Ok(e) } + async fn new_raw_from_str( + m: M, + a: A, + ) -> Result { + let model = m.try_into_model_from_str().await?; + let adapter = a.try_into_adapter().await?; + let fm = FunctionMap::default(); + let eft = Box::new(DefaultEffector); + let rm = Arc::new(RwLock::new(DefaultRoleManager::new(10))); + + let mut engine = Engine::new_raw(); + + engine.register_global_module(CASBIN_PACKAGE.as_shared_module()); + + for (key, &func) in fm.get_functions() { + Self::register_function(&mut engine, key, func); + } + + let mut e = Self { + model, + adapter, + fm, + eft, + rm, + enabled: true, + auto_save: true, + auto_build_role_links: true, + #[cfg(feature = "watcher")] + auto_notify_watcher: true, + #[cfg(feature = "watcher")] + watcher: None, + events: HashMap::new(), + engine, + #[cfg(feature = "logging")] + logger: Box::new(DefaultLogger::default()), + }; + + #[cfg(any(feature = "logging", feature = "watcher"))] + e.on(Event::PolicyChange, notify_logger_and_watcher); + + e.register_g_functions()?; + + Ok(e) + } #[inline] async fn new( @@ -440,6 +484,18 @@ impl CoreApi for Enforcer { } Ok(e) } + async fn new_from_str( + m: M, + a: A, + ) -> Result { + let mut e = Self::new_raw_from_str(m, a).await?; + + // Do not initialize the full policy when using a filtered adapter + if !e.adapter.is_filtered() { + e.load_policy().await?; + } + Ok(e) + } #[inline] fn add_function(&mut self, fname: &str, f: OperatorFunction) { @@ -1301,6 +1357,49 @@ mod tests { e.model = e2.model; assert_eq!(true, e.enforce(("root", "data1", "read")).unwrap()); } + #[tokio::test] + async fn test_get_and_set_model_from_str() { + let model_str = r#" + [request_definition] + r = sub, obj, act + + [policy_definition] + p = sub, obj, act + + [policy_effect] + e = some(where (p.eft == allow)) + + [matchers] + m = r.sub == p.sub && r.obj == p.obj && r.act == p.act + "#; + + let m1 = DefaultModel::from_str(model_str).await.unwrap(); + let adapter1 = FileAdapter::new("examples/basic_policy.csv"); + let mut e = Enforcer::new(m1, adapter1).await.unwrap(); + + assert_eq!(false, e.enforce(("root", "data1", "read")).unwrap()); + + let model_str_with_root = r#" + [request_definition] + r = sub, obj, act + + [policy_definition] + p = sub, obj, act + + [policy_effect] + e = some(where (p.eft == allow)) + + [matchers] + m = r.sub == p.sub && r.obj == p.obj && r.act == p.act || r.sub == "root" # root is the super user + "#; + + let m2 = DefaultModel::from_str(model_str_with_root).await.unwrap(); + let adapter2 = FileAdapter::new("examples/basic_policy.csv"); + let e2 = Enforcer::new(m2, adapter2).await.unwrap(); + + e.model = e2.model; + assert_eq!(true, e.enforce(("root", "data1", "read")).unwrap()); + } #[cfg(not(target_arch = "wasm32"))] #[cfg_attr(