Skip to content

Commit

Permalink
code cleanup, dependency upgrades
Browse files Browse the repository at this point in the history
  • Loading branch information
chanderlud committed Jan 12, 2024
1 parent fa360fd commit 5a44ca4
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 74 deletions.
8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
[package]
name = "cccp"
version = "0.9.0"
version = "0.10.0"
edition = "2021"
build = "build.rs"
repository = "https://github.com/chanderlud/cccp"
authors = ["Chander Luderman <me@chanchan.dev>"]
authors = ["Chander Luderman Miller <me@chanchan.dev>"]

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
clap = { version = "4.4", features = ["derive"] }
tokio = { version = "1.35", default-features = false, features = ["macros", "fs", "io-util", "signal"] }
tokio = { version = "1.35", default-features = false, features = ["macros", "fs", "io-util", "signal", "io-std"] }
futures = "0.3"
log = { version = "0.4", features = ["std"] }
async-ssh2-tokio = { git = "https://github.com/chanderlud/async-ssh2-tokio" }
async-ssh2-tokio = "0.8.5"
russh = "0.40"
simple-logging = "2.0"
regex = "1.10"
Expand Down
11 changes: 6 additions & 5 deletions src/cipher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ use prost::Message;
use rand::rngs::{OsRng, StdRng};
use rand::{RngCore, SeedableRng};
use tokio::io;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;

use crate::items::{Cipher, Crypto};
use crate::Result;
Expand All @@ -31,13 +32,13 @@ where
}
}

