From 71f6b5012f7bf3dade3d4375f04336ca5d5f4d52 Mon Sep 17 00:00:00 2001 From: chanderlud Date: Thu, 7 Dec 2023 17:03:07 -0800 Subject: [PATCH] stable folder support, optimizations started --- Cargo.toml | 5 +-- src/main.rs | 49 +++++++++++++++++---------- src/receiver/mod.rs | 49 +++++++++++++++++++-------- src/receiver/writer.rs | 21 ++++++------ src/sender/mod.rs | 77 ++++++++++++++++++------------------------ 5 files changed, 111 insertions(+), 90 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 5bfe9e6..37c241e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ build = "build.rs" [dependencies] clap = { version = "4.4", features = ["derive"] } -tokio = { version = "1.34", features = ["macros", "fs", "io-util"] } +tokio = { version = "1.34", default-features = false, features = ["macros", "fs", "io-util"] } futures = "0.3" log = { version = "0.4", features = ["std"] } async-ssh2-tokio = "0.8" @@ -20,7 +20,8 @@ rpassword = "7.3" indicatif = "0.17" prost = "0.12" prost-build = "0.12" -async-channel = "2.1" +bytesize = "1.3.0" +kanal = "0.1.0-pre8" [build-dependencies] prost-build = "0.12.3" diff --git a/src/main.rs b/src/main.rs index dee3747..efb3555 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,5 @@ #![feature(int_roundings)] -use async_channel::Receiver; use std::error; use std::fmt::{Display, Formatter}; use std::net::{IpAddr, SocketAddr}; @@ -13,10 +12,12 @@ use std::sync::Arc; use std::time::Duration; use async_ssh2_tokio::{AuthMethod, Client, ServerCheckMethod}; +use bytesize::ByteSize; use clap::Parser; use futures::stream::iter; use futures::{StreamExt, TryStreamExt}; use indicatif::{ProgressBar, ProgressStyle}; +use kanal::{AsyncReceiver, SendError}; use log::{debug, error, info, warn, LevelFilter}; use prost::Message; use regex::Regex; @@ -35,13 +36,16 @@ mod sender; type UnlimitedQueue = Arc>; type Result = std::result::Result; -const READ_BUFFER_SIZE: usize = 10_000_000; -const WRITE_BUFFER_SIZE: usize = 5_000_000; +// read buffer must be a multiple of the transfer buffer to prevent a nasty little bug +const READ_BUFFER_SIZE: usize = TRANSFER_BUFFER_SIZE * 100; +const WRITE_BUFFER_SIZE: usize = TRANSFER_BUFFER_SIZE * 100; const TRANSFER_BUFFER_SIZE: usize = 1024; const INDEX_SIZE: usize = std::mem::size_of::(); const ID_SIZE: usize = std::mem::size_of::(); const MAX_RETRIES: usize = 10; const RECEIVE_TIMEOUT: Duration = Duration::from_secs(5); +// UDP header + ID + INDEX + DATA +const PACKET_SIZE: usize = 8 + ID_SIZE + INDEX_SIZE + TRANSFER_BUFFER_SIZE; // how long to wait for a job to be confirmed before requeuing it const REQUEUE_INTERVAL: Duration = Duration::from_millis(1_000); @@ -62,6 +66,7 @@ enum ErrorKind { Parse(std::net::AddrParseError), Decode(prost::DecodeError), Join(tokio::task::JoinError), + Send(kanal::SendError), } impl From for Error { @@ -96,6 +101,14 @@ impl From for Error { } } +impl From for Error { + fn from(error: SendError) -> Self { + Self { + kind: ErrorKind::Send(error), + } + } +} + impl Termination for Error { fn report(self) -> ExitCode { ExitCode::from(match self.kind { @@ -106,6 +119,7 @@ impl Termination for Error { ErrorKind::Parse(_) => 3, ErrorKind::Decode(_) => 4, ErrorKind::Join(_) => 5, + ErrorKind::Send(_) => 6, }) } } @@ -115,7 +129,7 @@ struct Options { #[clap( short, long = "mode", - help = "local or remote", + hide = true, // the user does not need to set this default_value = "local" )] mode: Mode, @@ -162,10 +176,10 @@ struct Options { #[clap( short, long = "rate", - help = "the rate to send data at (bytes per second)", - default_value = "1000000" + help = "the rate to send data at [b, kb, mb, gb, tb]", + default_value = "1mb" )] - rate: u64, + rate: ByteSize, #[clap(help = "where to get the data from")] source: FileLocation, @@ -179,7 +193,7 @@ impl Options { let mode = if sender { "rr" } else { "rs" }; format!( - "cccp --mode {} --start-port {} --end-port {} --threads {} --log-level {} --rate {} \"{}\" \"{}\"", + "cccp --mode {} --start-port {} --end-port {} --threads {} --log-level {} --rate \"{}\" \"{}\" \"{}\"", mode, self.start_port, self.end_port, @@ -190,6 +204,10 @@ impl Options { self.destination ) } + + fn pps(&self) -> u64 { + self.rate.0 / PACKET_SIZE as u64 + } } #[derive(Debug, Clone, PartialEq)] @@ -361,14 +379,8 @@ async fn main() -> Result<()> { } if options.destination.host.is_none() && options.source.host.is_none() { - panic!("at least one host must be specified") + panic!("either the source or destination must be remote"); } - - // UDP header + INDEX + DATA - let packet_size = (8 + INDEX_SIZE + TRANSFER_BUFFER_SIZE) as u64; - let pps_rate = options.rate / packet_size; - debug!("{} byte/s -> {} packet/s", options.rate, pps_rate); - options.rate = pps_rate; } let sender = options.source.is_local(); @@ -451,9 +463,10 @@ async fn main() -> Result<()> { let main_future = async { if sender { - sender::main(options, stats, rts_stream, str_stream, remote_addr).await + sender::main(options, stats.clone(), rts_stream, str_stream, remote_addr).await } else { - receiver::main(options, stats, rts_stream, str_stream, remote_addr).await + receiver::main(options, stats.clone(), rts_stream, str_stream, remote_addr) + .await } }; @@ -640,7 +653,7 @@ async fn read_message(reader: &mu /// send messages from a channel to a writer async fn message_sender( mut writer: W, - receiver: Receiver, + receiver: AsyncReceiver, ) -> Result<()> { while let Ok(message) = receiver.recv().await { write_message(&mut writer, &message).await?; diff --git a/src/receiver/mod.rs b/src/receiver/mod.rs index 1fbe3a7..559a6d0 100644 --- a/src/receiver/mod.rs +++ b/src/receiver/mod.rs @@ -1,4 +1,4 @@ -use async_channel::Sender; +use kanal::AsyncSender; use std::collections::HashMap; use std::mem; use std::net::IpAddr; @@ -79,7 +79,7 @@ pub(crate) async fn main( let confirmation_queue: ConfirmationQueue = Default::default(); // `message_sender` can now be used to send messages to the sender - let (message_sender, message_receiver) = async_channel::unbounded(); + let (message_sender, message_receiver) = kanal::unbounded_async(); tokio::spawn(crate::message_sender(rts_stream, message_receiver)); let confirmation_handle = tokio::spawn(send_confirmations( @@ -145,7 +145,7 @@ async fn controller( confirmation_queue: ConfirmationQueue, confirmed_data: Arc, file_path: PathBuf, - message_sender: Sender, + message_sender: AsyncSender, ) -> Result<()> { loop { let message: Message = read_message(&mut str_stream).await?; @@ -164,7 +164,14 @@ async fn controller( file_path.clone() }; - let partial_path = file_path.with_extension("partial"); + // append partial extension to the existing extension, if there is one + let partial_extension = if let Some(extension) = file_path.extension() { + extension.to_str().unwrap().to_owned() + ".partial" + } else { + ".partial".to_string() + }; + + let partial_path = file_path.with_extension(partial_extension); let start_index = if partial_path.exists() { info!("partial file exists, resuming transfer"); @@ -187,15 +194,29 @@ async fn controller( path: file_path, }; - // TODO the result of writer is lost - tokio::spawn(writer( - file, - writer_queue.clone(), - confirmation_queue.clone(), - start_index, - message.id, - message_sender.clone(), - )); + tokio::spawn({ + let writer_queue = writer_queue.clone(); + let confirmation_queue = confirmation_queue.clone(); + let message_sender = message_sender.clone(); + + async move { + let path = file.path.clone(); + + let result = writer( + file, + writer_queue, + confirmation_queue, + start_index, + message.id, + message_sender, + ) + .await; + + if let Err(error) = result { + error!("writer for {} failed: {:?}", path.display(), error); + } + } + }); debug!("started file {:?}", details); } @@ -215,7 +236,7 @@ async fn controller( } async fn send_confirmations( - sender: Sender, + sender: AsyncSender, queue: UnlimitedQueue<(u32, u64)>, confirmed_data: Arc, ) -> Result<()> { diff --git a/src/receiver/writer.rs b/src/receiver/writer.rs index e21c443..04f74ea 100644 --- a/src/receiver/writer.rs +++ b/src/receiver/writer.rs @@ -1,7 +1,7 @@ -use async_channel::{SendError, Sender}; use deadqueue::limited::Queue; +use kanal::AsyncSender; use std::cmp::Ordering; -use std::collections::{BTreeMap, HashMap}; +use std::collections::HashMap; use std::io::SeekFrom; use std::path::PathBuf; use std::sync::Arc; @@ -13,7 +13,7 @@ use tokio::sync::RwLock; use crate::items::{message, End, Message}; use crate::receiver::{ConfirmationQueue, Job, WriterQueue}; -use crate::{TRANSFER_BUFFER_SIZE, WRITE_BUFFER_SIZE}; +use crate::{Result, TRANSFER_BUFFER_SIZE, WRITE_BUFFER_SIZE}; #[derive(Default)] pub(crate) struct SplitQueue { @@ -75,8 +75,8 @@ pub(crate) async fn writer( confirmation_queue: ConfirmationQueue, mut position: u64, id: u32, - message_sender: Sender, -) -> io::Result<()> { + message_sender: AsyncSender, +) -> Result<()> { let file = OpenOptions::new() .write(true) .create(true) @@ -92,7 +92,7 @@ pub(crate) async fn writer( position ); - let mut cache: BTreeMap = BTreeMap::new(); + let mut cache: HashMap = HashMap::new(); while position != file_details.file_size { let job = writer_queue.pop(&id).await.unwrap(); @@ -136,9 +136,7 @@ pub(crate) async fn writer( writer.flush().await?; // flush the writer file_details.rename().await?; // rename the file writer_queue.pop_queue(&id).await; // remove the queue - send_end_message(&message_sender, id) - .await - .expect("failed to send end message"); // send end message + send_end_message(&message_sender, id).await?; Ok(()) } @@ -158,10 +156,11 @@ async fn write_data( writer.write_all(&buffer[..len as usize]).await // write the data } -async fn send_end_message(sender: &Sender, id: u32) -> Result<(), SendError> { +async fn send_end_message(sender: &AsyncSender, id: u32) -> Result<()> { let end_message = Message { message: Some(message::Message::End(End { id })), }; - sender.send(end_message).await + sender.send(end_message).await?; + Ok(()) } diff --git a/src/sender/mod.rs b/src/sender/mod.rs index 7bb1dc3..e59a301 100644 --- a/src/sender/mod.rs +++ b/src/sender/mod.rs @@ -1,5 +1,5 @@ -use async_channel::{Receiver, Sender}; -use std::collections::{BTreeMap, HashMap, HashSet}; +use kanal::{AsyncReceiver, AsyncSender}; +use std::collections::{HashMap, HashSet}; use std::net::IpAddr; use std::path::{Path, PathBuf}; use std::sync::atomic::AtomicUsize; @@ -27,7 +27,7 @@ use crate::{ mod reader; type JobQueue = UnlimitedQueue; -type JobCache = Arc>>; +type JobCache = Arc>>; struct Job { data: [u8; ID_SIZE + INDEX_SIZE + TRANSFER_BUFFER_SIZE], @@ -65,10 +65,12 @@ pub(crate) async fn main( .to_path_buf(); } + let file_path = file.to_string_lossy().replace('\\', "/"); + file_map.insert( index as u32, FileDetail { - file_path: file.to_string_lossy().to_string(), + file_path, file_size, }, ); @@ -83,6 +85,7 @@ pub(crate) async fn main( dir.to_string_lossy().to_string() } }) + .map(|dir| dir.replace('\\', "/")) .collect(); let manifest = Manifest { @@ -110,44 +113,31 @@ pub(crate) async fn main( // a semaphore to control the send rate let send = Arc::new(Semaphore::new(0)); - // a map of semaphores to control the reads for each file - let mut read_semaphores: HashMap> = Default::default(); - - // create a semaphore for each file - for id in manifest.files.keys() { - let read = Arc::new(Semaphore::new(1_000)); - read_semaphores.insert(*id, read); - } - let read_semaphores = Arc::new(read_semaphores); + let read = Arc::new(Semaphore::new(1_000)); - let (confirmation_sender, confirmation_receiver) = async_channel::unbounded(); - let (end_sender, end_receiver) = async_channel::unbounded(); + let (confirmation_sender, confirmation_receiver) = kanal::unbounded_async(); + let (end_sender, end_receiver) = kanal::unbounded_async(); tokio::spawn(split_receiver(rts_stream, confirmation_sender, end_sender)); - let confirmation_handle = tokio::spawn({ - let cache = cache.clone(); - let queue = queue.clone(); - let read_semaphores = read_semaphores.clone(); - - receive_confirmations( - confirmation_receiver, - cache, - queue, - stats.confirmed_data.clone(), - read_semaphores, - ) - }); + let confirmation_handle = tokio::spawn(receive_confirmations( + confirmation_receiver, + cache.clone(), + queue.clone(), + stats.confirmed_data.clone(), + read.clone(), + )); + let rate = options.pps(); let semaphore = send.clone(); - tokio::spawn(add_permits_at_rate(semaphore, options.rate)); + tokio::spawn(add_permits_at_rate(semaphore, rate)); let controller_handle = tokio::spawn(controller( str_stream, manifest, queue.clone(), - read_semaphores, + read, stats.confirmed_data, options.source.file_path, end_receiver, @@ -198,10 +188,10 @@ async fn controller( mut control_stream: TcpStream, mut files: Manifest, job_queue: JobQueue, - read_semaphores: Arc>>, + read: Arc, confirmed_data: Arc, file_path: PathBuf, - end_receiver: Receiver, + end_receiver: AsyncReceiver, ) -> Result<()> { let mut id = 0; let mut active = 0; @@ -211,8 +201,6 @@ async fn controller( match files.files.remove(&id) { None => break, Some(file_details) => { - let read = read_semaphores.get(&id).unwrap().clone(); - let message = Message { message: Some(message::Message::Start(Start { id })), }; @@ -221,12 +209,12 @@ async fn controller( let file_path = file_path.join(&file_details.file_path); let start_index: StartIndex = read_message(&mut control_stream).await?; - confirmed_data.store(start_index.index as usize, Relaxed); + confirmed_data.fetch_add(start_index.index as usize, Relaxed); tokio::spawn(reader( file_path, job_queue.clone(), - read, + read.clone(), start_index.index, id, )); @@ -237,7 +225,7 @@ async fn controller( } } - debug!("started max files, waiting for end message"); + debug!("waiting for a file to end"); let end = end_receiver.recv().await.unwrap(); debug!("received end message: {:?} | active {}", end, active); active -= 1; @@ -258,11 +246,11 @@ async fn controller( } async fn receive_confirmations( - confirmation_receiver: Receiver, + confirmation_receiver: AsyncReceiver, cache: JobCache, queue: JobQueue, confirmed_data: Arc, - read_semaphores: Arc>>, + read: Arc, ) -> Result<()> { // this solves a problem where a confirmation is received after a job has already been requeued let lost_confirmations: Arc>> = Default::default(); @@ -272,7 +260,7 @@ async fn receive_confirmations( let cache = cache.clone(); let lost_confirmations = lost_confirmations.clone(); let confirmed_data = confirmed_data.clone(); - let read_semaphores = read_semaphores.clone(); + let read = read.clone(); let mut interval = interval(Duration::from_millis(100)); @@ -301,7 +289,6 @@ async fn receive_confirmations( // the job is not requeued because it was confirmed while outside the cache lost_confirmations.remove(&key); - let read = read_semaphores.get(&key.0).unwrap(); read.add_permits(1); confirmed_data.fetch_add(TRANSFER_BUFFER_SIZE, Relaxed); } else { @@ -319,8 +306,6 @@ async fn receive_confirmations( let mut cache = cache.write().await; for (id, indexes) in confirmations.indexes { - let read = read_semaphores.get(&id).unwrap(); - // process the array of indexes for index in indexes.inner { if cache.remove(&(id, index)).is_none() { @@ -339,6 +324,8 @@ async fn receive_confirmations( /// adds leases to the semaphore at a given rate async fn add_permits_at_rate(semaphore: Arc, rate: u64) { + debug!("adding permits at rate {}", rate); + let mut interval = interval(Duration::from_nanos(1_000_000_000 / rate)); loop { @@ -379,8 +366,8 @@ fn files_and_dirs( /// split the message stream into `Confirmation` and `End` messages async fn split_receiver( mut reader: R, - confirmation_sender: Sender, - end_sender: Sender, + confirmation_sender: AsyncSender, + end_sender: AsyncSender, ) -> Result<()> { loop { let message: Message = read_message(&mut reader).await?;