Skip to content

Commit

Permalink
Add ackable events
Browse files Browse the repository at this point in the history
  • Loading branch information
mendess committed Sep 16, 2024
1 parent 3748c5b commit b5935fb
Show file tree
Hide file tree
Showing 7 changed files with 348 additions and 54 deletions.
81 changes: 81 additions & 0 deletions socketio/src/asynchronous/client/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use super::{
client::{Client, ReconnectSettings},
};
use crate::asynchronous::socket::Socket as InnerSocket;
use crate::AckId;

/// A builder class for a `socket.io` socket. This handles setting up the client and
/// configuring the callback, the namespace and metadata of the socket. If no
Expand Down Expand Up @@ -190,6 +191,51 @@ impl ClientBuilder {
+ 'static
+ Send
+ Sync,
{
self.on.insert(
event.into(),
Callback::<DynAsyncCallback>::new_no_ack(callback),
);
self
}

/// Registers a new callback for a certain [`crate::event::Event`] that expects the client to
/// ack. The event could either be one of the common events like `message`, `error`, `open`,
/// `close` or a custom event defined by a string, e.g. `onPayment` or `foo`.
///
/// # Example
/// ```rust
/// use rust_socketio::{asynchronous::{ClientBuilder, Client}, AckId, Payload};
/// use futures_util::FutureExt;
///
/// #[tokio::main]
/// async fn main() {
/// let socket = ClientBuilder::new("http://localhost:4200/")
/// .namespace("/admin")
/// .on("test", |payload: Payload, client: Client, ack: AckId| {
/// async move {
/// match payload {
/// Payload::Text(values) => println!("Received: {:#?}", values),
/// Payload::Binary(bin_data) => println!("Received bytes: {:#?}", bin_data),
/// // This is deprecated, use Payload::Text instead
/// Payload::String(str) => println!("Received: {}", str),
/// }
/// client.ack(ack, "received").await;
/// }
/// .boxed()
/// })
/// .on("error", |err, _| async move { eprintln!("Error: {:#?}", err) }.boxed())
/// .connect()
/// .await;
/// }
///
#[cfg(feature = "async-callbacks")]
pub fn on_with_ack<T: Into<Event>, F>(mut self, event: T, callback: F) -> Self
where
F: for<'a> std::ops::FnMut(Payload, Client, AckId) -> BoxFuture<'static, ()>
+ 'static
+ Send
+ Sync,
{
self.on
.insert(event.into(), Callback::<DynAsyncCallback>::new(callback));
Expand Down Expand Up @@ -257,6 +303,41 @@ impl ClientBuilder {
pub fn on_any<F>(mut self, callback: F) -> Self
where
F: for<'a> FnMut(Event, Payload, Client) -> BoxFuture<'static, ()> + 'static + Send + Sync,
{
self.on_any = Some(Callback::<DynAsyncAnyCallback>::new_no_ack(callback));
self
}

/// Registers a Callback for all [`crate::event::Event::Custom`] and
/// [`crate::event::Event::Message`] that expect the client to ack.
///
/// # Example
/// ```rust
/// use rust_socketio::{asynchronous::ClientBuilder, Payload};
/// use futures_util::future::FutureExt;
///
/// #[tokio::main]
/// async fn main() {
/// let client = ClientBuilder::new("http://localhost:4200/")
/// .namespace("/admin")
/// .on_any(|event, payload, client, ack| {
/// async {
/// if let Payload::String(str) = payload {
/// println!("{}: {}", String::from(event), str);
/// }
/// client.ack(ack, "received").await;
/// }.boxed()
/// })
/// .connect()
/// .await;
/// }
/// ```
pub fn on_any_with_ack<F>(mut self, callback: F) -> Self
where
F: for<'a> FnMut(Event, Payload, Client, AckId) -> BoxFuture<'static, ()>
+ 'static
+ Send
+ Sync,
{
self.on_any = Some(Callback::<DynAsyncAnyCallback>::new(callback));
self
Expand Down
71 changes: 56 additions & 15 deletions socketio/src/asynchronous/client/callback.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
use futures_util::future::BoxFuture;
use futures_util::{future::BoxFuture, FutureExt};
use std::{
fmt::Debug,
future::Future,
ops::{Deref, DerefMut},
};

use crate::{Event, Payload};
use crate::{AckId, Event, Payload};

use super::client::{Client, ReconnectSettings};

/// Internal type, provides a way to store futures and return them in a boxed manner.
pub(crate) type DynAsyncCallback =
Box<dyn for<'a> FnMut(Payload, Client) -> BoxFuture<'static, ()> + 'static + Send + Sync>;
pub(crate) type DynAsyncCallback = Box<
dyn for<'a> FnMut(Payload, Client, Option<AckId>) -> BoxFuture<'static, ()>
+ 'static
+ Send
+ Sync,
>;

pub(crate) type DynAsyncAnyCallback = Box<
dyn for<'a> FnMut(Event, Payload, Client) -> BoxFuture<'static, ()> + 'static + Send + Sync,
dyn for<'a> FnMut(Event, Payload, Client, Option<AckId>) -> BoxFuture<'static, ()>
+ 'static
+ Send
+ Sync,
>;

pub(crate) type DynAsyncReconnectSettingsCallback =
Expand All @@ -30,8 +38,10 @@ impl<T> Debug for Callback<T> {
}

impl Deref for Callback<DynAsyncCallback> {
type Target =
dyn for<'a> FnMut(Payload, Client) -> BoxFuture<'static, ()> + 'static + Sync + Send;
type Target = dyn for<'a> FnMut(Payload, Client, Option<AckId>) -> BoxFuture<'static, ()>
+ 'static
+ Sync
+ Send;

fn deref(&self) -> &Self::Target {
self.inner.as_ref()
Expand All @@ -45,19 +55,34 @@ impl DerefMut for Callback<DynAsyncCallback> {
}

impl Callback<DynAsyncCallback> {
pub(crate) fn new<T>(callback: T) -> Self
pub(crate) fn new<T>(mut callback: T) -> Self
where
T: for<'a> FnMut(Payload, Client) -> BoxFuture<'static, ()> + 'static + Sync + Send,
T: for<'a> FnMut(Payload, Client, AckId) -> BoxFuture<'static, ()> + 'static + Sync + Send,
{
Callback {
inner: Box::new(callback),
inner: Box::new(move |p, c, a| match a {
Some(a) => callback(p, c, a).boxed(),
None => std::future::ready(()).boxed(),
}),
}
}

pub(crate) fn new_no_ack<T, Fut>(mut callback: T) -> Self
where
T: FnMut(Payload, Client) -> Fut + Sync + Send + 'static,
Fut: Future<Output = ()> + 'static + Send,
{
Callback {
inner: Box::new(move |p, c, _a| callback(p, c).boxed()),
}
}
}

impl Deref for Callback<DynAsyncAnyCallback> {
type Target =
dyn for<'a> FnMut(Event, Payload, Client) -> BoxFuture<'static, ()> + 'static + Sync + Send;
type Target = dyn for<'a> FnMut(Event, Payload, Client, Option<AckId>) -> BoxFuture<'static, ()>
+ 'static
+ Sync
+ Send;

fn deref(&self) -> &Self::Target {
self.inner.as_ref()
Expand All @@ -71,12 +96,28 @@ impl DerefMut for Callback<DynAsyncAnyCallback> {
}

impl Callback<DynAsyncAnyCallback> {
pub(crate) fn new<T>(callback: T) -> Self
pub(crate) fn new<T>(mut callback: T) -> Self
where
T: for<'a> FnMut(Event, Payload, Client) -> BoxFuture<'static, ()> + 'static + Sync + Send,
T: for<'a> FnMut(Event, Payload, Client, AckId) -> BoxFuture<'static, ()>
+ 'static
+ Sync
+ Send,
{
Callback {
inner: Box::new(callback),
inner: Box::new(move |e, p, c, a| match a {
Some(a) => callback(e, p, c, a).boxed(),
None => std::future::ready(()).boxed(),
}),
}
}

pub(crate) fn new_no_ack<T, Fut>(mut callback: T) -> Self
where
T: FnMut(Event, Payload, Client) -> Fut + Sync + Send + 'static,
Fut: Future<Output = ()> + 'static + Send,
{
Callback {
inner: Box::new(move |e, p, c, _a| callback(e, p, c).boxed()),
}
}
}
Expand Down
47 changes: 34 additions & 13 deletions socketio/src/asynchronous/client/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ impl Client {
id,
time_started: Instant::now(),
timeout,
callback: Callback::<DynAsyncCallback>::new(callback),
callback: Callback::<DynAsyncCallback>::new_no_ack(callback),
};

// add the ack to the tuple of outstanding acks
Expand All @@ -376,19 +376,33 @@ impl Client {
self.socket.read().await.send(socket_packet).await
}

async fn callback<P: Into<Payload>>(&self, event: &Event, payload: P) -> Result<()> {
pub async fn ack<D>(&self, ack_id: AckId, data: D) -> Result<()>
where
D: Into<Payload>,
{
let socket_packet = Packet::new_ack(data.into(), &self.nsp, ack_id);

self.socket.read().await.send(socket_packet).await
}

async fn callback<P: Into<Payload>>(
&self,
event: &Event,
payload: P,
ack_id: Option<AckId>,
) -> Result<()> {
let mut builder = self.builder.write().await;
let payload = payload.into();

if let Some(callback) = builder.on.get_mut(event) {
callback(payload.clone(), self.clone()).await;
callback(payload.clone(), self.clone(), ack_id).await;
}

// Call on_any for all common and custom events.
match event {
Event::Message | Event::Custom(_) => {
if let Some(callback) = builder.on_any.as_mut() {
callback(event.clone(), payload, self.clone()).await;
callback(event.clone(), payload, self.clone(), ack_id).await;
}
}
_ => (),
Expand All @@ -411,6 +425,7 @@ impl Client {
ack.callback.deref_mut()(
Payload::from(payload.to_owned()),
self.clone(),
None,
)
.await;
}
Expand All @@ -419,6 +434,7 @@ impl Client {
ack.callback.deref_mut()(
Payload::Binary(payload.to_owned()),
self.clone(),
None,
)
.await;
}
Expand Down Expand Up @@ -446,8 +462,12 @@ impl Client {

if let Some(attachments) = &packet.attachments {
if let Some(binary_payload) = attachments.get(0) {
self.callback(&event, Payload::Binary(binary_payload.to_owned()))
.await?;
self.callback(
&event,
Payload::Binary(binary_payload.to_owned()),
packet.id,
)
.await?;
}
}
Ok(())
Expand Down Expand Up @@ -480,7 +500,7 @@ impl Client {
};

// call the correct callback
self.callback(&event, payloads.to_vec()).await?;
self.callback(&event, payloads.to_vec(), packet.id).await?;
}

Ok(())
Expand All @@ -495,22 +515,22 @@ impl Client {
match packet.packet_type {
PacketId::Ack | PacketId::BinaryAck => {
if let Err(err) = self.handle_ack(packet).await {
self.callback(&Event::Error, err.to_string()).await?;
self.callback(&Event::Error, err.to_string(), None).await?;
return Err(err);
}
}
PacketId::BinaryEvent => {
if let Err(err) = self.handle_binary_event(packet).await {
self.callback(&Event::Error, err.to_string()).await?;
self.callback(&Event::Error, err.to_string(), None).await?;
}
}
PacketId::Connect => {
*(self.disconnect_reason.write().await) = DisconnectReason::default();
self.callback(&Event::Connect, "").await?;
self.callback(&Event::Connect, "", None).await?;
}
PacketId::Disconnect => {
*(self.disconnect_reason.write().await) = DisconnectReason::Server;
self.callback(&Event::Close, "").await?;
self.callback(&Event::Close, "", None).await?;
}
PacketId::ConnectError => {
self.callback(
Expand All @@ -520,12 +540,13 @@ impl Client {
.data
.as_ref()
.unwrap_or(&String::from("\"No error message provided\"")),
None,
)
.await?;
}
PacketId::Event => {
if let Err(err) = self.handle_event(packet).await {
self.callback(&Event::Error, err.to_string()).await?;
self.callback(&Event::Error, err.to_string(), None).await?;
}
}
}
Expand All @@ -547,7 +568,7 @@ impl Client {
None => None,
Some(Err(err)) => {
// call the error callback
match self.callback(&Event::Error, err.to_string()).await {
match self.callback(&Event::Error, err.to_string(), None).await {
Err(callback_err) => Some((Err(callback_err), socket)),
Ok(_) => Some((Err(err), socket)),
}
Expand Down
Loading

0 comments on commit b5935fb

Please sign in to comment.