diff --git a/socketio/src/asynchronous/client/builder.rs b/socketio/src/asynchronous/client/builder.rs index 44710e19..9b457b7a 100644 --- a/socketio/src/asynchronous/client/builder.rs +++ b/socketio/src/asynchronous/client/builder.rs @@ -17,6 +17,7 @@ use super::{ client::{Client, ReconnectSettings}, }; use crate::asynchronous::socket::Socket as InnerSocket; +use crate::AckId; /// A builder class for a `socket.io` socket. This handles setting up the client and /// configuring the callback, the namespace and metadata of the socket. If no @@ -190,6 +191,51 @@ impl ClientBuilder { + 'static + Send + Sync, + { + self.on.insert( + event.into(), + Callback::::new_no_ack(callback), + ); + self + } + + /// Registers a new callback for a certain [`crate::event::Event`] that expects the client to + /// ack. The event could either be one of the common events like `message`, `error`, `open`, + /// `close` or a custom event defined by a string, e.g. `onPayment` or `foo`. + /// + /// # Example + /// ```rust + /// use rust_socketio::{asynchronous::{ClientBuilder, Client}, AckId, Payload}; + /// use futures_util::FutureExt; + /// + /// #[tokio::main] + /// async fn main() { + /// let socket = ClientBuilder::new("http://localhost:4200/") + /// .namespace("/admin") + /// .on("test", |payload: Payload, client: Client, ack: AckId| { + /// async move { + /// match payload { + /// Payload::Text(values) => println!("Received: {:#?}", values), + /// Payload::Binary(bin_data) => println!("Received bytes: {:#?}", bin_data), + /// // This is deprecated, use Payload::Text instead + /// Payload::String(str) => println!("Received: {}", str), + /// } + /// client.ack(ack, "received").await; + /// } + /// .boxed() + /// }) + /// .on("error", |err, _| async move { eprintln!("Error: {:#?}", err) }.boxed()) + /// .connect() + /// .await; + /// } + /// + #[cfg(feature = "async-callbacks")] + pub fn on_with_ack, F>(mut self, event: T, callback: F) -> Self + where + F: for<'a> std::ops::FnMut(Payload, Client, AckId) -> BoxFuture<'static, ()> + + 'static + + Send + + Sync, { self.on .insert(event.into(), Callback::::new(callback)); @@ -257,6 +303,41 @@ impl ClientBuilder { pub fn on_any(mut self, callback: F) -> Self where F: for<'a> FnMut(Event, Payload, Client) -> BoxFuture<'static, ()> + 'static + Send + Sync, + { + self.on_any = Some(Callback::::new_no_ack(callback)); + self + } + + /// Registers a Callback for all [`crate::event::Event::Custom`] and + /// [`crate::event::Event::Message`] that expect the client to ack. + /// + /// # Example + /// ```rust + /// use rust_socketio::{asynchronous::ClientBuilder, Payload}; + /// use futures_util::future::FutureExt; + /// + /// #[tokio::main] + /// async fn main() { + /// let client = ClientBuilder::new("http://localhost:4200/") + /// .namespace("/admin") + /// .on_any(|event, payload, client, ack| { + /// async { + /// if let Payload::String(str) = payload { + /// println!("{}: {}", String::from(event), str); + /// } + /// client.ack(ack, "received").await; + /// }.boxed() + /// }) + /// .connect() + /// .await; + /// } + /// ``` + pub fn on_any_with_ack(mut self, callback: F) -> Self + where + F: for<'a> FnMut(Event, Payload, Client, AckId) -> BoxFuture<'static, ()> + + 'static + + Send + + Sync, { self.on_any = Some(Callback::::new(callback)); self diff --git a/socketio/src/asynchronous/client/callback.rs b/socketio/src/asynchronous/client/callback.rs index 3188b175..47f027a0 100644 --- a/socketio/src/asynchronous/client/callback.rs +++ b/socketio/src/asynchronous/client/callback.rs @@ -1,19 +1,27 @@ -use futures_util::future::BoxFuture; +use futures_util::{future::BoxFuture, FutureExt}; use std::{ fmt::Debug, + future::Future, ops::{Deref, DerefMut}, }; -use crate::{Event, Payload}; +use crate::{AckId, Event, Payload}; use super::client::{Client, ReconnectSettings}; /// Internal type, provides a way to store futures and return them in a boxed manner. -pub(crate) type DynAsyncCallback = - Box FnMut(Payload, Client) -> BoxFuture<'static, ()> + 'static + Send + Sync>; +pub(crate) type DynAsyncCallback = Box< + dyn for<'a> FnMut(Payload, Client, Option) -> BoxFuture<'static, ()> + + 'static + + Send + + Sync, +>; pub(crate) type DynAsyncAnyCallback = Box< - dyn for<'a> FnMut(Event, Payload, Client) -> BoxFuture<'static, ()> + 'static + Send + Sync, + dyn for<'a> FnMut(Event, Payload, Client, Option) -> BoxFuture<'static, ()> + + 'static + + Send + + Sync, >; pub(crate) type DynAsyncReconnectSettingsCallback = @@ -30,8 +38,10 @@ impl Debug for Callback { } impl Deref for Callback { - type Target = - dyn for<'a> FnMut(Payload, Client) -> BoxFuture<'static, ()> + 'static + Sync + Send; + type Target = dyn for<'a> FnMut(Payload, Client, Option) -> BoxFuture<'static, ()> + + 'static + + Sync + + Send; fn deref(&self) -> &Self::Target { self.inner.as_ref() @@ -45,19 +55,34 @@ impl DerefMut for Callback { } impl Callback { - pub(crate) fn new(callback: T) -> Self + pub(crate) fn new(mut callback: T) -> Self where - T: for<'a> FnMut(Payload, Client) -> BoxFuture<'static, ()> + 'static + Sync + Send, + T: for<'a> FnMut(Payload, Client, AckId) -> BoxFuture<'static, ()> + 'static + Sync + Send, { Callback { - inner: Box::new(callback), + inner: Box::new(move |p, c, a| match a { + Some(a) => callback(p, c, a).boxed(), + None => std::future::ready(()).boxed(), + }), + } + } + + pub(crate) fn new_no_ack(mut callback: T) -> Self + where + T: FnMut(Payload, Client) -> Fut + Sync + Send + 'static, + Fut: Future + 'static + Send, + { + Callback { + inner: Box::new(move |p, c, _a| callback(p, c).boxed()), } } } impl Deref for Callback { - type Target = - dyn for<'a> FnMut(Event, Payload, Client) -> BoxFuture<'static, ()> + 'static + Sync + Send; + type Target = dyn for<'a> FnMut(Event, Payload, Client, Option) -> BoxFuture<'static, ()> + + 'static + + Sync + + Send; fn deref(&self) -> &Self::Target { self.inner.as_ref() @@ -71,12 +96,28 @@ impl DerefMut for Callback { } impl Callback { - pub(crate) fn new(callback: T) -> Self + pub(crate) fn new(mut callback: T) -> Self where - T: for<'a> FnMut(Event, Payload, Client) -> BoxFuture<'static, ()> + 'static + Sync + Send, + T: for<'a> FnMut(Event, Payload, Client, AckId) -> BoxFuture<'static, ()> + + 'static + + Sync + + Send, { Callback { - inner: Box::new(callback), + inner: Box::new(move |e, p, c, a| match a { + Some(a) => callback(e, p, c, a).boxed(), + None => std::future::ready(()).boxed(), + }), + } + } + + pub(crate) fn new_no_ack(mut callback: T) -> Self + where + T: FnMut(Event, Payload, Client) -> Fut + Sync + Send + 'static, + Fut: Future + 'static + Send, + { + Callback { + inner: Box::new(move |e, p, c, _a| callback(e, p, c).boxed()), } } } diff --git a/socketio/src/asynchronous/client/client.rs b/socketio/src/asynchronous/client/client.rs index 72725263..01b44639 100644 --- a/socketio/src/asynchronous/client/client.rs +++ b/socketio/src/asynchronous/client/client.rs @@ -367,7 +367,7 @@ impl Client { id, time_started: Instant::now(), timeout, - callback: Callback::::new(callback), + callback: Callback::::new_no_ack(callback), }; // add the ack to the tuple of outstanding acks @@ -376,19 +376,33 @@ impl Client { self.socket.read().await.send(socket_packet).await } - async fn callback>(&self, event: &Event, payload: P) -> Result<()> { + pub async fn ack(&self, ack_id: AckId, data: D) -> Result<()> + where + D: Into, + { + let socket_packet = Packet::new_ack(data.into(), &self.nsp, ack_id); + + self.socket.read().await.send(socket_packet).await + } + + async fn callback>( + &self, + event: &Event, + payload: P, + ack_id: Option, + ) -> Result<()> { let mut builder = self.builder.write().await; let payload = payload.into(); if let Some(callback) = builder.on.get_mut(event) { - callback(payload.clone(), self.clone()).await; + callback(payload.clone(), self.clone(), ack_id).await; } // Call on_any for all common and custom events. match event { Event::Message | Event::Custom(_) => { if let Some(callback) = builder.on_any.as_mut() { - callback(event.clone(), payload, self.clone()).await; + callback(event.clone(), payload, self.clone(), ack_id).await; } } _ => (), @@ -411,6 +425,7 @@ impl Client { ack.callback.deref_mut()( Payload::from(payload.to_owned()), self.clone(), + None, ) .await; } @@ -419,6 +434,7 @@ impl Client { ack.callback.deref_mut()( Payload::Binary(payload.to_owned()), self.clone(), + None, ) .await; } @@ -446,8 +462,12 @@ impl Client { if let Some(attachments) = &packet.attachments { if let Some(binary_payload) = attachments.get(0) { - self.callback(&event, Payload::Binary(binary_payload.to_owned())) - .await?; + self.callback( + &event, + Payload::Binary(binary_payload.to_owned()), + packet.id, + ) + .await?; } } Ok(()) @@ -480,7 +500,7 @@ impl Client { }; // call the correct callback - self.callback(&event, payloads.to_vec()).await?; + self.callback(&event, payloads.to_vec(), packet.id).await?; } Ok(()) @@ -495,22 +515,22 @@ impl Client { match packet.packet_type { PacketId::Ack | PacketId::BinaryAck => { if let Err(err) = self.handle_ack(packet).await { - self.callback(&Event::Error, err.to_string()).await?; + self.callback(&Event::Error, err.to_string(), None).await?; return Err(err); } } PacketId::BinaryEvent => { if let Err(err) = self.handle_binary_event(packet).await { - self.callback(&Event::Error, err.to_string()).await?; + self.callback(&Event::Error, err.to_string(), None).await?; } } PacketId::Connect => { *(self.disconnect_reason.write().await) = DisconnectReason::default(); - self.callback(&Event::Connect, "").await?; + self.callback(&Event::Connect, "", None).await?; } PacketId::Disconnect => { *(self.disconnect_reason.write().await) = DisconnectReason::Server; - self.callback(&Event::Close, "").await?; + self.callback(&Event::Close, "", None).await?; } PacketId::ConnectError => { self.callback( @@ -520,12 +540,13 @@ impl Client { .data .as_ref() .unwrap_or(&String::from("\"No error message provided\"")), + None, ) .await?; } PacketId::Event => { if let Err(err) = self.handle_event(packet).await { - self.callback(&Event::Error, err.to_string()).await?; + self.callback(&Event::Error, err.to_string(), None).await?; } } } @@ -547,7 +568,7 @@ impl Client { None => None, Some(Err(err)) => { // call the error callback - match self.callback(&Event::Error, err.to_string()).await { + match self.callback(&Event::Error, err.to_string(), None).await { Err(callback_err) => Some((Err(callback_err), socket)), Ok(_) => Some((Err(err), socket)), } diff --git a/socketio/src/client/builder.rs b/socketio/src/client/builder.rs index 724971f0..7b6e53b9 100644 --- a/socketio/src/client/builder.rs +++ b/socketio/src/client/builder.rs @@ -1,7 +1,7 @@ use super::super::{event::Event, payload::Payload}; use super::callback::Callback; use super::client::Client; -use crate::RawClient; +use crate::{AckId, RawClient}; use native_tls::TlsConnector; use rust_engineio::client::ClientBuilder as EngineIoClientBuilder; use rust_engineio::header::{HeaderMap, HeaderValue}; @@ -173,6 +173,41 @@ impl ClientBuilder { pub fn on, F>(mut self, event: T, callback: F) -> Self where F: FnMut(Payload, RawClient) + 'static + Send, + { + let callback = Callback::::new_no_ack(callback); + // SAFETY: Lock is held for such amount of time no code paths lead to a panic while lock is held + self.on.lock().unwrap().insert(event.into(), callback); + self + } + + /// Registers a new callback for a certain [`crate::event::Event`] that expects the client to + /// ack. The event could either be one of the common events like `message`, `error`, `open`, + /// `close` or a custom event defined by a string, e.g. `onPayment` or `foo`. + /// + /// # Example + /// ```rust + /// use rust_socketio::{ClientBuilder, Payload}; + /// + /// let socket = ClientBuilder::new("http://localhost:4200/") + /// .namespace("/admin") + /// .on("test", |payload: Payload, client, ack_id| { + /// match payload { + /// Payload::Text(values) => println!("Received: {:#?}", values), + /// Payload::Binary(bin_data) => println!("Received bytes: {:#?}", bin_data), + /// // This payload type is deprecated, use Payload::Text instead + /// Payload::String(str) => println!("Received: {}", str), + /// } + /// client.ack(ack_id, "received"); + /// }) + /// .on("error", |err, _| eprintln!("Error: {:#?}", err)) + /// .connect(); + /// + /// ``` + // While present implementation doesn't require mut, it's reasonable to require mutability. + #[allow(unused_mut)] + pub fn on_with_ack, F>(mut self, event: T, callback: F) -> Self + where + F: FnMut(Payload, RawClient, AckId) + 'static + Send, { let callback = Callback::::new(callback); // SAFETY: Lock is held for such amount of time no code paths lead to a panic while lock is held @@ -201,6 +236,36 @@ impl ClientBuilder { pub fn on_any(mut self, callback: F) -> Self where F: FnMut(Event, Payload, RawClient) + 'static + Send, + { + let callback = Some(Callback::::new_no_ack(callback)); + // SAFETY: Lock is held for such amount of time no code paths lead to a panic while lock is held + *self.on_any.lock().unwrap() = callback; + self + } + + /// Registers a Callback for all [`crate::event::Event::Custom`] and + /// [`crate::event::Event::Message`] that expects the client to ack. + /// + /// # Example + /// ```rust + /// use rust_socketio::{ClientBuilder, Payload}; + /// + /// let client = ClientBuilder::new("http://localhost:4200/") + /// .namespace("/admin") + /// .on_any(|event, payload, client, ack_id| { + /// if let Payload::String(str) = payload { + /// println!("{} {}", String::from(event), str); + /// } + /// client.ack(ack_id, "received") + /// }) + /// .connect(); + /// + /// ``` + // While present implementation doesn't require mut, it's reasonable to require mutability. + #[allow(unused_mut)] + pub fn on_any_with_ack(mut self, callback: F) -> Self + where + F: FnMut(Event, Payload, RawClient, AckId) + 'static + Send, { let callback = Some(Callback::::new(callback)); // SAFETY: Lock is held for such amount of time no code paths lead to a panic while lock is held diff --git a/socketio/src/client/callback.rs b/socketio/src/client/callback.rs index 1015ec03..2a9d7a7b 100644 --- a/socketio/src/client/callback.rs +++ b/socketio/src/client/callback.rs @@ -4,10 +4,11 @@ use std::{ }; use super::RawClient; -use crate::{Event, Payload}; +use crate::{AckId, Event, Payload}; -pub(crate) type SocketCallback = Box; -pub(crate) type SocketAnyCallback = Box; +pub(crate) type SocketCallback = Box) + 'static + Send>; +pub(crate) type SocketAnyCallback = + Box) + 'static + Send>; pub(crate) struct Callback { inner: T, @@ -22,7 +23,7 @@ impl Debug for Callback { } impl Deref for Callback { - type Target = dyn FnMut(Payload, RawClient) + 'static + Send; + type Target = dyn FnMut(Payload, RawClient, Option) + 'static + Send; fn deref(&self) -> &Self::Target { self.inner.as_ref() @@ -36,12 +37,25 @@ impl DerefMut for Callback { } impl Callback { - pub(crate) fn new(callback: T) -> Self + pub(crate) fn new(mut callback: T) -> Self + where + T: FnMut(Payload, RawClient, AckId) + 'static + Send, + { + Callback { + inner: Box::new(move |p, c, a| { + if let Some(a) = a { + callback(p, c, a) + } + }), + } + } + + pub(crate) fn new_no_ack(mut callback: T) -> Self where T: FnMut(Payload, RawClient) + 'static + Send, { Callback { - inner: Box::new(callback), + inner: Box::new(move |p, c, _a| callback(p, c)), } } } @@ -55,7 +69,7 @@ impl Debug for Callback { } impl Deref for Callback { - type Target = dyn FnMut(Event, Payload, RawClient) + 'static + Send; + type Target = dyn FnMut(Event, Payload, RawClient, Option) + 'static + Send; fn deref(&self) -> &Self::Target { self.inner.as_ref() @@ -69,12 +83,25 @@ impl DerefMut for Callback { } impl Callback { - pub(crate) fn new(callback: T) -> Self + pub(crate) fn new(mut callback: T) -> Self + where + T: FnMut(Event, Payload, RawClient, AckId) + 'static + Send, + { + Callback { + inner: Box::new(move |e, p, c, a| { + if let Some(a) = a { + callback(e, p, c, a) + } + }), + } + } + + pub(crate) fn new_no_ack(mut callback: T) -> Self where T: FnMut(Event, Payload, RawClient) + 'static + Send, { Callback { - inner: Box::new(callback), + inner: Box::new(move |e, p, c, _a| callback(e, p, c)), } } } diff --git a/socketio/src/client/raw_client.rs b/socketio/src/client/raw_client.rs index 9fbb5ab3..65743299 100644 --- a/socketio/src/client/raw_client.rs +++ b/socketio/src/client/raw_client.rs @@ -149,7 +149,7 @@ impl RawClient { let _ = self.socket.send(disconnect_packet); self.socket.disconnect()?; - let _ = self.callback(&Event::Close, ""); // trigger on_close + let _ = self.callback(&Event::Close, "", None); // trigger on_close Ok(()) } @@ -211,7 +211,7 @@ impl RawClient { id, time_started: Instant::now(), timeout, - callback: Callback::::new(callback), + callback: Callback::::new_no_ack(callback), }; // add the ack to the tuple of outstanding acks @@ -221,11 +221,19 @@ impl RawClient { Ok(()) } + pub fn ack(&self, ack_id: AckId, data: D) -> Result<()> + where + D: Into, + { + let socket_packet = Packet::new_ack(data.into(), &self.nsp, ack_id); + self.socket.send(socket_packet) + } + pub(crate) fn poll(&self) -> Result> { loop { match self.socket.poll() { Err(err) => { - self.callback(&Event::Error, err.to_string())?; + self.callback(&Event::Error, err.to_string(), None)?; return Err(err); } Ok(Some(packet)) => { @@ -246,7 +254,12 @@ impl RawClient { Iter { socket: self } } - fn callback>(&self, event: &Event, payload: P) -> Result<()> { + fn callback>( + &self, + event: &Event, + payload: P, + ack_id: Option, + ) -> Result<()> { let mut on = self.on.lock()?; let mut on_any = self.on_any.lock()?; let lock = on.deref_mut(); @@ -255,12 +268,12 @@ impl RawClient { let payload = payload.into(); if let Some(callback) = lock.get_mut(event) { - callback(payload.clone(), self.clone()); + callback(payload.clone(), self.clone(), ack_id); } match event { Event::Message | Event::Custom(_) => { if let Some(callback) = on_any_lock { - callback(event.clone(), payload, self.clone()) + callback(event.clone(), payload, self.clone(), ack_id) } } _ => {} @@ -284,12 +297,16 @@ impl RawClient { if ack.time_started.elapsed() < ack.timeout { if let Some(ref payload) = socket_packet.data { - ack.callback.deref_mut()(Payload::from(payload.to_owned()), self.clone()); + ack.callback.deref_mut()(Payload::from(payload.to_owned()), self.clone(), None); } if let Some(ref attachments) = socket_packet.attachments { if let Some(payload) = attachments.first() { - ack.callback.deref_mut()(Payload::Binary(payload.to_owned()), self.clone()); + ack.callback.deref_mut()( + Payload::Binary(payload.to_owned()), + self.clone(), + None, + ); } } } @@ -312,7 +329,11 @@ impl RawClient { if let Some(attachments) = &packet.attachments { if let Some(binary_payload) = attachments.first() { - self.callback(&event, Payload::Binary(binary_payload.to_owned()))?; + self.callback( + &event, + Payload::Binary(binary_payload.to_owned()), + packet.id, + )?; } } Ok(()) @@ -344,7 +365,7 @@ impl RawClient { }; // call the correct callback - self.callback(&event, payloads.to_vec())?; + self.callback(&event, payloads.to_vec(), packet.id)?; } Ok(()) @@ -359,20 +380,20 @@ impl RawClient { match packet.packet_type { PacketId::Ack | PacketId::BinaryAck => { if let Err(err) = self.handle_ack(packet) { - self.callback(&Event::Error, err.to_string())?; + self.callback(&Event::Error, err.to_string(), None)?; return Err(err); } } PacketId::BinaryEvent => { if let Err(err) = self.handle_binary_event(packet) { - self.callback(&Event::Error, err.to_string())?; + self.callback(&Event::Error, err.to_string(), None)?; } } PacketId::Connect => { - self.callback(&Event::Connect, "")?; + self.callback(&Event::Connect, "", None)?; } PacketId::Disconnect => { - self.callback(&Event::Close, "")?; + self.callback(&Event::Close, "", None)?; } PacketId::ConnectError => { self.callback( @@ -382,11 +403,12 @@ impl RawClient { .clone() .data .unwrap_or_else(|| String::from("\"No error message provided\"")), + None, )?; } PacketId::Event => { if let Err(err) = self.handle_event(packet) { - self.callback(&Event::Error, err.to_string())?; + self.callback(&Event::Error, err.to_string(), None)?; } } } diff --git a/socketio/src/packet.rs b/socketio/src/packet.rs index 48cc6312..689a64bf 100644 --- a/socketio/src/packet.rs +++ b/socketio/src/packet.rs @@ -88,6 +88,43 @@ impl Packet { } } } + + pub(crate) fn new_ack(payload: Payload, nsp: &str, id: AckId) -> Self { + match payload { + Payload::Text(data) => Packet::new( + PacketId::Ack, + nsp.to_owned(), + Some(serde_json::Value::Array(data).to_string()), + Some(id), + 0, + None, + ), + #[allow(deprecated)] + Payload::String(str_data) => { + let payload = if serde_json::from_str::(&str_data).is_ok() { + format!("[{str_data}]") + } else { + format!("[{str_data:?}]") + }; + Packet::new( + PacketId::Ack, + nsp.to_owned(), + Some(payload), + Some(id), + 0, + None, + ) + } + Payload::Binary(data) => Packet::new( + PacketId::BinaryAck, + nsp.to_owned(), + None, + Some(id), + 1, + Some(vec![data]), + ), + } + } } impl Default for Packet { @@ -605,7 +642,7 @@ mod test { #[test] fn test_illegal_packet_id() { let _sut = PacketId::try_from(42).expect_err("error!"); - assert!(matches!(Error::InvalidPacketId(42 as char), _sut)) + assert!(matches!(Error::InvalidPacketId(42u8 as char), _sut)) } #[test]