diff --git a/data/Cargo.toml b/data/Cargo.toml index 5d89cb5..1e89017 100644 --- a/data/Cargo.toml +++ b/data/Cargo.toml @@ -10,12 +10,10 @@ license = "MIT/Apache-2.0" shared = { path = "../shared", package = "shared", default-features = false, features = ["marshal"] } sctp = { path = "../sctp", package = "sctp" } -tokio = { version = "1.19", features = ["full"] } bytes = "1" log = "0.4.16" thiserror = "1.0" [dev-dependencies] -tokio-test = "0.4.0" # must match the min version of the `tokio` crate above env_logger = "0.9.0" chrono = "0.4.23" diff --git a/data/src/data_channel/mod.rs b/data/src/data_channel/mod.rs index 1def160..683baf7 100644 --- a/data/src/data_channel/mod.rs +++ b/data/src/data_channel/mod.rs @@ -1,26 +1,12 @@ -#[cfg(test)] -mod data_channel_test; - -use std::borrow::Borrow; -use std::future::Future; -use std::net::Shutdown; -use std::pin::Pin; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; -use std::task::{Context, Poll}; -use std::{fmt, io}; - -use bytes::{Buf, Bytes}; -use sctp::association::Association; -use sctp::chunk::chunk_payload_data::PayloadProtocolIdentifier; -use sctp::stream::*; -use shared::marshal::*; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +//#[cfg(test)] +//mod data_channel_test; -use crate::message::message_channel_ack::*; -use crate::message::message_channel_open::*; -use crate::message::*; +use crate::message::{message_channel_ack::*, message_channel_open::*, *}; +use bytes::{Buf, BytesMut}; +use sctp::{PayloadProtocolIdentifier, ReliabilityType}; use shared::error::{Error, Result}; +use shared::marshal::*; +use std::collections::VecDeque; const RECEIVE_MTU: usize = 8192; @@ -35,71 +21,47 @@ pub struct Config { pub protocol: String, } +/// Transmit is used to data sent over over SCTP +#[derive(Debug, Default, Clone)] +pub struct Transmit { + pub association_handle: usize, + pub stream_id: u16, + pub ppi: PayloadProtocolIdentifier, + pub unordered: bool, + pub reliability_type: ReliabilityType, + pub payload: BytesMut, +} + /// DataChannel represents a data channel #[derive(Debug, Default, Clone)] pub struct DataChannel { - pub config: Config, - stream: Arc, + config: Config, + association_handle: usize, + stream_id: u16, + transmits: VecDeque, // stats - messages_sent: Arc, - messages_received: Arc, - bytes_sent: Arc, - bytes_received: Arc, + messages_sent: usize, + messages_received: usize, + bytes_sent: usize, + bytes_received: usize, } impl DataChannel { - pub fn new(stream: Arc, config: Config) -> Self { + fn new(config: Config, association_handle: usize, stream_id: u16) -> Self { Self { config, - stream, + association_handle, + stream_id, + transmits: VecDeque::new(), ..Default::default() } } /// Dial opens a data channels over SCTP - pub async fn dial( - association: &Arc, - identifier: u16, - config: Config, - ) -> Result { - let stream = association - .open_stream(identifier, PayloadProtocolIdentifier::Binary) - .await?; - - Self::client(stream, config).await - } + pub fn dial(config: Config, association_handle: usize, stream_id: u16) -> Result { + let mut data_channel = DataChannel::new(config.clone(), association_handle, stream_id); - /// Accept is used to accept incoming data channels over SCTP - pub async fn accept( - association: &Arc, - config: Config, - existing_channels: &[T], - ) -> Result - where - T: Borrow, - { - let stream = association - .accept_stream() - .await - .ok_or(Error::ErrStreamClosed)?; - - for channel in existing_channels.iter().map(|ch| ch.borrow()) { - if channel.stream_identifier() == stream.stream_identifier() { - let ch = channel.to_owned(); - ch.stream - .set_default_payload_type(PayloadProtocolIdentifier::Binary); - return Ok(ch); - } - } - - stream.set_default_payload_type(PayloadProtocolIdentifier::Binary); - - Self::server(stream, config).await - } - - /// Client opens a data channel over an SCTP stream - pub async fn client(stream: Arc, config: Config) -> Result { if !config.negotiated { let msg = Message::DataChannelOpen(DataChannelOpen { channel_type: config.channel_type, @@ -110,24 +72,34 @@ impl DataChannel { }) .marshal()?; - stream - .write_sctp(&msg, PayloadProtocolIdentifier::Dcep) - .await?; - } - Ok(DataChannel::new(stream, config)) - } + let (unordered, reliability_type) = data_channel.get_reliability_params(); - /// Server accepts a data channel over an SCTP stream - pub async fn server(stream: Arc, mut config: Config) -> Result { - let mut buf = vec![0u8; RECEIVE_MTU]; + data_channel.transmits.push_back(Transmit { + association_handle, + stream_id, + ppi: PayloadProtocolIdentifier::Dcep, + unordered, + reliability_type, + payload: msg, + }); + } - let (n, ppi) = stream.read_sctp(&mut buf).await?; + Ok(data_channel) + } + /// Accept is used to accept incoming data channels over SCTP + pub fn accept( + mut config: Config, + association_handle: usize, + stream_id: u16, + ppi: PayloadProtocolIdentifier, + buf: &[u8], + ) -> Result { if ppi != PayloadProtocolIdentifier::Dcep { return Err(Error::InvalidPayloadProtocolIdentifier(ppi as u8)); } - let mut read_buf = &buf[..n]; + let mut read_buf = buf; let msg = Message::unmarshal(&mut read_buf)?; if let Message::DataChannelOpen(dco) = msg { @@ -140,99 +112,87 @@ impl DataChannel { return Err(Error::InvalidMessageType(msg.message_type() as u8)); }; - let data_channel = DataChannel::new(stream, config); + let mut data_channel = DataChannel::new(config, association_handle, stream_id); - data_channel.write_data_channel_ack().await?; - data_channel.commit_reliability_params(); + data_channel.write_data_channel_ack()?; Ok(data_channel) } /// Read reads a packet of len(p) bytes as binary data. - /// - /// See [`sctp::stream::Stream::read_sctp`]. - pub async fn read(&self, buf: &mut [u8]) -> Result { - self.read_data_channel(buf).await.map(|(n, _)| n) + pub fn read(&mut self, ppi: PayloadProtocolIdentifier, buf: &[u8]) -> Result { + self.read_data_channel(ppi, buf).map(|(b, _)| b) } /// ReadDataChannel reads a packet of len(p) bytes. It returns the number of bytes read and /// `true` if the data read is a string. - /// - /// See [`sctp::stream::Stream::read_sctp`]. - pub async fn read_data_channel(&self, buf: &mut [u8]) -> Result<(usize, bool)> { - loop { - //TODO: add handling of cancel read_data_channel - let (mut n, ppi) = match self.stream.read_sctp(buf).await { - Ok((0, PayloadProtocolIdentifier::Unknown)) => { - // The incoming stream was reset or the reading half was shutdown - return Ok((0, false)); - } - Ok((n, ppi)) => (n, ppi), - Err(err) => { - // Shutdown the stream and send the reset request to the remote. - self.close().await?; - return Err(err.into()); - } - }; - - let mut is_string = false; - match ppi { - PayloadProtocolIdentifier::Dcep => { - let mut data = &buf[..n]; - match self.handle_dcep(&mut data).await { - Ok(()) => {} - Err(err) => { - log::error!("Failed to handle DCEP: {:?}", err); - } + pub fn read_data_channel( + &mut self, + ppi: PayloadProtocolIdentifier, + buf: &[u8], + ) -> Result<(BytesMut, bool)> { + let mut is_string = false; + match ppi { + PayloadProtocolIdentifier::Dcep => { + let mut data_buf = buf; + match self.handle_dcep(&mut data_buf) { + Ok(()) => {} + Err(err) => { + log::error!("Failed to handle DCEP: {:?}", err); + return Err(err); } - continue; - } - PayloadProtocolIdentifier::String | PayloadProtocolIdentifier::StringEmpty => { - is_string = true; } - _ => {} - }; + } + PayloadProtocolIdentifier::String | PayloadProtocolIdentifier::StringEmpty => { + is_string = true; + } + _ => {} + }; - match ppi { - PayloadProtocolIdentifier::StringEmpty | PayloadProtocolIdentifier::BinaryEmpty => { - n = 0; - } - _ => {} - }; + let data = match ppi { + PayloadProtocolIdentifier::StringEmpty | PayloadProtocolIdentifier::BinaryEmpty => { + BytesMut::new() + } + _ => BytesMut::from(buf), + }; - self.messages_received.fetch_add(1, Ordering::SeqCst); - self.bytes_received.fetch_add(n, Ordering::SeqCst); + self.messages_received += 1; + self.bytes_received += 1; - return Ok((n, is_string)); - } + Ok((data, is_string)) } /// MessagesSent returns the number of messages sent pub fn messages_sent(&self) -> usize { - self.messages_sent.load(Ordering::SeqCst) + self.messages_sent } /// MessagesReceived returns the number of messages received pub fn messages_received(&self) -> usize { - self.messages_received.load(Ordering::SeqCst) + self.messages_received } /// BytesSent returns the number of bytes sent pub fn bytes_sent(&self) -> usize { - self.bytes_sent.load(Ordering::SeqCst) + self.bytes_sent } /// BytesReceived returns the number of bytes received pub fn bytes_received(&self) -> usize { - self.bytes_received.load(Ordering::SeqCst) + self.bytes_received + } + + /// association_handle returns the association handle + pub fn association_handle(&self) -> usize { + self.association_handle } /// StreamIdentifier returns the Stream identifier associated to the stream. pub fn stream_identifier(&self) -> u16 { - self.stream.stream_identifier() + self.stream_id } - async fn handle_dcep(&self, data: &mut B) -> Result<()> + fn handle_dcep(&mut self, data: &mut B) -> Result<()> where B: Buf, { @@ -243,11 +203,11 @@ impl DataChannel { // Note: DATA_CHANNEL_OPEN message is handled inside Server() method. // Therefore, the message will not reach here. log::debug!("Received DATA_CHANNEL_OPEN"); - let _ = self.write_data_channel_ack().await?; + self.write_data_channel_ack()?; } Message::DataChannelAck(_) => { log::debug!("Received DATA_CHANNEL_ACK"); - self.commit_reliability_params(); + //self.commit_reliability_params(); } }; @@ -255,12 +215,12 @@ impl DataChannel { } /// Write writes len(p) bytes from p as binary data - pub async fn write(&self, data: &Bytes) -> Result { - self.write_data_channel(data, false).await + pub fn write(&mut self, data: &[u8]) -> Result { + self.write_data_channel(data, false) } /// WriteDataChannel writes len(p) bytes from p - pub async fn write_data_channel(&self, data: &Bytes, is_string: bool) -> Result { + pub fn write_data_channel(&mut self, data: &[u8], is_string: bool) -> Result { let data_len = data.len(); // https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-12#section-6.6 @@ -277,30 +237,51 @@ impl DataChannel { (true, _) => PayloadProtocolIdentifier::String, }; + let (unordered, reliability_type) = self.get_reliability_params(); + let n = if data_len == 0 { - let _ = self - .stream - .write_sctp(&Bytes::from_static(&[0]), ppi) - .await?; + self.transmits.push_back(Transmit { + association_handle: self.association_handle, + stream_id: self.stream_id, + ppi, + unordered, + reliability_type, + payload: BytesMut::from(&[0][..]), + }); + 0 } else { - let n = self.stream.write_sctp(data, ppi).await?; - self.bytes_sent.fetch_add(n, Ordering::SeqCst); - n + self.transmits.push_back(Transmit { + association_handle: self.association_handle, + stream_id: self.stream_id, + ppi, + unordered, + reliability_type, + payload: BytesMut::from(data), + }); + + self.bytes_sent += data.len(); + data.len() }; - self.messages_sent.fetch_add(1, Ordering::SeqCst); + self.messages_sent += 1; Ok(n) } - async fn write_data_channel_ack(&self) -> Result { + fn write_data_channel_ack(&mut self) -> Result<()> { let ack = Message::DataChannelAck(DataChannelAck {}).marshal()?; - Ok(self - .stream - .write_sctp(&ack, PayloadProtocolIdentifier::Dcep) - .await?) + let (unordered, reliability_type) = self.get_reliability_params(); + self.transmits.push_back(Transmit { + association_handle: self.association_handle, + stream_id: self.stream_id, + ppi: PayloadProtocolIdentifier::Dcep, + unordered, + reliability_type, + payload: ack, + }); + Ok(()) } - + /* /// Close closes the DataChannel and the underlying SCTP stream. pub async fn close(&self) -> Result<()> { // https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-13#section-6.7 @@ -339,9 +320,9 @@ impl DataChannel { /// number of bytes of outgoing data buffered is lower than the threshold. pub fn on_buffered_amount_low(&self, f: OnBufferedAmountLowFn) { self.stream.on_buffered_amount_low(f) - } + }*/ - fn commit_reliability_params(&self) { + fn get_reliability_params(&self) -> (bool, ReliabilityType) { let (unordered, reliability_type) = match self.config.channel_type { ChannelType::Reliable => (false, ReliabilityType::Reliable), ChannelType::ReliableUnordered => (true, ReliabilityType::Reliable), @@ -351,331 +332,6 @@ impl DataChannel { ChannelType::PartialReliableTimedUnordered => (true, ReliabilityType::Timed), }; - self.stream.set_reliability_params( - unordered, - reliability_type, - self.config.reliability_parameter, - ); - } -} - -/// Default capacity of the temporary read buffer used by [`PollStream`]. -const DEFAULT_READ_BUF_SIZE: usize = 8192; - -/// State of the read `Future` in [`PollStream`]. -enum ReadFut { - /// Nothing in progress. - Idle, - /// Reading data from the underlying stream. - Reading(Pin>> + Send>>), - /// Finished reading, but there's unread data in the temporary buffer. - RemainingData(Vec), -} - -impl ReadFut { - /// Gets a mutable reference to the future stored inside `Reading(future)`. - /// - /// # Panics - /// - /// Panics if `ReadFut` variant is not `Reading`. - fn get_reading_mut(&mut self) -> &mut Pin>> + Send>> { - match self { - ReadFut::Reading(ref mut fut) => fut, - _ => panic!("expected ReadFut to be Reading"), - } - } -} - -/// A wrapper around around [`DataChannel`], which implements [`AsyncRead`] and -/// [`AsyncWrite`]. -/// -/// Both `poll_read` and `poll_write` calls allocate temporary buffers, which results in an -/// additional overhead. -pub struct PollDataChannel { - data_channel: Arc, - - read_fut: ReadFut, - write_fut: Option> + Send>>>, - shutdown_fut: Option> + Send>>>, - - read_buf_cap: usize, -} - -impl PollDataChannel { - /// Constructs a new `PollDataChannel`. - /// - /// # Examples - /// - /// ``` - /// use data::data_channel::{DataChannel, PollDataChannel, Config}; - /// use sctp::stream::Stream; - /// use std::sync::Arc; - /// - /// let dc = Arc::new(DataChannel::new(Arc::new(Stream::default()), Config::default())); - /// let poll_dc = PollDataChannel::new(dc); - /// ``` - pub fn new(data_channel: Arc) -> Self { - Self { - data_channel, - read_fut: ReadFut::Idle, - write_fut: None, - shutdown_fut: None, - read_buf_cap: DEFAULT_READ_BUF_SIZE, - } - } - - /// Get back the inner data_channel. - pub fn into_inner(self) -> Arc { - self.data_channel - } - - /// Obtain a clone of the inner data_channel. - pub fn clone_inner(&self) -> Arc { - self.data_channel.clone() - } - - /// MessagesSent returns the number of messages sent - pub fn messages_sent(&self) -> usize { - self.data_channel.messages_sent() - } - - /// MessagesReceived returns the number of messages received - pub fn messages_received(&self) -> usize { - self.data_channel.messages_received() - } - - /// BytesSent returns the number of bytes sent - pub fn bytes_sent(&self) -> usize { - self.data_channel.bytes_sent() - } - - /// BytesReceived returns the number of bytes received - pub fn bytes_received(&self) -> usize { - self.data_channel.bytes_received() - } - - /// StreamIdentifier returns the Stream identifier associated to the stream. - pub fn stream_identifier(&self) -> u16 { - self.data_channel.stream_identifier() - } - - /// BufferedAmount returns the number of bytes of data currently queued to be - /// sent over this stream. - pub fn buffered_amount(&self) -> usize { - self.data_channel.buffered_amount() - } - - /// BufferedAmountLowThreshold returns the number of bytes of buffered outgoing - /// data that is considered "low." Defaults to 0. - pub fn buffered_amount_low_threshold(&self) -> usize { - self.data_channel.buffered_amount_low_threshold() - } - - /// Set the capacity of the temporary read buffer (default: 8192). - pub fn set_read_buf_capacity(&mut self, capacity: usize) { - self.read_buf_cap = capacity - } -} - -impl AsyncRead for PollDataChannel { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - if buf.remaining() == 0 { - return Poll::Ready(Ok(())); - } - - let fut = match self.read_fut { - ReadFut::Idle => { - // read into a temporary buffer because `buf` has an unonymous lifetime, which can - // be shorter than the lifetime of `read_fut`. - let data_channel = self.data_channel.clone(); - let mut temp_buf = vec![0; self.read_buf_cap]; - self.read_fut = ReadFut::Reading(Box::pin(async move { - data_channel.read(temp_buf.as_mut_slice()).await.map(|n| { - temp_buf.truncate(n); - temp_buf - }) - })); - self.read_fut.get_reading_mut() - } - ReadFut::Reading(ref mut fut) => fut, - ReadFut::RemainingData(ref mut data) => { - let remaining = buf.remaining(); - let len = std::cmp::min(data.len(), remaining); - buf.put_slice(&data[..len]); - if data.len() > remaining { - // ReadFut remains to be RemainingData - data.drain(..len); - } else { - self.read_fut = ReadFut::Idle; - } - return Poll::Ready(Ok(())); - } - }; - - loop { - match fut.as_mut().poll(cx) { - Poll::Pending => return Poll::Pending, - // retry immediately upon empty data or incomplete chunks - // since there's no way to setup a waker. - Poll::Ready(Err(Error::Sctp(sctp::Error::ErrTryAgain))) => {} - // EOF has been reached => don't touch buf and just return Ok - Poll::Ready(Err(Error::Sctp(sctp::Error::ErrEof))) => { - self.read_fut = ReadFut::Idle; - return Poll::Ready(Ok(())); - } - Poll::Ready(Err(e)) => { - self.read_fut = ReadFut::Idle; - return Poll::Ready(Err(e.into())); - } - Poll::Ready(Ok(mut temp_buf)) => { - let remaining = buf.remaining(); - let len = std::cmp::min(temp_buf.len(), remaining); - buf.put_slice(&temp_buf[..len]); - if temp_buf.len() > remaining { - temp_buf.drain(..len); - self.read_fut = ReadFut::RemainingData(temp_buf); - } else { - self.read_fut = ReadFut::Idle; - } - return Poll::Ready(Ok(())); - } - } - } - } -} - -impl AsyncWrite for PollDataChannel { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - if buf.is_empty() { - return Poll::Ready(Ok(0)); - } - - if let Some(fut) = self.write_fut.as_mut() { - match fut.as_mut().poll(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(Err(e)) => { - let data_channel = self.data_channel.clone(); - let bytes = Bytes::copy_from_slice(buf); - self.write_fut = - Some(Box::pin(async move { data_channel.write(&bytes).await })); - Poll::Ready(Err(e.into())) - } - // Given the data is buffered, it's okay to ignore the number of written bytes. - // - // TODO: In the long term, `data_channel.write` should be made sync. Then we could - // remove the whole `if` condition and just call `data_channel.write`. - Poll::Ready(Ok(_)) => { - let data_channel = self.data_channel.clone(); - let bytes = Bytes::copy_from_slice(buf); - self.write_fut = - Some(Box::pin(async move { data_channel.write(&bytes).await })); - Poll::Ready(Ok(buf.len())) - } - } - } else { - let data_channel = self.data_channel.clone(); - let bytes = Bytes::copy_from_slice(buf); - let fut = self - .write_fut - .insert(Box::pin(async move { data_channel.write(&bytes).await })); - - match fut.as_mut().poll(cx) { - // If it's the first time we're polling the future, `Poll::Pending` can't be - // returned because that would mean the `PollDataChannel` is not ready for writing. - // And this is not true since we've just created a future, which is going to write - // the buf to the underlying stream. - // - // It's okay to return `Poll::Ready` if the data is buffered (this is what the - // buffered writer and `File` do). - Poll::Pending => Poll::Ready(Ok(buf.len())), - Poll::Ready(Err(e)) => { - self.write_fut = None; - Poll::Ready(Err(e.into())) - } - Poll::Ready(Ok(n)) => { - self.write_fut = None; - Poll::Ready(Ok(n)) - } - } - } - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.write_fut.as_mut() { - Some(fut) => match fut.as_mut().poll(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(Err(e)) => { - self.write_fut = None; - Poll::Ready(Err(e.into())) - } - Poll::Ready(Ok(_)) => { - self.write_fut = None; - Poll::Ready(Ok(())) - } - }, - None => Poll::Ready(Ok(())), - } - } - - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.as_mut().poll_flush(cx) { - Poll::Pending => return Poll::Pending, - Poll::Ready(_) => {} - } - - let fut = match self.shutdown_fut.as_mut() { - Some(fut) => fut, - None => { - let data_channel = self.data_channel.clone(); - self.shutdown_fut.get_or_insert(Box::pin(async move { - data_channel - .stream - .shutdown(Shutdown::Write) - .await - .map_err(Error::Sctp) - })) - } - }; - - match fut.as_mut().poll(cx) { - Poll::Pending => Poll::Pending, - Poll::Ready(Err(e)) => { - self.shutdown_fut = None; - Poll::Ready(Err(e.into())) - } - Poll::Ready(Ok(_)) => { - self.shutdown_fut = None; - Poll::Ready(Ok(())) - } - } - } -} - -impl Clone for PollDataChannel { - fn clone(&self) -> PollDataChannel { - PollDataChannel::new(self.clone_inner()) - } -} - -impl fmt::Debug for PollDataChannel { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("PollDataChannel") - .field("data_channel", &self.data_channel) - .field("read_buf_cap", &self.read_buf_cap) - .finish() - } -} - -impl AsRef for PollDataChannel { - fn as_ref(&self) -> &DataChannel { - &self.data_channel + (unordered, reliability_type) } } diff --git a/data/src/lib.rs b/data/src/lib.rs index 71d76b4..f4b2e47 100644 --- a/data/src/lib.rs +++ b/data/src/lib.rs @@ -1,5 +1,5 @@ #![warn(rust_2018_idioms)] #![allow(dead_code)] -//pub mod data_channel; +pub mod data_channel; pub mod message; diff --git a/shared/src/marshal/mod.rs b/shared/src/marshal/mod.rs index 05034e6..b5a6267 100644 --- a/shared/src/marshal/mod.rs +++ b/shared/src/marshal/mod.rs @@ -1,4 +1,4 @@ -use bytes::{Buf, Bytes, BytesMut}; +use bytes::{Buf, BytesMut}; use crate::error::{Error, Result}; @@ -9,7 +9,7 @@ pub trait MarshalSize { pub trait Marshal: MarshalSize { fn marshal_to(&self, buf: &mut [u8]) -> Result; - fn marshal(&self) -> Result { + fn marshal(&self) -> Result { let l = self.marshal_size(); let mut buf = BytesMut::with_capacity(l); buf.resize(l, 0); @@ -19,7 +19,7 @@ pub trait Marshal: MarshalSize { "marshal_to output size {n}, but expect {l}" ))) } else { - Ok(buf.freeze()) + Ok(buf) } } }