pub(crate) struct CipherStream<S: AsyncWrite + AsyncRead + Unpin> {
stream: S,
pub(crate) struct CipherStream {
stream: TcpStream,
cipher: Box<dyn StreamCipherWrapper>,
}

impl<S: AsyncWrite + AsyncRead + Unpin> CipherStream<S> {
pub(crate) fn new(stream: S, crypto: &Crypto) -> Result<Self> {
impl CipherStream {
pub(crate) fn new(stream: TcpStream, crypto: &Crypto) -> Result<Self> {
Ok(Self {
stream,
cipher: crypto.make_cipher()?,
Expand Down
93 changes: 45 additions & 48 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#![feature(int_roundings)]

use std::io::BufRead;
use std::io::{stdin, BufRead};
use std::net::{IpAddr, SocketAddr};
use std::ops::Not;
use std::path::Path;
Expand All @@ -10,9 +10,6 @@ use std::sync::Arc;
use std::time::Duration;

use async_ssh2_tokio::{AuthMethod, Client, ServerCheckMethod};
use base64::engine::general_purpose::STANDARD_NO_PAD;
use base64::prelude::BASE64_STANDARD_NO_PAD;
use base64::Engine;
use blake3::{Hash, Hasher};
use clap::{CommandFactory, Parser};
use futures::stream::iter;
Expand All @@ -24,7 +21,7 @@ use rpassword::prompt_password;
use russh::ChannelMsg;
use simple_logging::{log_to_file, log_to_stderr};
use tokio::fs::File;
use tokio::io::{AsyncReadExt, BufReader};
use tokio::io::{stdout, AsyncReadExt, AsyncWriteExt, BufReader};
use tokio::net::{TcpListener, TcpStream, UdpSocket};
use tokio::signal::ctrl_c;
use tokio::sync::Notify;
Expand Down Expand Up @@ -84,14 +81,15 @@ impl TransferStats {

fn packet_loss(&self) -> f64 {
let sent = self.sent_packets.load(Relaxed);
let confirmed = self.confirmed_packets.load(Relaxed);

if sent == 0 || sent < confirmed {
if sent == 0 {
return 0_f64;
}

let lost = sent - confirmed;
lost as f64 / sent as f64
let confirmed = self.confirmed_packets.load(Relaxed);
let lost = sent.saturating_sub(confirmed) as f64;

lost / sent as f64
}

fn total(&self) -> usize {
Expand All @@ -114,13 +112,15 @@ async fn main() -> Result<()> {

let signal = Arc::new(Notify::new());

// ctrl-c handler for attempting to gracefully exit
tokio::spawn({
let cancel_signal = signal.clone();

async move {
ctrl_c().await.expect("failed to listen for ctrl-c");
debug!("ctrl-c received");
cancel_signal.notify_waiters();
sleep(Duration::from_secs(1)).await;
ctrl_c().await.expect("failed to listen for ctrl-c");
error!("ctrl-c received again. exiting");
std::process::exit(1);
Expand Down Expand Up @@ -191,7 +191,7 @@ async fn main() -> Result<()> {
)
.exit();
} else if !source_local && !destination_local {
debug!("switching ton controller mode");
debug!("switching to controller mode");
options.mode = Mode::Controller;
}
}
Expand Down Expand Up @@ -260,7 +260,7 @@ async fn main() -> Result<()> {
// sender -> receiver stream
let str = connect_stream(remote_ip, &mut options).await?;

run_main(sender, options, stats.clone(), rts, str, remote_ip, signal).await
run_main(sender, options, &stats, rts, str, remote_ip, signal).await
};

select! {
Expand All @@ -281,13 +281,15 @@ async fn main() -> Result<()> {
}
});

let (rts, str, remote_addr) = match options.stream_setup_mode {
let stats_handle = tokio::spawn(remote_stats_printer(stats.clone()));

match options.stream_setup_mode {
// remote clients usually are in listen mode
SetupMode::Listen => {
let (rts, addr) = listen_stream(&mut options).await?;
let (str, _) = listen_stream(&mut options).await?;

(rts, str, addr)
run_main(sender, options, &stats, rts, str, addr, signal).await?;
}
// remote clients only use connect mode for remote -> remote transfers where the source is always in connect mode
SetupMode::Connect => {
Expand All @@ -297,22 +299,9 @@ async fn main() -> Result<()> {
let rts = connect_stream(addr, &mut options).await?;
let str = connect_stream(addr, &mut options).await?;

(rts, str, addr)
run_main(sender, options, &stats, rts, str, addr, signal).await?;
}
};

let stats_handle = tokio::spawn(remote_stats_printer(stats.clone()));

run_main(
sender,
options,
stats.clone(),
rts,
str,
remote_addr,
signal,
)
.await?;
}

stats.complete.store(true, Relaxed);
stats_handle.await?;
Expand Down Expand Up @@ -374,9 +363,9 @@ async fn main() -> Result<()> {
async fn run_main(
sender: bool,
options: Options,
stats: TransferStats,
rts: CipherStream<TcpStream>,
str: CipherStream<TcpStream>,
stats: &TransferStats,
rts: CipherStream,
str: CipherStream,
remote_addr: IpAddr,
signal: Arc<Notify>,
) -> Result<()> {
Expand Down Expand Up @@ -505,19 +494,25 @@ async fn local_stats_printer(stats: TransferStats, mut interval: Interval) {
bar.finish_with_message("complete");
}

/// prints a base64 encoded stats message to stdout
/// writes the Stats message into stdout
async fn remote_stats_printer(stats: TransferStats) {
let mut interval = interval(Duration::from_secs(1));
let mut stdout = stdout();

while !stats.is_complete() {
interval.tick().await;

let stats = Stats::from(&stats); // create a Stats message
// allocate a buffer for the message
let mut buf = Vec::with_capacity(stats.encoded_len());
// convert the stats struct into a protobuf message
let stats = Stats::from(&stats);

// allocate a buffer for the message + newline
let mut buf = Vec::with_capacity(stats.encoded_len() + 1);
stats.encode(&mut buf).unwrap(); // infallible
let encoded = BASE64_STANDARD_NO_PAD.encode(&buf); // base64 encode the message
println!("{}", encoded); // print the encoded message
buf.push(b'\n'); // add a newline

if let Err(error) = stdout.write(&buf).await {
error!("failed to write stats to stdout: {}", error);
}
}
}

Expand All @@ -530,7 +525,7 @@ async fn command_runner(
total_data: Option<Arc<AtomicUsize>>,
cancel_signal: Arc<Notify>,
) -> Result<u32> {
debug!("executing command: {}", command);
debug!("command runner starting for {:?}", command);

let mut channel = client.get_channel().await?;
channel.exec(true, command).await?;
Expand All @@ -547,9 +542,9 @@ async fn command_runner(
if let Some(message) = message {
match message {
ChannelMsg::Data { ref data } => {
let message = String::from_utf8_lossy(data).replace('\n', "");
let buffer = STANDARD_NO_PAD.decode(message)?;
let stats = Stats::decode(&buffer[..])?;
// the remote client sends stats messages to stdout
// the last byte is a newline
let stats = Stats::decode(&data[..data.len() - 1])?;

if let Some(sent_packets) = &sent_packets {
sent_packets.store(stats.sent_packets as usize, Relaxed);
Expand All @@ -565,7 +560,12 @@ async fn command_runner(
}
ChannelMsg::ExtendedData { ref data, ext: 1 } => {
let error = String::from_utf8_lossy(data).replace('\n', "");
error!("remote client stderr: {}", error);

if error.contains("not recognized as an internal or external command") {
break Err(Error::command_not_found());
} else {
error!("remote client stderr: {}", error);
}
}
ChannelMsg::ExitStatus { exit_status: 127 } => break Err(Error::command_not_found()),
ChannelMsg::ExitStatus { exit_status } => break Ok(exit_status),
Expand All @@ -580,10 +580,7 @@ async fn command_runner(
}

/// connects to a listening remote client
async fn connect_stream(
remote_addr: IpAddr,
options: &mut Options,
) -> Result<CipherStream<TcpStream>> {
async fn connect_stream(remote_addr: IpAddr, options: &mut Options) -> Result<CipherStream> {
let tcp_stream = loop {
if let Ok(stream) = TcpStream::connect((remote_addr, options.start_port)).await {
break stream;
Expand All @@ -600,7 +597,7 @@ async fn connect_stream(
}

/// listens for a remote client to connect
async fn listen_stream(options: &mut Options) -> Result<(CipherStream<TcpStream>, IpAddr)> {
async fn listen_stream(options: &mut Options) -> Result<(CipherStream, IpAddr)> {
let listener = TcpListener::bind(("0.0.0.0", options.start_port)).await?;
let (tcp_stream, remote_addr) = listener.accept().await?;

Expand Down Expand Up @@ -633,7 +630,7 @@ async fn hash_file<P: AsRef<Path>>(path: P) -> io::Result<Hash> {

/// watches for stdin to receive a STOP message
fn wait_for_stop(signal: Arc<Notify>) {
let stdin = std::io::stdin();
let stdin = stdin();
let reader = std::io::BufReader::new(stdin);
let lines = reader.lines();

Expand Down
18 changes: 9 additions & 9 deletions src/receiver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use futures::{StreamExt, TryStreamExt};
use kanal::{AsyncReceiver, AsyncSender};
use log::{debug, error, info, warn};
use tokio::fs::{create_dir, metadata};
use tokio::net::{TcpStream, UdpSocket};
use tokio::net::UdpSocket;
use tokio::select;
use tokio::sync::{Mutex, Notify};
use tokio::task::JoinHandle;
Expand All @@ -39,9 +39,9 @@ struct Job {

pub(crate) async fn main(
options: Options,
stats: TransferStats,
rts_stream: CipherStream<TcpStream>,
mut str_stream: CipherStream<TcpStream>,
stats: &TransferStats,
rts_stream: CipherStream,
mut str_stream: CipherStream,
remote_addr: IpAddr,
cancel_signal: Arc<Notify>,
) -> Result<()> {
Expand Down Expand Up @@ -283,7 +283,7 @@ async fn receiver(
}

async fn controller(
mut control_stream: CipherStream<TcpStream>,
mut control_stream: CipherStream,
mut files: HashMap<u32, FileDetails>,
writer_queue: WriterQueue,
confirmation_sender: AsyncSender<(u32, u64)>,
Expand Down Expand Up @@ -408,7 +408,7 @@ async fn send_confirmations(

/// send messages from a channel to a cipher stream
async fn send_messages<M: prost::Message>(
mut stream: CipherStream<TcpStream>,
mut stream: CipherStream,
receiver: AsyncReceiver<M>,
) -> Result<()> {
while let Ok(message) = receiver.recv().await {
Expand All @@ -423,7 +423,7 @@ async fn send_messages<M: prost::Message>(
fn free_space(path: &Path) -> Result<u64> {
use nix::sys::statvfs::statvfs;

let path = format_path(path)?;
let path = parent_path(path)?;
debug!("getting free space for {:?}", path);
let stat = statvfs(&path)?;

Expand All @@ -436,7 +436,7 @@ fn free_space(path: &Path) -> Result<u64> {
use widestring::U16CString;
use windows_sys::Win32::Storage::FileSystem;

let path = format_path(path)?;
let path = parent_path(path)?;
let path = U16CString::from_os_str(path)?;

let mut free_bytes = 0_u64;
Expand All @@ -459,7 +459,7 @@ fn free_space(path: &Path) -> Result<u64> {
}

/// returns the absolute path of the first existing parent directory
fn format_path(path: &Path) -> Result<PathBuf> {
fn parent_path(path: &Path) -> Result<PathBuf> {
let mut path = path.to_path_buf();

if !path.is_absolute() {
Expand Down
Loading

0 comments on commit 5a44ca4

Please sign in to comment.