diff --git a/Cargo.lock b/Cargo.lock index 98ded35..3407bf9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -728,11 +728,12 @@ checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" [[package]] name = "rraft-py" -version = "0.2.0" +version = "0.2.13" dependencies = [ "bincode", "bytes", "fxhash", + "once_cell", "prost", "protobuf", "pyo3", diff --git a/Cargo.toml b/Cargo.toml index e29ce0d..91d3482 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rraft-py" -version = "0.2.0" +version = "0.2.9" authors = ["Lablup Inc."] license = "Apache-2.0" repository = "https://github.com/lablup/rraft-py" @@ -17,6 +17,7 @@ edition = "2021" prost = "0.11" protobuf = "2" pyo3 = { version = "0.19.2", features = ["extension-module", "multiple-pymethods"] } +pyo3-asyncio = { version = "0.19", features = ["attributes", "tokio-runtime"] } raft = { git = "https://github.com/jopemachine/raft-rs.git", branch = "feat/add-custom-deserializer", features = ["prost-codec", "default-logger"], default-features = false } slog = { version = "2.2", features = ["max_level_debug", "release_max_level_debug"] } slog-envlogger = "2.1.0" @@ -26,6 +27,8 @@ slog-async = "2.7.0" fxhash = "0.2.1" bincode = "1.3.3" bytes = "1.0" +tokio = { version = "1.32", features = ["sync", "macros"] } +once_cell = "1.7" [lib] name = "rraft" diff --git a/example/single_mem_node/use_threading.py b/example/single_mem_node/use_threading.py index 434c0ef..adbc905 100644 --- a/example/single_mem_node/use_threading.py +++ b/example/single_mem_node/use_threading.py @@ -164,6 +164,8 @@ def handle_committed_entries(committed_entries: List[EntryRef]): # Use another thread to propose a Raft request. send_propose(logger) + logger.mutex.acquire_lock_and(lambda: print('start!!!')) + t = now() timeout = 100 # Use a HashMap to hold the `propose` callbacks. diff --git a/requirements.txt b/requirements.txt index 2850225..3087cca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -maturin==1.2.2 +maturin==1.2.3 pytest==7.3.1 pytest-benchmark==4.0.0 \ No newline at end of file diff --git a/rraft.pyi b/rraft.pyi index 71133a3..bd15ea3 100644 --- a/rraft.pyi +++ b/rraft.pyi @@ -66,6 +66,24 @@ class __Decoder: def decode(v: bytes) -> Any: """ """ +def set_snapshot_data_deserializer(cb: Any) -> None: + """ """ + +def set_message_context_deserializer(cb: Any) -> None: + """ """ + +def set_confchange_context_deserializer(cb: Any) -> None: + """ """ + +def set_confchangev2_context_deserializer(cb: Any) -> None: + """ """ + +def set_entry_data_deserializer(cb: Any) -> None: + """ """ + +def set_entry_context_deserializer(cb: Any) -> None: + """ """ + class OverflowStrategy: """ """ @@ -218,6 +236,8 @@ class EntryType: def from_int(v: int) -> "EntryType": ... class __API_Logger: + mutex: "Mutex" + def info(self, s: str) -> None: """ Log info level record @@ -228,25 +248,25 @@ class __API_Logger: """ Log debug level record - See `log` for documentation. + See `slog_debug` for documentation. """ def trace(self, s: str) -> None: """ Log trace level record - See `log` for documentation. + See `slog_trace` for documentation. """ def crit(self, s: str) -> None: """ Log crit level record - See `log` for documentation. + See `slog_crit` for documentation. """ def error(self, s: str) -> None: """ Log error level record - See `log` for documentation. + See `slog_error` for documentation. """ class Logger(__API_Logger): @@ -255,6 +275,8 @@ class Logger(__API_Logger): def __init__( self, chan_size: int, overflow_strategy: "OverflowStrategy" ) -> None: ... + @staticmethod + def new_file_logger(log_path: str): ... def make_ref(self) -> "LoggerRef": ... class LoggerRef(__API_Logger): diff --git a/src/bindings/global.rs b/src/bindings/global.rs index 7163443..ce3431b 100644 --- a/src/bindings/global.rs +++ b/src/bindings/global.rs @@ -1,3 +1,5 @@ +use std::sync::{Arc, Mutex}; + use pyo3::prelude::*; use crate::raftpb_bindings::conf_change::new_conf_change_single as _new_conf_change_single; @@ -14,7 +16,7 @@ use raft::{ NO_LIMIT, }; -use crate::external_bindings::slog::PyLogger; +use crate::external_bindings::slog::{LoggerMode, PyLogger}; use crate::raftpb_bindings::message_type::PyMessageType; use crate::utils::reference::RefMutOwner; @@ -28,6 +30,8 @@ pub fn majority(total: usize) -> usize { pub fn default_logger() -> PyLogger { PyLogger { inner: RefMutOwner::new(_default_logger()), + mutex: Arc::new(Mutex::new(())), + mode: LoggerMode::Stdout, } } diff --git a/src/external_bindings/slog.rs b/src/external_bindings/slog.rs index 0c28157..588779c 100644 --- a/src/external_bindings/slog.rs +++ b/src/external_bindings/slog.rs @@ -1,9 +1,11 @@ -use pyo3::{intern, prelude::*, types::PyString}; -use slog::*; -use slog_async::OverflowStrategy; +use std::fs::OpenOptions; +use std::sync::{Arc, Mutex}; use crate::implement_type_conversion; use crate::utils::reference::{RefMutContainer, RefMutOwner}; +use pyo3::{intern, prelude::*, types::PyString}; +use slog::*; +use slog_async::OverflowStrategy; #[pyclass(name = "OverflowStrategy")] pub struct PyOverflowStrategy(pub OverflowStrategy); @@ -50,16 +52,26 @@ impl PyOverflowStrategy { } } +#[derive(Clone)] +pub enum LoggerMode { + File, + Stdout, +} + #[derive(Clone)] #[pyclass(name = "Logger")] pub struct PyLogger { pub inner: RefMutOwner, + pub mutex: Arc>, + pub mode: LoggerMode, } #[derive(Clone)] #[pyclass(name = "LoggerRef")] pub struct PyLoggerRef { pub inner: RefMutContainer, + pub mutex: Arc>, + pub mode: LoggerMode, } #[derive(FromPyObject)] @@ -76,6 +88,7 @@ impl PyLogger { pub fn new(chan_size: usize, overflow_strategy: &PyOverflowStrategy) -> Self { let decorator = slog_term::TermDecorator::new().build(); let drain = slog_term::FullFormat::new(decorator).build().fuse(); + let drain = slog_async::Async::new(drain) .chan_size(chan_size) .overflow_strategy(overflow_strategy.0) @@ -86,12 +99,39 @@ impl PyLogger { PyLogger { inner: RefMutOwner::new(logger), + mutex: Arc::new(Mutex::new(())), + mode: LoggerMode::Stdout, + } + } + + #[staticmethod] + pub fn new_file_logger(log_path: &PyString) -> Self { + let log_path = log_path.to_str().unwrap(); + let file = OpenOptions::new() + .create(true) + .write(true) + .truncate(true) + .open(log_path) + .unwrap(); + + let decorator = slog_term::PlainDecorator::new(file); + let drain = slog_term::FullFormat::new(decorator).build().fuse(); + let drain = slog_async::Async::new(drain).build().fuse(); + + let logger = slog::Logger::root(drain, o!()); + + PyLogger { + inner: RefMutOwner::new(logger), + mutex: Arc::new(Mutex::new(())), + mode: LoggerMode::File, } } pub fn make_ref(&mut self) -> PyLoggerRef { PyLoggerRef { inner: RefMutContainer::new(&mut self.inner), + mutex: self.mutex.clone(), + mode: self.mode.clone(), } } @@ -104,27 +144,87 @@ impl PyLogger { #[pymethods] impl PyLoggerRef { pub fn info(&mut self, s: &PyString) -> PyResult<()> { - self.inner - .map_as_ref(|inner| info!(inner, "{}", format!("{}", s))) + let print = || { + self.inner + .map_as_ref(|inner| info!(inner, "{}", format!("{}", s))) + }; + + match self.mode { + LoggerMode::Stdout => { + let _guard = self.mutex.lock().unwrap(); + print() + } + LoggerMode::File => { + print() + } + } } pub fn debug(&mut self, s: &PyString) -> PyResult<()> { - self.inner - .map_as_ref(|inner| debug!(inner, "{}", format!("{}", s))) + let print = || { + self.inner + .map_as_ref(|inner| debug!(inner, "{}", format!("{}", s))) + }; + + match self.mode { + LoggerMode::Stdout => { + let _guard = self.mutex.lock().unwrap(); + print() + } + LoggerMode::File => { + print() + } + } } pub fn trace(&mut self, s: &PyString) -> PyResult<()> { - self.inner - .map_as_ref(|inner| trace!(inner, "{}", format!("{}", s))) + let print = || { + self.inner + .map_as_ref(|inner| trace!(inner, "{}", format!("{}", s))) + }; + + match self.mode { + LoggerMode::Stdout => { + let _guard = self.mutex.lock().unwrap(); + print() + } + LoggerMode::File => { + print() + } + } } pub fn error(&mut self, s: &PyString) -> PyResult<()> { - self.inner - .map_as_ref(|inner| error!(inner, "{}", format!("{}", s))) + let print = || { + self.inner + .map_as_ref(|inner| error!(inner, "{}", format!("{}", s))) + }; + + match self.mode { + LoggerMode::Stdout => { + let _guard = self.mutex.lock().unwrap(); + print() + } + LoggerMode::File => { + print() + } + } } pub fn crit(&mut self, s: &PyString) -> PyResult<()> { - self.inner - .map_as_ref(|inner| crit!(inner, "{}", format!("{}", s))) + let print = || { + self.inner + .map_as_ref(|inner: &Logger| crit!(inner, "{}", format!("{}", s))) + }; + + match self.mode { + LoggerMode::Stdout => { + let _guard = self.mutex.lock().unwrap(); + print() + } + LoggerMode::File => { + print() + } + } } } diff --git a/src/lib.rs b/src/lib.rs index 728dcb2..5e89153 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -110,6 +110,36 @@ fn rraft(py: Python, m: &PyModule) -> PyResult<()> { m )?)?; + m.add_function(wrap_pyfunction!( + utils::deserializer::set_confchange_context_deserializer, + m + )?)?; + + m.add_function(wrap_pyfunction!( + utils::deserializer::set_confchangev2_context_deserializer, + m + )?)?; + + m.add_function(wrap_pyfunction!( + utils::deserializer::set_entry_context_deserializer, + m + )?)?; + + m.add_function(wrap_pyfunction!( + utils::deserializer::set_entry_data_deserializer, + m + )?)?; + + m.add_function(wrap_pyfunction!( + utils::deserializer::set_message_context_deserializer, + m + )?)?; + + m.add_function(wrap_pyfunction!( + utils::deserializer::set_snapshot_data_deserializer, + m + )?)?; + m.add( "DestroyedRefUsedError", py.get_type::(), diff --git a/src/raftpb_bindings/conf_change.rs b/src/raftpb_bindings/conf_change.rs index 06d357b..51e7b81 100644 --- a/src/raftpb_bindings/conf_change.rs +++ b/src/raftpb_bindings/conf_change.rs @@ -1,8 +1,8 @@ +use crate::implement_type_conversion; use crate::utils::{ errors::to_pyresult, reference::{RefMutContainer, RefMutOwner}, }; -use crate::{deserialize_bytes, implement_type_conversion}; use prost::Message as ProstMessage; use protobuf::Message as PbMessage; use pyo3::{intern, prelude::*, pyclass::CompareOp, types::PyBytes}; @@ -34,16 +34,6 @@ pub enum PyConfChangeMut<'p> { implement_type_conversion!(ConfChange, PyConfChangeMut); -pub fn format_confchange(cc: &ConfChange, py: Python) -> String { - format!( - "ConfChange {{ change_type: {change_type:?}, node_id: {node_id:?}, context: {context:?}, id: {id:?} }}", - change_type = cc.get_change_type(), - node_id = cc.get_node_id(), - id = cc.get_id(), - context = deserialize_bytes!(cc, "confchange_context_deserializer", context, py) - ) -} - #[pymethods] impl PyConfChange { #[new] @@ -73,8 +63,8 @@ impl PyConfChange { } } - pub fn __repr__(&self, py: Python) -> String { - format_confchange(&self.inner.inner, py) + pub fn __repr__(&self) -> String { + format!("{:?}", self.inner.inner) } pub fn __richcmp__(&self, py: Python, rhs: PyConfChangeMut, op: CompareOp) -> PyObject { @@ -95,8 +85,8 @@ impl PyConfChange { #[pymethods] impl PyConfChangeRef { - pub fn __repr__(&self, py: Python) -> PyResult { - self.inner.map_as_ref(|inner| format_confchange(inner, py)) + pub fn __repr__(&self) -> PyResult { + self.inner.map_as_ref(|inner| format!("{:?}", inner)) } pub fn __richcmp__( diff --git a/src/raftpb_bindings/conf_change_v2.rs b/src/raftpb_bindings/conf_change_v2.rs index 061a446..0ae8624 100644 --- a/src/raftpb_bindings/conf_change_v2.rs +++ b/src/raftpb_bindings/conf_change_v2.rs @@ -1,9 +1,9 @@ +use crate::implement_type_conversion; use crate::utils::{ errors::to_pyresult, reference::{RefMutContainer, RefMutOwner}, unsafe_cast::make_mut, }; -use crate::{deserialize_bytes, implement_type_conversion}; use prost::Message as ProstMessage; use protobuf::Message as PbMessage; use pyo3::{ @@ -20,15 +20,6 @@ use super::{ conf_change_transition::PyConfChangeTransition, }; -pub fn format_confchangev2(cc: &ConfChangeV2, py: Python) -> String { - format!( - "ConfChangeV2 {{ transition: {transition:?}, changes: {changes:?}, context: {context:?} }}", - transition = cc.transition(), - changes = cc.changes, - context = deserialize_bytes!(cc, "confchangev2_context_deserializer", context, py) - ) -} - #[derive(Clone)] #[pyclass(name = "ConfChangeV2")] pub struct PyConfChangeV2 { @@ -78,8 +69,8 @@ impl PyConfChangeV2 { } } - pub fn __repr__(&self, py: Python) -> String { - format_confchangev2(&self.inner.inner, py) + pub fn __repr__(&self) -> String { + format!("{:?}", self.inner.inner) } pub fn __richcmp__(&self, py: Python, rhs: PyConfChangeV2Mut, op: CompareOp) -> PyObject { @@ -106,9 +97,8 @@ impl Default for PyConfChangeV2 { #[pymethods] impl PyConfChangeV2Ref { - pub fn __repr__(&self, py: Python) -> PyResult { - self.inner - .map_as_ref(|inner| format_confchangev2(inner, py)) + pub fn __repr__(&self) -> PyResult { + self.inner.map_as_ref(|inner| format!("{:?}", inner)) } pub fn __richcmp__( diff --git a/src/raftpb_bindings/entry.rs b/src/raftpb_bindings/entry.rs index 5b6c9ac..7293889 100644 --- a/src/raftpb_bindings/entry.rs +++ b/src/raftpb_bindings/entry.rs @@ -1,6 +1,6 @@ +use crate::implement_type_conversion; use crate::utils::errors::to_pyresult; use crate::utils::reference::{RefMutContainer, RefMutOwner}; -use crate::{deserialize_bytes, implement_type_conversion}; use prost::Message as ProstMessage; use protobuf::Message as PbMessage; use pyo3::pyclass::CompareOp; @@ -30,18 +30,6 @@ pub enum PyEntryMut<'p> { implement_type_conversion!(Entry, PyEntryMut); -pub fn format_entry(entry: &Entry, py: Python) -> String { - format!( - "Entry {{ context: {context:}, data: {data:}, entry_type: {entry_type:?}, index: {index:}, sync_log: {sync_log:}, term: {term:} }}", - data=deserialize_bytes!(entry, "entry_data_deserializer", data, py), - context=deserialize_bytes!(entry, "entry_context_deserializer", context, py), - entry_type=entry.get_entry_type(), - index=entry.get_index(), - sync_log=entry.get_sync_log(), - term=entry.get_term(), - ) -} - #[pymethods] impl PyEntry { #[new] @@ -71,8 +59,8 @@ impl PyEntry { } } - pub fn __repr__(&self, py: Python) -> String { - format_entry(&self.inner.inner, py) + pub fn __repr__(&self) -> String { + format!("{:?}", self.inner.inner) } pub fn __richcmp__(&self, py: Python, rhs: PyEntryMut, op: CompareOp) -> PyObject { @@ -93,8 +81,8 @@ impl PyEntry { #[pymethods] impl PyEntryRef { - pub fn __repr__(&self, py: Python) -> PyResult { - self.inner.map_as_ref(|inner| format_entry(inner, py)) + pub fn __repr__(&self) -> PyResult { + self.inner.map_as_ref(|inner| format!("{:?}", inner)) } pub fn __richcmp__(&self, py: Python, rhs: PyEntryMut, op: CompareOp) -> PyResult { diff --git a/src/raftpb_bindings/message.rs b/src/raftpb_bindings/message.rs index 786affa..c61c1ff 100644 --- a/src/raftpb_bindings/message.rs +++ b/src/raftpb_bindings/message.rs @@ -7,15 +7,12 @@ use pyo3::{ types::{PyBytes, PyList}, }; +use crate::implement_type_conversion; use crate::utils::{ errors::to_pyresult, reference::{RefMutContainer, RefMutOwner}, unsafe_cast::make_mut, }; -use crate::{ - deserialize_bytes, implement_type_conversion, raftpb_bindings::entry::format_entry, - raftpb_bindings::snapshot::format_snapshot, -}; use raft::eraftpb::Message; use super::{ @@ -44,28 +41,6 @@ pub enum PyMessageMut<'p> { implement_type_conversion!(Message, PyMessageMut); -pub fn format_message(msg: &Message, py: Python) -> String { - format!( - "Message {{ msg_type: {msg_type:?}, to: {to:?}, from: {from:?}, term: {term:?}, log_term: {log_term:?}, index: {index:?}, entries: [{entries:?}], commit: {commit:?}, commit_term: {commit_term:?}, snapshot: {snapshot:?}, request_snapshot: {request_snapshot:?}, reject: {reject:?}, reject_hint: {reject_hint:?}, context: {context:?}, deprecated_priority: {deprecated_priority:?}, priority: {priority:?} }}", - msg_type=msg.get_msg_type(), - to=msg.get_to(), - from=msg.get_from(), - term=msg.get_term(), - log_term=msg.get_log_term(), - index=msg.get_index(), - entries=msg.get_entries().iter().map(|e| format_entry(e, py)).collect::>().join(", "), - commit=msg.get_commit(), - commit_term=msg.get_commit_term(), - snapshot=format_snapshot(msg.get_snapshot(), py), - request_snapshot=msg.get_request_snapshot(), - reject=msg.get_reject(), - reject_hint=msg.get_reject_hint(), - context=deserialize_bytes!(msg, "message_context_deserializer", context, py), - deprecated_priority=msg.get_deprecated_priority(), - priority=msg.get_priority(), - ) -} - #[pymethods] impl PyMessage { #[new] @@ -95,8 +70,8 @@ impl PyMessage { } } - pub fn __repr__(&self, py: Python) -> String { - format_message(&self.inner.inner, py) + pub fn __repr__(&self) -> String { + format!("{:?}", self.inner.inner) } pub fn __richcmp__(&self, py: Python, rhs: PyMessageMut, op: CompareOp) -> PyObject { @@ -117,8 +92,8 @@ impl PyMessage { #[pymethods] impl PyMessageRef { - pub fn __repr__(&self, py: Python) -> PyResult { - self.inner.map_as_ref(|inner| format_message(inner, py)) + pub fn __repr__(&self) -> PyResult { + self.inner.map_as_ref(|inner| format!("{:?}", inner)) } pub fn __richcmp__(&self, py: Python, rhs: PyMessageMut, op: CompareOp) -> PyResult { diff --git a/src/raftpb_bindings/snapshot.rs b/src/raftpb_bindings/snapshot.rs index d9a45f1..c61b5b7 100644 --- a/src/raftpb_bindings/snapshot.rs +++ b/src/raftpb_bindings/snapshot.rs @@ -2,11 +2,11 @@ use prost::Message as ProstMessage; use protobuf::Message as PbMessage; use pyo3::{intern, prelude::*, pyclass::CompareOp, types::PyBytes}; +use crate::implement_type_conversion; use crate::utils::{ errors::to_pyresult, reference::{RefMutContainer, RefMutOwner}, }; -use crate::{deserialize_bytes, implement_type_conversion}; use raft::eraftpb::Snapshot; use super::snapshot_metadata::{PySnapshotMetadataMut, PySnapshotMetadataRef}; @@ -31,14 +31,6 @@ pub enum PySnapshotMut<'p> { implement_type_conversion!(Snapshot, PySnapshotMut); -pub fn format_snapshot(snapshot: &Snapshot, py: Python) -> String { - format!( - "Snapshot {{ data: {data:?}, metadata: {metadata:?} }}", - data = deserialize_bytes!(snapshot, "snapshot_data_deserializer", data, py), - metadata = snapshot.metadata, - ) -} - #[pymethods] impl PySnapshot { #[new] @@ -68,8 +60,8 @@ impl PySnapshot { } } - pub fn __repr__(&self, py: Python) -> String { - format_snapshot(&self.inner, py) + pub fn __repr__(&self) -> String { + format!("{:?}", self.inner.inner) } pub fn __bool__(&self) -> bool { @@ -94,8 +86,8 @@ impl PySnapshot { #[pymethods] impl PySnapshotRef { - pub fn __repr__(&self, py: Python) -> PyResult { - self.inner.map_as_ref(|inner| format_snapshot(inner, py)) + pub fn __repr__(&self) -> PyResult { + self.inner.map_as_ref(|inner| format!("{:?}", inner)) } pub fn __richcmp__(&self, py: Python, rhs: PySnapshotMut, op: CompareOp) -> PyResult { diff --git a/src/utils/deserializer.rs b/src/utils/deserializer.rs index 9f9e214..af1ccb2 100644 --- a/src/utils/deserializer.rs +++ b/src/utils/deserializer.rs @@ -1,31 +1,61 @@ -use pyo3::{types::PyBytes, IntoPy, Python}; - -#[macro_export] -macro_rules! deserialize_bytes { - ($inner:ident, $deserializer_name: literal, $attr:ident, $py:ident) => {{ - if let Some(deserializer) = $py.eval($deserializer_name, None, None).ok() { - deserializer - .call((PyBytes::new($py, $inner.$attr.as_slice()),), None) - .unwrap() - .into_py($py) - .to_string() - } else { - format!("{:?}", $inner.$attr) - } - }}; -} +use std::sync::Mutex; -// IMPORTANT: UNCOMMENT ALL THE BELOW CODES when you want to use raft-rs's upstream repository. +use ::once_cell::sync::Lazy; +use pyo3::*; +use pyo3::{types::PyBytes, IntoPy, PyObject, Python}; use raft::derializer::{Bytes, CustomDeserializer}; pub struct MyDeserializer; // TODO: Refactor below codes to reduce code redundancy. +static ENTRY_CONTEXT_DESERIALIZE_CB: Lazy>> = Lazy::new(|| Mutex::new(None)); +static ENTRY_DATA_DESERIALIZE_CB: Lazy>> = Lazy::new(|| Mutex::new(None)); +static CONFCHANGEV2_CONTEXT_DESERIALIZE_CB: Lazy>> = + Lazy::new(|| Mutex::new(None)); +static CONFCHANGE_CONTEXT_DESERIALIZE_CB: Lazy>> = + Lazy::new(|| Mutex::new(None)); +static MESSAGE_CONTEXT_DESERIALIZER_CB: Lazy>> = + Lazy::new(|| Mutex::new(None)); +static SNAPSHOT_DATA_DESERIALIZER_CB: Lazy>> = + Lazy::new(|| Mutex::new(None)); + +#[pyfunction] +pub fn set_entry_context_deserializer(cb: PyObject) { + *ENTRY_CONTEXT_DESERIALIZE_CB.lock().unwrap() = Some(cb); +} + +#[pyfunction] +pub fn set_entry_data_deserializer(cb: PyObject) { + *ENTRY_DATA_DESERIALIZE_CB.lock().unwrap() = Some(cb); +} + +#[pyfunction] +pub fn set_confchangev2_context_deserializer(cb: PyObject) { + *CONFCHANGEV2_CONTEXT_DESERIALIZE_CB.lock().unwrap() = Some(cb); +} + +#[pyfunction] +pub fn set_confchange_context_deserializer(cb: PyObject) { + *CONFCHANGE_CONTEXT_DESERIALIZE_CB.lock().unwrap() = Some(cb); +} + +#[pyfunction] +pub fn set_message_context_deserializer(cb: PyObject) { + *MESSAGE_CONTEXT_DESERIALIZER_CB.lock().unwrap() = Some(cb); +} + +#[pyfunction] +pub fn set_snapshot_data_deserializer(cb: PyObject) { + *SNAPSHOT_DATA_DESERIALIZER_CB.lock().unwrap() = Some(cb); +} + impl CustomDeserializer for MyDeserializer { fn entry_context_deserialize(&self, v: &Bytes) -> String { fn deserialize(py: Python, data: &[u8]) -> String { - if let Some(deserializer) = py.eval("entry_context_deserializer", None, None).ok() { - deserializer - .call((PyBytes::new(py, data),), None) + let callback_lock = ENTRY_CONTEXT_DESERIALIZE_CB.lock().unwrap(); + + if let Some(callback) = &*callback_lock { + callback + .call(py, (PyBytes::new(py, data),), None) .unwrap() .into_py(py) .to_string() @@ -42,9 +72,11 @@ impl CustomDeserializer for MyDeserializer { fn entry_data_deserialize(&self, v: &Bytes) -> String { fn deserialize(py: Python, data: &[u8]) -> String { - if let Some(deserializer) = py.eval("entry_data_deserializer", None, None).ok() { - deserializer - .call((PyBytes::new(py, data),), None) + let callback_lock = ENTRY_DATA_DESERIALIZE_CB.lock().unwrap(); + + if let Some(callback) = &*callback_lock { + callback + .call(py, (PyBytes::new(py, data),), None) .unwrap() .into_py(py) .to_string() @@ -61,12 +93,11 @@ impl CustomDeserializer for MyDeserializer { fn confchangev2_context_deserialize(&self, v: &Bytes) -> String { fn deserialize(py: Python, data: &[u8]) -> String { - if let Some(deserializer) = py - .eval("confchangev2_context_deserializer", None, None) - .ok() - { - deserializer - .call((PyBytes::new(py, data),), None) + let callback_lock = CONFCHANGEV2_CONTEXT_DESERIALIZE_CB.lock().unwrap(); + + if let Some(callback) = &*callback_lock { + callback + .call(py, (PyBytes::new(py, data),), None) .unwrap() .into_py(py) .to_string() @@ -83,10 +114,11 @@ impl CustomDeserializer for MyDeserializer { fn confchange_context_deserialize(&self, v: &Bytes) -> String { fn deserialize(py: Python, data: &[u8]) -> String { - if let Some(deserializer) = py.eval("confchange_context_deserializer", None, None).ok() - { - deserializer - .call((PyBytes::new(py, data),), None) + let callback_lock = CONFCHANGE_CONTEXT_DESERIALIZE_CB.lock().unwrap(); + + if let Some(callback) = &*callback_lock { + callback + .call(py, (PyBytes::new(py, data),), None) .unwrap() .into_py(py) .to_string() @@ -103,9 +135,11 @@ impl CustomDeserializer for MyDeserializer { fn message_context_deserializer(&self, v: &Bytes) -> String { fn deserialize(py: Python, data: &[u8]) -> String { - if let Some(deserializer) = py.eval("message_context_deserializer", None, None).ok() { - deserializer - .call((PyBytes::new(py, data),), None) + let callback_lock = MESSAGE_CONTEXT_DESERIALIZER_CB.lock().unwrap(); + + if let Some(callback) = &*callback_lock { + callback + .call(py, (PyBytes::new(py, data),), None) .unwrap() .into_py(py) .to_string() @@ -122,9 +156,11 @@ impl CustomDeserializer for MyDeserializer { fn snapshot_data_deserializer(&self, v: &Bytes) -> String { fn deserialize(py: Python, data: &[u8]) -> String { - if let Some(deserializer) = py.eval("snapshot_data_deserializer", None, None).ok() { - deserializer - .call((PyBytes::new(py, data),), None) + let callback_lock = SNAPSHOT_DATA_DESERIALIZER_CB.lock().unwrap(); + + if let Some(callback) = &*callback_lock { + callback + .call(py, (PyBytes::new(py, data),), None) .unwrap() .into_py(py) .to_string()