From 7b02feec0951e39d73c162afe971e4442ddf6916 Mon Sep 17 00:00:00 2001 From: chanderlud Date: Wed, 6 Dec 2023 20:03:04 -0800 Subject: [PATCH] folder support stabilized --- Cargo.toml | 7 + build.rs | 6 + scratch.txt | 57 -------- src/items.proto | 46 +++++++ src/main.rs | 173 ++++++++++++++++++------ src/receiver/mod.rs | 226 +++++++++++++++++++------------ src/receiver/writer.rs | 82 +++++++++--- src/sender/mod.rs | 297 +++++++++++++++++++++++++++++++---------- src/sender/reader.rs | 11 +- 9 files changed, 627 insertions(+), 278 deletions(-) create mode 100644 build.rs delete mode 100644 scratch.txt create mode 100644 src/items.proto diff --git a/Cargo.toml b/Cargo.toml index 6678401..5bfe9e6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,7 @@ name = "cccp" version = "0.1.0" edition = "2021" +build = "build.rs" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -17,6 +18,12 @@ regex = "1.10" dirs = "5.0" rpassword = "7.3" indicatif = "0.17" +prost = "0.12" +prost-build = "0.12" +async-channel = "2.1" + +[build-dependencies] +prost-build = "0.12.3" [profile.release] opt-level = 3 diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..02fce9d --- /dev/null +++ b/build.rs @@ -0,0 +1,6 @@ +use std::io::Result; + +fn main() -> Result<()> { + prost_build::compile_protos(&["src/items.proto"], &["src/"])?; + Ok(()) +} diff --git a/scratch.txt b/scratch.txt deleted file mode 100644 index e9d0aa8..0000000 --- a/scratch.txt +++ /dev/null @@ -1,57 +0,0 @@ -how to get multi file support workin - -1. sender needs to be able to decide which files to send - - /a/path/to/a/directory -> every file and sub directory and sub file... - - /a/path/to/some/files/* -> every file in the directory - -2. sender needs to assign each file an id + tell the receiver what directories need to be created - - probably gotta bite the bullet and use a protobuf here - - directories should probably be constructed by the receiver? - - the receiver needs to have a map of file ids to destination file paths - -struct Message { - inner: MessageType -} - -enum MessageType { - Files { - directories: Vec<&str>, - files: HashMap // file path, file size - }, - Confirmations { - indexes: HashMap>, - }, - StartIndex { - id: u32, - index: u64, - }, - Start { - id: u32, - }, - End { - id: u32, // this should be sent by the receiver when the writer exits so the sender can clean up - } -} - -3. the reader needs some changes - - i guess multiple readers will be needed. the id's will need to be added to the packets too - - one queue can be used - - each reader will need its own read semaphore, the control server will need to keep track of them all (cry) - -4. the receiver needs to spawn writers and created queues for each id it receives + create the directories - - should be pretty easy - -5. the control stream is gonna be a bit more complicated now - - should probably communicate in protobufs - - needs to control the sender and receiver exiting. if the stream breaks it can exit both ends - - we really don't need to care about the reader or writer's results/completion unless they are Err - - -things to think about - - probably need a system to limit the number of concurrent files - this will include sending the ids to the receiver as they are started - so writers and queues are only created when needed - readers and send queues should be considered too ig - - i guess every time a new file gets started there needs to be a little handshake like - sender --[Start]--> receiver --[StartIndex]--> sender - and then it can start working \ No newline at end of file diff --git a/src/items.proto b/src/items.proto new file mode 100644 index 0000000..4429d03 --- /dev/null +++ b/src/items.proto @@ -0,0 +1,46 @@ +syntax = "proto3"; + +package cccp.items; + +message Message { + oneof message { + Files files = 1; + Confirmations confirmations = 2; + StartIndex start_index = 3; + Start start = 4; + End end = 5; + Done done = 6; + } +} + +message Files { + repeated string directories = 1; + map files = 2; // map for file details +} + +message FileDetail { + string file_path = 1; + uint64 file_size = 2; +} + +message Confirmations { + map indexes = 1; +} + +message ConfirmationIndexes { + repeated uint64 inner = 1; +} + +message StartIndex { + uint64 index = 1; +} + +message Start { + uint32 id = 1; +} + +message End { + uint32 id = 1; +} + +message Done {} \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index 62b1f68..d991282 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,6 @@ #![feature(int_roundings)] +use async_channel::Receiver; use std::error; use std::fmt::{Display, Formatter}; use std::net::{IpAddr, SocketAddr}; @@ -17,12 +18,13 @@ use futures::stream::iter; use futures::{StreamExt, TryStreamExt}; use indicatif::{ProgressBar, ProgressStyle}; use log::{debug, error, info, warn, LevelFilter}; +use prost::Message; use regex::Regex; use rpassword::prompt_password; use simple_logging::{log_to_file, log_to_stderr}; use tokio::fs::File; -use tokio::io::AsyncReadExt; -use tokio::net::{TcpListener, TcpSocket, UdpSocket}; +use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpSocket, TcpStream, UdpSocket}; use tokio::time::{interval, sleep}; use tokio::{io, select}; @@ -37,11 +39,17 @@ const READ_BUFFER_SIZE: usize = 10_000_000; const WRITE_BUFFER_SIZE: usize = 5_000_000; const TRANSFER_BUFFER_SIZE: usize = 1024; const INDEX_SIZE: usize = std::mem::size_of::(); +const ID_SIZE: usize = std::mem::size_of::(); const MAX_RETRIES: usize = 10; const RECEIVE_TIMEOUT: Duration = Duration::from_secs(5); // how long to wait for a job to be confirmed before requeuing it const REQUEUE_INTERVAL: Duration = Duration::from_millis(1_000); +const MAX_CONCURRENT_TRANSFERS: usize = 100; + +pub mod items { + include!(concat!(env!("OUT_DIR"), "/cccp.items.rs")); +} #[derive(Debug)] struct Error { @@ -50,14 +58,16 @@ struct Error { #[derive(Debug)] enum ErrorKind { - IoError(io::Error), - ParseError(std::net::AddrParseError), + Io(io::Error), + Parse(std::net::AddrParseError), + Decode(prost::DecodeError), + Join(tokio::task::JoinError), } impl From for Error { fn from(error: io::Error) -> Self { Self { - kind: ErrorKind::IoError(error), + kind: ErrorKind::Io(error), } } } @@ -65,7 +75,23 @@ impl From for Error { impl From for Error { fn from(error: std::net::AddrParseError) -> Self { Self { - kind: ErrorKind::ParseError(error), + kind: ErrorKind::Parse(error), + } + } +} + +impl From for Error { + fn from(error: prost::DecodeError) -> Self { + Self { + kind: ErrorKind::Decode(error), + } + } +} + +impl From for Error { + fn from(error: tokio::task::JoinError) -> Self { + Self { + kind: ErrorKind::Join(error), } } } @@ -73,11 +99,13 @@ impl From for Error { impl Termination for Error { fn report(self) -> ExitCode { ExitCode::from(match self.kind { - ErrorKind::IoError(error) => match error.kind() { + ErrorKind::Io(error) => match error.kind() { io::ErrorKind::NotFound => 1, _ => 2, }, - ErrorKind::ParseError(_) => 3, + ErrorKind::Parse(_) => 3, + ErrorKind::Decode(_) => 4, + ErrorKind::Join(_) => 5, }) } } @@ -257,15 +285,6 @@ impl FileLocation { fn is_local(&self) -> bool { self.host.is_none() || (self.host.is_some() && self.file_path.exists()) } - - fn is_dir(&self) -> bool { - self.file_path.is_dir() - } - - async fn file_size(&self) -> io::Result { - let metadata = tokio::fs::metadata(&self.file_path).await?; - Ok(metadata.len()) - } } #[derive(Clone, Debug)] @@ -328,8 +347,8 @@ async fn main() -> Result<()> { port_count, options.threads, port_count ); options.threads = port_count; - } else if port_count < 2 { - panic!("a minimum of two ports are required") + } else if port_count < 3 { + panic!("a minimum of three ports are required") } else if port_count > options.threads { warn!( "{} ports > {} threads. changing port range to {}-{}", @@ -417,25 +436,12 @@ async fn main() -> Result<()> { client.execute(&command).await }); - let bind = match options.bind_address { - Some(addr) => SocketAddr::new(addr, 0), - None => "0.0.0.0:0".parse().unwrap(), - }; - - // connect to the remote client on the first port in the range - let control_stream = loop { - let socket = TcpSocket::new_v4().unwrap(); - socket.bind(bind).unwrap(); - - let remote_socket = SocketAddr::new(remote_addr, options.start_port); - - if let Ok(stream) = socket.connect(remote_socket).await { - break stream; - } else { - // give the receiver time to start listening - sleep(Duration::from_millis(100)).await; - } - }; + // receiver -> sender stream + let rts_stream = + connect_stream(remote_addr, options.start_port, options.bind_address).await?; + // sender -> receiver stream + let str_stream = + connect_stream(remote_addr, options.start_port + 1, options.bind_address).await?; let display_handle = tokio::spawn({ let stats = stats.clone(); @@ -445,9 +451,9 @@ async fn main() -> Result<()> { let main_future = async { if sender { - sender::main(options, stats, control_stream, remote_addr).await + sender::main(options, stats, rts_stream, str_stream, remote_addr).await } else { - receiver::main(options, stats, control_stream, remote_addr).await + receiver::main(options, stats, rts_stream, str_stream, remote_addr).await } }; @@ -465,6 +471,8 @@ async fn main() -> Result<()> { 1 => error!("remote client failed, file not found"), 2 => error!("remote client failed, unknown IO error"), 3 => error!("remote client failed, parse error"), + 4 => error!("remote client failed, decode error"), + 5 => error!("remote client failed, join error"), _ => error!("remote client failed, unknown error"), } } @@ -480,13 +488,20 @@ async fn main() -> Result<()> { } } Mode::Remote(sender) => { + // receiver -> sender stream let listener = TcpListener::bind(("0.0.0.0", options.start_port)).await?; - let (control_stream, remote_addr) = listener.accept().await?; + let (rts_stream, remote_addr) = listener.accept().await?; + + // sender -> receiver stream + let listener = TcpListener::bind(("0.0.0.0", options.start_port + 1)).await?; + let (str_stream, _) = listener.accept().await?; + + let remote_addr = remote_addr.ip(); if sender { - sender::main(options, stats, control_stream, remote_addr.ip()).await?; + sender::main(options, stats, rts_stream, str_stream, remote_addr).await?; } else { - receiver::main(options, stats, control_stream, remote_addr.ip()).await?; + receiver::main(options, stats, rts_stream, str_stream, remote_addr).await?; }; } } @@ -495,7 +510,7 @@ async fn main() -> Result<()> { Ok(()) } -// opens the sockets that will be used to send data +/// opens the sockets that will be used to send data async fn socket_factory( start: u16, end: u16, @@ -520,6 +535,7 @@ async fn socket_factory( .await } +/// try to get an ssh key for authentication async fn ssh_key_auth() -> io::Result { // get the home directory of the current user let home_dir = dirs::home_dir().ok_or(io::Error::new( @@ -540,11 +556,13 @@ async fn ssh_key_auth() -> io::Result { Ok(AuthMethod::with_key(&key, None)) } +/// prompt the user for a password fn password_auth() -> io::Result { let password = prompt_password("password: ")?; Ok(AuthMethod::with_password(&password)) } +/// print a progress bar to stdout async fn print_progress(stats: TransferStats) { let bar = ProgressBar::new(100); let mut interval = interval(Duration::from_secs(1)); @@ -563,3 +581,70 @@ async fn print_progress(stats: TransferStats) { bar.set_position((progress * 100_f64) as u64); } } + +/// connect to a remote client on a given port +async fn connect_stream( + remote_addr: IpAddr, + port: u16, + bind_addr: Option, +) -> Result { + let bind = match bind_addr { + Some(addr) => SocketAddr::new(addr, 0), + None => "0.0.0.0:0".parse()?, + }; + + // connect to the remote client + loop { + let socket = TcpSocket::new_v4()?; + socket.bind(bind)?; + + let remote_socket = SocketAddr::new(remote_addr, port); + + if let Ok(stream) = socket.connect(remote_socket).await { + break Ok(stream); + } else { + // give the receiver time to start listening + sleep(Duration::from_millis(100)).await; + } + } +} + +/// write a `Message` to a writer +async fn write_message( + writer: &mut W, + message: &M, +) -> Result<()> { + let len = message.encoded_len(); + writer.write_u32(len as u32).await?; + + let mut buffer = Vec::with_capacity(len); + message.encode(&mut buffer).unwrap(); + + writer.write_all(&buffer).await?; + + Ok(()) +} + +/// read a `Message` from a reader +async fn read_message(reader: &mut R) -> Result { + let len = reader.read_u32().await? as usize; + + let mut buffer = vec![0; len]; + reader.read_exact(&mut buffer).await?; + + let message = M::decode(&buffer[..])?; + + Ok(message) +} + +/// send messages from a channel to a writer +async fn message_sender( + mut writer: W, + receiver: Receiver, +) -> Result<()> { + while let Ok(message) = receiver.recv().await { + write_message(&mut writer, &message).await?; + } + + Ok(()) +} diff --git a/src/receiver/mod.rs b/src/receiver/mod.rs index 2be7476..5611289 100644 --- a/src/receiver/mod.rs +++ b/src/receiver/mod.rs @@ -1,29 +1,32 @@ +use async_channel::Sender; use std::collections::HashMap; use std::mem; use std::net::IpAddr; +use std::path::PathBuf; use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering::Relaxed; use std::sync::Arc; use std::time::Duration; use log::{debug, error, info, warn}; -use tokio::fs::metadata; -use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::fs::{create_dir_all, metadata}; use tokio::net::{TcpStream, UdpSocket}; +use tokio::select; use tokio::sync::Mutex; use tokio::task::JoinHandle; use tokio::time::{interval, timeout}; -use tokio::{io, select}; +use crate::items::{message, ConfirmationIndexes, Confirmations, Files, Message, StartIndex}; use crate::receiver::writer::{writer, SplitQueue}; use crate::{ - socket_factory, Options, Result, TransferStats, UnlimitedQueue, INDEX_SIZE, MAX_RETRIES, - RECEIVE_TIMEOUT, TRANSFER_BUFFER_SIZE, + read_message, socket_factory, write_message, Options, Result, TransferStats, UnlimitedQueue, + ID_SIZE, INDEX_SIZE, MAX_RETRIES, RECEIVE_TIMEOUT, TRANSFER_BUFFER_SIZE, }; mod writer; type WriterQueue = Arc; +type ConfirmationQueue = UnlimitedQueue<(u32, u64)>; #[derive(Clone)] struct Job { @@ -32,49 +35,33 @@ struct Job { } pub(crate) async fn main( - mut options: Options, + options: Options, stats: TransferStats, - mut control_stream: TcpStream, + rts_stream: TcpStream, + mut str_stream: TcpStream, remote_addr: IpAddr, ) -> Result<()> { - // TODO what if the source is a directory? - if options.destination.file_path.is_dir() { - info!("destination is a folder, reformatting path with target file"); - - options - .destination - .file_path - .push(options.source.file_path.iter().last().unwrap()) - } - info!("receiving {} -> {}", options.source, options.destination); - let partial_path = options.destination.file_path.with_extension("partial"); - - let start_index = if partial_path.exists() { - info!("partial file exists, resuming transfer"); - let metadata = metadata(&partial_path).await?; - // the file is written sequentially, so we can calculate the start index by rounding down to the nearest multiple of the transfer buffer size - metadata.len().div_floor(TRANSFER_BUFFER_SIZE as u64) * TRANSFER_BUFFER_SIZE as u64 - } else { - 0 - }; - - stats - .confirmed_data - .fetch_add(start_index as usize, Relaxed); + let files: Files = read_message(&mut str_stream).await?; + debug!("received files: {:?}", files); - // receive the file size from the remote client - let file_size = control_stream.read_u64().await?; - stats.total_data.store(file_size as usize, Relaxed); - debug!("received file size: {}", file_size); + // create the local directories needed to write the files + for dir in &files.directories { + let local_dir = options.destination.file_path.join(dir); + debug!("creating directory {:?}", local_dir); + create_dir_all(local_dir).await?; + } - // send the start index to the remote client - debug!("sending start index {}", start_index); - control_stream.write_u64(start_index).await?; + // set the total data to be received + for details in files.files.values() { + stats + .total_data + .fetch_add(details.file_size as usize, Relaxed); + } let sockets = socket_factory( - options.start_port + 1, // the first port is used for control messages + options.start_port + 2, // the first two ports are used for control messages and confirmations options.end_port, remote_addr, options.threads, @@ -84,22 +71,26 @@ pub(crate) async fn main( info!("opened sockets"); let writer_queue: WriterQueue = Default::default(); - let confirmation_queue: UnlimitedQueue<(u32, u64)> = Default::default(); + let confirmation_queue: ConfirmationQueue = Default::default(); - let writer_handle = tokio::spawn(writer( - options.destination.file_path, - partial_path.clone(), - writer_queue.clone(), - file_size, + // `message_sender` can now be used to send messages to the sender + let (message_sender, message_receiver) = async_channel::unbounded(); + tokio::spawn(crate::message_sender(rts_stream, message_receiver)); + + let confirmation_handle = tokio::spawn(send_confirmations( + message_sender.clone(), confirmation_queue.clone(), - start_index, - 0, + stats.confirmed_data.clone(), )); - let confirmation_handle = tokio::spawn(send_confirmations( - control_stream, + let controller_handle = tokio::spawn(controller( + str_stream, + files.clone(), + writer_queue.clone(), confirmation_queue, - stats.confirmed_data, + stats.confirmed_data.clone(), + options.destination.file_path, + message_sender, )); let handles: Vec<_> = sockets @@ -114,28 +105,28 @@ pub(crate) async fn main( }; select! { - result = confirmation_handle => error!("confirmation sender failed {:?}", result), - result = writer_handle => { - info!("writer finished with result {:?}", result); - }, - _ = receiver_future => info!("receiver(s) exited"), + result = confirmation_handle => result?, + result = controller_handle => result?, + _ = receiver_future => { warn!("receiver(s) exited"); Ok(()) }, } - Ok(()) + } async fn receiver(queue: WriterQueue, socket: UdpSocket) { - let mut buf = [0; INDEX_SIZE + TRANSFER_BUFFER_SIZE]; // buffer for receiving data + let mut buf = [0; ID_SIZE + INDEX_SIZE + TRANSFER_BUFFER_SIZE]; // buffer for receiving data let mut retries = 0; // counter to keep track of retries while retries < MAX_RETRIES { match timeout(RECEIVE_TIMEOUT, socket.recv(&mut buf)).await { Ok(Ok(read)) if read > 0 => { retries = 0; // reset retries - // TODO let id = ... - let id = 0; - let index = u64::from_be_bytes(buf[..INDEX_SIZE].try_into().unwrap()); - let data = buf[INDEX_SIZE..].try_into().unwrap(); + + let id = u32::from_be_bytes(buf[..ID_SIZE].try_into().unwrap()); + let index = + u64::from_be_bytes(buf[ID_SIZE..INDEX_SIZE + ID_SIZE].try_into().unwrap()); + let data = buf[INDEX_SIZE + ID_SIZE..].try_into().unwrap(); + queue.push(Job { data, index }, id).await; } Ok(Ok(_)) => warn!("0 byte read?"), // this should never happen @@ -144,14 +135,90 @@ async fn receiver(queue: WriterQueue, socket: UdpSocket) { } } +async fn controller( + mut str_stream: TcpStream, + mut files: Files, + writer_queue: WriterQueue, + confirmation_queue: ConfirmationQueue, + confirmed_data: Arc, + file_path: PathBuf, + message_sender: Sender, +) -> Result<()> { + loop { + let message: Message = read_message(&mut str_stream).await?; + + match message.message { + Some(message::Message::Start(message)) => { + debug!("received start message: {:?}", message); + + let details = files.files.remove(&message.id).unwrap(); + + writer_queue.push_queue(message.id).await; // create a queue for the writer + + let file_path = if file_path.is_dir() { + file_path.join(&details.file_path) + } else { + file_path.clone() + }; + + let partial_path = file_path.with_extension("partial"); + + let start_index = if partial_path.exists() { + info!("partial file exists, resuming transfer"); + let metadata = metadata(&partial_path).await?; + // the file is written sequentially, so we can calculate the start index by rounding down to the nearest multiple of the transfer buffer size + metadata.len().div_floor(TRANSFER_BUFFER_SIZE as u64) + * TRANSFER_BUFFER_SIZE as u64 + } else { + 0 + }; + + confirmed_data.fetch_add(start_index as usize, Relaxed); + + // send the start index to the remote client + write_message(&mut str_stream, &StartIndex { index: start_index }).await?; + + let file = writer::File { + file_size: details.file_size, + partial_path, + path: file_path, + }; + + // TODO the result of writer is lost + tokio::spawn(writer( + file, + writer_queue.clone(), + confirmation_queue.clone(), + start_index, + message.id, + message_sender.clone(), + )); + + debug!("started file {:?}", details); + } + Some(message::Message::Done(_)) => { + debug!("received done message"); + message_sender.close(); + break; + } + _ => { + error!("received {:?}", message); + break; + } + } + } + + Ok(()) +} + async fn send_confirmations( - mut control_stream: TcpStream, + sender: Sender, queue: UnlimitedQueue<(u32, u64)>, confirmed_data: Arc, -) -> io::Result<()> { +) -> Result<()> { let data: Arc>> = Default::default(); - let sender_handle: JoinHandle> = tokio::spawn({ + let sender_handle: JoinHandle> = tokio::spawn({ let data = data.clone(); async move { @@ -170,16 +237,23 @@ async fn send_confirmations( let confirmations = mem::take(&mut *data); drop(data); // release the lock on data - let map: HashMap> = HashMap::new(); + let map: HashMap = HashMap::new(); // group the confirmations by id let map = confirmations.into_iter().fold(map, |mut map, (id, index)| { - map.entry(id).or_default().push(index); + map.entry(id).or_default().inner.push(index); map }); - // TODO send confirmations - // write_indexes(&mut control_stream, &indexes).await?; + let message = Message { + message: Some(message::Message::Confirmations(Confirmations { + indexes: map, + })), + }; + sender + .send(message) + .await + .expect("failed to send confirmations"); } } }); @@ -198,19 +272,3 @@ async fn send_confirmations( _ = future => Ok(()) } } - -// writes an array of u64 values to the control stream -async fn write_indexes( - control_stream: &mut T, - data: &[u64], -) -> io::Result<()> { - let length = data.len() as u64; - control_stream.write_u64(length).await?; - - // send the array of u64 values - for value in data { - control_stream.write_u64(*value).await?; - } - - Ok(()) -} diff --git a/src/receiver/writer.rs b/src/receiver/writer.rs index 1090e86..e21c443 100644 --- a/src/receiver/writer.rs +++ b/src/receiver/writer.rs @@ -1,27 +1,30 @@ +use async_channel::{SendError, Sender}; use deadqueue::limited::Queue; use std::cmp::Ordering; use std::collections::{BTreeMap, HashMap}; use std::io::SeekFrom; use std::path::PathBuf; +use std::sync::Arc; use log::{debug, info}; use tokio::fs::{rename, OpenOptions}; use tokio::io::{self, AsyncSeekExt, AsyncWrite, AsyncWriteExt, BufWriter}; use tokio::sync::RwLock; -use crate::receiver::{Job, WriterQueue}; -use crate::{UnlimitedQueue, TRANSFER_BUFFER_SIZE, WRITE_BUFFER_SIZE}; +use crate::items::{message, End, Message}; +use crate::receiver::{ConfirmationQueue, Job, WriterQueue}; +use crate::{TRANSFER_BUFFER_SIZE, WRITE_BUFFER_SIZE}; #[derive(Default)] pub(crate) struct SplitQueue { - inner: RwLock>>, + inner: RwLock>>>, } impl SplitQueue { pub(crate) async fn push_queue(&self, id: u32) { let mut inner = self.inner.write().await; - inner.insert(id, Queue::new(1_000)); + inner.insert(id, Arc::new(Queue::new(1_000))); } pub(crate) async fn pop_queue(&self, id: &u32) { @@ -39,9 +42,12 @@ impl SplitQueue { } pub(crate) async fn pop(&self, id: &u32) -> Option { - let inner = self.inner.read().await; + let queue = { + let inner = self.inner.read().await; + inner.get(id).cloned() + }; - if let Some(queue) = inner.get(id) { + if let Some(queue) = queue { Some(queue.pop().await) } else { None @@ -49,29 +55,46 @@ impl SplitQueue { } } +/// stores file details for writer +pub(crate) struct File { + pub(crate) path: PathBuf, + pub(crate) partial_path: PathBuf, + pub(crate) file_size: u64, +} + +impl File { + /// rename the partial file to the final name + async fn rename(&self) -> io::Result<()> { + rename(&self.partial_path, &self.path).await + } +} + pub(crate) async fn writer( - path: PathBuf, - partial_path: PathBuf, + file_details: File, writer_queue: WriterQueue, - file_size: u64, - confirmation_queue: UnlimitedQueue<(u32, u64)>, + confirmation_queue: ConfirmationQueue, mut position: u64, id: u32, + message_sender: Sender, ) -> io::Result<()> { let file = OpenOptions::new() .write(true) .create(true) - .open(&partial_path) + .open(&file_details.partial_path) .await?; let mut writer = BufWriter::with_capacity(WRITE_BUFFER_SIZE, file); writer.seek(SeekFrom::Start(position)).await?; // seek to the initial position - debug!("starting writer at position {}", position); + debug!( + "writer for {} starting at {}", + file_details.path.display(), + position + ); let mut cache: BTreeMap = BTreeMap::new(); - while position != file_size { + while position != file_details.file_size { let job = writer_queue.pop(&id).await.unwrap(); match job.index.cmp(&position) { @@ -85,20 +108,37 @@ pub(crate) async fn writer( } // if the chunk is at the current position, write it Ordering::Equal => { - write_data(&mut writer, &job.data, &mut position, file_size).await?; + write_data( + &mut writer, + &job.data, + &mut position, + file_details.file_size, + ) + .await?; confirmation_queue.push((id, job.index)); } } // write all concurrent chunks from the cache while let Some(job) = cache.remove(&position) { - write_data(&mut writer, &job.data, &mut position, file_size).await?; + write_data( + &mut writer, + &job.data, + &mut position, + file_details.file_size, + ) + .await?; } } info!("writer wrote all expected bytes"); - writer.flush().await?; - rename(&partial_path, path).await?; + + writer.flush().await?; // flush the writer + file_details.rename().await?; // rename the file + writer_queue.pop_queue(&id).await; // remove the queue + send_end_message(&message_sender, id) + .await + .expect("failed to send end message"); // send end message Ok(()) } @@ -117,3 +157,11 @@ async fn write_data( *position += len; // advance the position writer.write_all(&buffer[..len as usize]).await // write the data } + +async fn send_end_message(sender: &Sender, id: u32) -> Result<(), SendError> { + let end_message = Message { + message: Some(message::Message::End(End { id })), + }; + + sender.send(end_message).await +} diff --git a/src/sender/mod.rs b/src/sender/mod.rs index 37d051c..4beaaf0 100644 --- a/src/sender/mod.rs +++ b/src/sender/mod.rs @@ -1,64 +1,93 @@ -use std::collections::{BTreeMap, HashSet}; +use async_channel::{Receiver, Sender}; +use std::collections::{BTreeMap, HashMap, HashSet}; use std::net::IpAddr; +use std::path::{Path, PathBuf}; use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering::Relaxed; use std::sync::Arc; use std::time::Duration; -use log::{debug, error, info}; -use tokio::io::{self, AsyncReadExt, AsyncWriteExt}; +use log::{debug, error, info, warn}; +use tokio::io::{self, AsyncReadExt}; use tokio::net::{TcpStream, UdpSocket}; use tokio::select; use tokio::sync::{Mutex, RwLock, Semaphore}; use tokio::time::{interval, Instant}; +use crate::items::{ + message, Confirmations, Done, End, FileDetail, Files, Message, Start, StartIndex, +}; +use crate::sender::reader::reader; use crate::{ - socket_factory, Options, Result, TransferStats, UnlimitedQueue, INDEX_SIZE, MAX_RETRIES, - REQUEUE_INTERVAL, TRANSFER_BUFFER_SIZE, + read_message, socket_factory, write_message, Options, Result, TransferStats, UnlimitedQueue, + ID_SIZE, INDEX_SIZE, MAX_CONCURRENT_TRANSFERS, MAX_RETRIES, REQUEUE_INTERVAL, + TRANSFER_BUFFER_SIZE, }; mod reader; type JobQueue = UnlimitedQueue; -type JobCache = Arc>>; +type JobCache = Arc>>; struct Job { - data: [u8; INDEX_SIZE + TRANSFER_BUFFER_SIZE], + data: [u8; ID_SIZE + INDEX_SIZE + TRANSFER_BUFFER_SIZE], index: u64, + id: u32, cached_at: Option, } pub(crate) async fn main( options: Options, stats: TransferStats, - mut control_stream: TcpStream, + rts_stream: TcpStream, + mut str_stream: TcpStream, remote_addr: IpAddr, ) -> Result<()> { info!("sending {} -> {}", options.source, options.destination); - if options.source.is_dir() { - for entry in options.source.file_path.read_dir()? { - if let Ok(entry) = entry { - println!("{:?}", entry); - entry.path().is_dir(); // lol - } + // collect the files and directories to send + let mut files = Vec::new(); + let mut dirs = Vec::new(); + files_and_dirs(&options.source.file_path, &mut files, &mut dirs)?; + + let mut file_map: HashMap = HashMap::with_capacity(files.len()); + + for (index, file) in files.into_iter().enumerate() { + let file_size = tokio::fs::metadata(&file).await?.len(); + stats.total_data.fetch_add(file_size as usize, Relaxed); + + if let Ok(file_path) = file.strip_prefix(&options.source.file_path) { + file_map.insert( + index as u32, + FileDetail { + file_path: file_path.to_string_lossy().to_string(), + file_size, + }, + ); } - - return Ok(()); } - let file_size = options.source.file_size().await?; - stats.total_data.store(file_size as usize, Relaxed); - // send the file size to the remote client - control_stream.write_u64(file_size).await?; + let directories = dirs + .into_iter() + .map(|dir| { + if let Ok(file_path) = dir.strip_prefix(&options.source.file_path) { + file_path.to_string_lossy().to_string() + } else { + dir.to_string_lossy().to_string() + } + }) + .collect(); + + let files = Files { + directories, + files: file_map, + }; - // receive the start index from the remote client - let start_index = control_stream.read_u64().await?; - stats.confirmed_data.store(start_index as usize, Relaxed); - debug!("received start index {}", start_index); + debug!("sending files: {:?}", files); + write_message(&mut str_stream, &files).await?; let sockets = socket_factory( - options.start_port + 1, // the first port is used for control messages + options.start_port + 2, // the first two ports are used for control messages and confirmations options.end_port, remote_addr, options.threads, @@ -74,27 +103,49 @@ pub(crate) async fn main( // a semaphore to control the send rate let send = Arc::new(Semaphore::new(0)); - // a semaphore which limits the number of jobs that the reader will add to the queue - let read = Arc::new(Semaphore::new(1_000)); + // a map of semaphores to control the reads for each file + let mut read_semaphores: HashMap> = Default::default(); - tokio::spawn(reader::reader( - options.source.file_path, - queue.clone(), - read.clone(), - start_index, - )); + // create a semaphore for each file + for id in files.files.keys() { + let read = Arc::new(Semaphore::new(1_000)); + read_semaphores.insert(*id, read); + } + + let read_semaphores = Arc::new(read_semaphores); + + let (confirmation_sender, confirmation_receiver) = async_channel::unbounded(); + let (end_sender, end_receiver) = async_channel::unbounded(); + + tokio::spawn(split_receiver(rts_stream, confirmation_sender, end_sender)); let confirmation_handle = tokio::spawn({ let cache = cache.clone(); let queue = queue.clone(); - let read = read.clone(); - - receive_confirmations(control_stream, cache, queue, stats.confirmed_data, read) + let read_semaphores = read_semaphores.clone(); + + receive_confirmations( + confirmation_receiver, + cache, + queue, + stats.confirmed_data.clone(), + read_semaphores, + ) }); let semaphore = send.clone(); tokio::spawn(add_permits_at_rate(semaphore, options.rate)); + let controller_handle = tokio::spawn(controller( + str_stream, + files, + queue.clone(), + read_semaphores, + stats.confirmed_data, + options.source.file_path, + end_receiver, + )); + let handles: Vec<_> = sockets .into_iter() .map(|socket| tokio::spawn(sender(queue.clone(), socket, cache.clone(), send.clone()))) @@ -107,14 +158,10 @@ pub(crate) async fn main( }; select! { - _ = sender_future => error!("senders exited"), - result = confirmation_handle => { - // the confirmation receiver never exits unless an error occurs - error!("confirmation receiver exited with result {:?}", result); - } + result = confirmation_handle => result?, + result = controller_handle => result?, + _ = sender_future => { warn!("senders exited"); Ok(()) }, } - - Ok(()) } async fn sender(queue: JobQueue, socket: UdpSocket, cache: JobCache, send: Arc) { @@ -132,7 +179,7 @@ async fn sender(queue: JobQueue, socket: UdpSocket, cache: JobCache, send: Arc>>, + confirmed_data: Arc, + file_path: PathBuf, + end_receiver: Receiver, +) -> Result<()> { + let mut id = 0; + let mut active = 0; + + loop { + while active < MAX_CONCURRENT_TRANSFERS { + match files.files.remove(&id) { + None => break, + Some(file_details) => { + let read = read_semaphores.get(&id).unwrap().clone(); + + let message = Message { + message: Some(message::Message::Start(Start { id })), + }; + write_message(&mut control_stream, &message).await?; + + let file_path = file_path.join(&file_details.file_path); + + let start_index: StartIndex = read_message(&mut control_stream).await?; + confirmed_data.store(start_index.index as usize, Relaxed); + + tokio::spawn(reader( + file_path, + job_queue.clone(), + read, + start_index.index, + id, + )); + + id += 1; + active += 1; + } + } + } + + debug!("started max files, waiting for end message"); + let end = end_receiver.recv().await.unwrap(); + debug!("received end message: {:?} | active {}", end, active); + active -= 1; + + if files.files.is_empty() && active == 0 { + break; + } + } + + debug!("all files completed, sending done message"); + + let message = Message { + message: Some(message::Message::Done(Done {})), + }; + write_message(&mut control_stream, &message).await?; + + Ok(()) +} + +async fn receive_confirmations( + confirmation_receiver: Receiver, cache: JobCache, queue: JobQueue, confirmed_data: Arc, - read: Arc, -) -> io::Result<()> { + read_semaphores: Arc>>, +) -> Result<()> { // this solves a problem where a confirmation is received after a job has already been requeued - let lost_confirmations: Arc>> = Default::default(); + let lost_confirmations: Arc>> = Default::default(); // this thread checks the cache for unconfirmed jobs that have been there for too long and requeues them tokio::spawn({ let cache = cache.clone(); let lost_confirmations = lost_confirmations.clone(); let confirmed_data = confirmed_data.clone(); - let read = read.clone(); + let read_semaphores = read_semaphores.clone(); let mut interval = interval(Duration::from_millis(100)); @@ -178,12 +288,13 @@ async fn receive_confirmations( let mut cache = cache.write().await; // requeue and remove entries - for index in keys_to_remove { - if let Some(mut unconfirmed) = cache.remove(&index) { - if lost_confirmations.contains(&index) { + for key in keys_to_remove { + if let Some(mut unconfirmed) = cache.remove(&key) { + if lost_confirmations.contains(&key) { // the job is not requeued because it was confirmed while outside the cache - lost_confirmations.remove(&index); + lost_confirmations.remove(&key); + let read = read_semaphores.get(&key.0).unwrap(); read.add_permits(1); confirmed_data.fetch_add(TRANSFER_BUFFER_SIZE, Relaxed); } else { @@ -196,26 +307,30 @@ async fn receive_confirmations( } }); - loop { - let confirmed_indexes = receive_indexes(&mut control_stream).await?; - + while let Ok(confirmations) = confirmation_receiver.recv().await { let mut lost_confirmations = lost_confirmations.lock().await; let mut cache = cache.write().await; - // process the array of u64 values - for index in confirmed_indexes { - if cache.remove(&index).is_none() { - // if the index is not in the cache, it was already requeued - lost_confirmations.insert(index); - } else { - read.add_permits(1); // add a permit to the reader - confirmed_data.fetch_add(TRANSFER_BUFFER_SIZE, Relaxed); + for (id, indexes) in confirmations.indexes { + let read = read_semaphores.get(&id).unwrap(); + + // process the array of indexes + for index in indexes.inner { + if cache.remove(&(id, index)).is_none() { + // if the index is not in the cache, it was already requeued + lost_confirmations.insert((id, index)); + } else { + read.add_permits(1); // add a permit to the reader + confirmed_data.fetch_add(TRANSFER_BUFFER_SIZE, Relaxed); + } } } } + + Ok(()) } -// adds leases to the semaphore at a given rate to control the send rate +/// adds leases to the semaphore at a given rate async fn add_permits_at_rate(semaphore: Arc, rate: u64) { let mut interval = interval(Duration::from_nanos(1_000_000_000 / rate)); @@ -230,14 +345,52 @@ async fn add_permits_at_rate(semaphore: Arc, rate: u64) { } } -async fn receive_indexes(control_stream: &mut TcpStream) -> io::Result> { - let length = control_stream.read_u64().await? as usize; // read the length of the array - let mut indexes = Vec::with_capacity(length); // create a vector with the capacity of the array +/// recursively collect all files and directories in a directory +fn files_and_dirs( + source: &Path, + files: &mut Vec, + dirs: &mut Vec, +) -> io::Result<()> { + if source.is_dir() { + for entry in source.read_dir()?.filter_map(std::result::Result::ok) { + let path = entry.path(); - for _ in 0..length { - let index = control_stream.read_u64().await?; // read the u64 value - indexes.push(index); + if path.is_dir() { + dirs.push(path.clone()); + files_and_dirs(&path, files, dirs)?; + } else { + files.push(path); + } + } + } else { + files.push(source.to_path_buf()); } - Ok(indexes) + Ok(()) +} + +/// split the message stream into `Confirmation` and `End` messages +async fn split_receiver( + mut reader: R, + confirmation_sender: Sender, + end_sender: Sender, +) -> Result<()> { + loop { + let message: Message = read_message(&mut reader).await?; + + match message.message { + Some(message::Message::Confirmations(confirmations)) => { + confirmation_sender + .send(confirmations) + .await + .expect("failed to send confirmations"); + } + Some(message::Message::End(end)) => { + end_sender.send(end).await.expect("failed to send end"); + } + _ => { + error!("received {:?}", message); + } + } + } } diff --git a/src/sender/reader.rs b/src/sender/reader.rs index fbe5e1b..0eb2d41 100644 --- a/src/sender/reader.rs +++ b/src/sender/reader.rs @@ -8,19 +8,21 @@ use tokio::io::{AsyncReadExt, AsyncSeekExt, BufReader, Result}; use tokio::sync::Semaphore; use crate::sender::{Job, JobQueue}; -use crate::{INDEX_SIZE, READ_BUFFER_SIZE, TRANSFER_BUFFER_SIZE}; +use crate::{ID_SIZE, INDEX_SIZE, READ_BUFFER_SIZE, TRANSFER_BUFFER_SIZE}; pub(crate) async fn reader( path: PathBuf, queue: JobQueue, read: Arc, mut index: u64, + id: u32, ) -> Result<()> { let file = File::open(path).await?; let mut reader = BufReader::with_capacity(READ_BUFFER_SIZE, file); reader.seek(SeekFrom::Start(index)).await?; - let mut buffer = [0; INDEX_SIZE + TRANSFER_BUFFER_SIZE]; + let mut buffer = [0; ID_SIZE + INDEX_SIZE + TRANSFER_BUFFER_SIZE]; + buffer[..ID_SIZE].copy_from_slice(&id.to_be_bytes()); debug!("starting reader at index {}", index); @@ -28,10 +30,10 @@ pub(crate) async fn reader( let permit = read.acquire().await.unwrap(); // write index to buffer - buffer[..INDEX_SIZE].copy_from_slice(&index.to_be_bytes()); + buffer[ID_SIZE..INDEX_SIZE + ID_SIZE].copy_from_slice(&index.to_be_bytes()); // read data into buffer after checksum and index - let read = reader.read(&mut buffer[INDEX_SIZE..]).await?; + let read = reader.read(&mut buffer[INDEX_SIZE + ID_SIZE..]).await?; // check if EOF if read == 0 { @@ -42,6 +44,7 @@ pub(crate) async fn reader( queue.push(Job { data: buffer, index, + id, cached_at: None, });