Skip to content

Commit

Permalink
refactoring + error handling + aes128 support
Browse files Browse the repository at this point in the history
  • Loading branch information
chanderlud committed Dec 26, 2023
1 parent f04a051 commit c301df5
Show file tree
Hide file tree
Showing 11 changed files with 317 additions and 347 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ chacha20 = "0.9"
base64 = "0.21"
ctr = "0.9"
aes = "0.8"
whoami = "1.4"
cipher = "0.4"

[target.'cfg(unix)'.dependencies]
nix = { version = "0.27", features = ["fs"] }
Expand Down
101 changes: 101 additions & 0 deletions src/cipher.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
use aes::{Aes128, Aes256};
use chacha20::{ChaCha20, ChaCha8};
use cipher::{KeyIvInit, StreamCipher, StreamCipherSeek};
use ctr::Ctr128BE;
use prost::Message;
use tokio::io;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};

use crate::items::{Cipher, Crypto};

pub(crate) trait StreamCipherWrapper: Send + Sync {
fn seek(&mut self, index: u64);
fn apply_keystream(&mut self, data: &mut [u8]);
}

impl<T> StreamCipherWrapper for T
where
T: StreamCipherSeek + StreamCipher + Send + Sync,
{
fn seek(&mut self, index: u64) {
StreamCipherSeek::seek(self, index);
}

fn apply_keystream(&mut self, buf: &mut [u8]) {
StreamCipher::apply_keystream(self, buf);
}
}

pub(crate) struct CipherStream<S: AsyncWrite + AsyncRead + Unpin> {
stream: S,
cipher: Box<dyn StreamCipherWrapper>,
}

impl<S: AsyncWrite + AsyncRead + Unpin> CipherStream<S> {
pub(crate) fn new(stream: S, crypto: &Crypto) -> crate::Result<Self> {
Ok(Self {
stream,
cipher: make_cipher(crypto)?,
})
}

/// write a `Message` to the stream
pub(crate) async fn write_message<M: Message>(&mut self, message: &M) -> crate::Result<()> {
let len = message.encoded_len(); // get the length of the message
self.write_u32(len as u32).await?; // write the length of the message

let mut buffer = Vec::with_capacity(len); // create a buffer to write the message into
message.encode(&mut buffer).unwrap(); // encode the message into the buffer (infallible)

self.write_all(&mut buffer).await?; // write the message to the writer

Ok(())
}

/// read a `Message` from the stream
pub(crate) async fn read_message<M: Message + Default>(&mut self) -> crate::Result<M> {
let len = self.read_u32().await? as usize; // read the length of the message

let mut buffer = vec![0; len]; // create a buffer to read the message into
self.read_exact(&mut buffer).await?; // read the message into the buffer

let message = M::decode(&buffer[..])?; // decode the message

Ok(message)
}

async fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
AsyncReadExt::read_exact(&mut self.stream, buf).await?;
self.cipher.apply_keystream(buf);
Ok(())
}

async fn read_u32(&mut self) -> io::Result<u32> {
let mut buf = [0; 4];
self.read_exact(&mut buf).await?;
Ok(u32::from_be_bytes(buf))
}

async fn write_all(&mut self, buf: &mut [u8]) -> io::Result<()> {
self.cipher.apply_keystream(buf);
AsyncWriteExt::write_all(&mut self.stream, buf).await
}

async fn write_u32(&mut self, value: u32) -> io::Result<()> {
let mut buf = value.to_be_bytes();
self.write_all(&mut buf).await
}
}

pub(crate) fn make_cipher(crypto: &Crypto) -> crate::Result<Box<dyn StreamCipherWrapper>> {
let cipher: Cipher = crypto.cipher.try_into()?;
let key = &crypto.key[..cipher.key_length()];
let iv = &crypto.iv[..cipher.iv_length()];

Ok(match cipher {
Cipher::Aes128 => Box::new(Ctr128BE::<Aes128>::new(key.into(), iv.into())),
Cipher::Aes256 => Box::new(Ctr128BE::<Aes256>::new(key.into(), iv.into())),
Cipher::Chacha8 => Box::new(ChaCha8::new(key.into(), iv.into())),
Cipher::Chacha20 => Box::new(ChaCha20::new(key.into(), iv.into())),
})
}
30 changes: 11 additions & 19 deletions src/error.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use std::array::TryFromSliceError;
use std::fmt::Formatter;
use std::process::{ExitCode, Termination};

use kanal::{ReceiveError, SendError};
use prost::Message;
use tokio::io;
use tokio::sync::AcquireError;

Expand Down Expand Up @@ -34,6 +34,7 @@ pub(crate) enum ErrorKind {
Failure(u32),
EmptyPath,
InvalidExtension,
UnexpectedMessage(Box<dyn Message>),
}

impl From<io::Error> for Error {
Expand Down Expand Up @@ -134,24 +135,6 @@ impl From<async_ssh2_tokio::Error> for Error {
}
}

impl Termination for Error {
fn report(self) -> ExitCode {
ExitCode::from(match self.kind {
ErrorKind::Io(error) => match error.kind() {
io::ErrorKind::NotFound => 1,
_ => 2,
},
ErrorKind::AddrParse(_) => 3,
ErrorKind::Decode(_) => 4,
ErrorKind::Join(_) => 5,
ErrorKind::Send(_) => 6,
ErrorKind::Receive(_) => 7,
ErrorKind::Acquire(_) => 8,
_ => 9,
})
}
}

impl std::fmt::Display for Error {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self.kind {
Expand All @@ -176,6 +159,9 @@ impl std::fmt::Display for Error {
ErrorKind::Failure(ref reason) => write!(f, "Failure: {}", reason),
ErrorKind::EmptyPath => write!(f, "Empty path"),
ErrorKind::InvalidExtension => write!(f, "Invalid extension"),
ErrorKind::UnexpectedMessage(ref message) => {
write!(f, "Unexpected message {:?}", message)
}
}
}
}
Expand Down Expand Up @@ -211,6 +197,12 @@ impl Error {
}
}

pub(crate) fn unexpected_message(message: Box<dyn Message>) -> Self {
Self {
kind: ErrorKind::UnexpectedMessage(message),
}
}

#[cfg(windows)]
pub(crate) fn status_error() -> Self {
Self {
Expand Down
9 changes: 6 additions & 3 deletions src/items.proto
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ message Crypto {
}

enum Cipher {
AES = 0;
CHACHA8 = 1;
CHACHA8 = 0;
AES128 = 1;
CHACHA20 = 2;
AES256 = 3;
}

// the receiver already had these files
Expand Down Expand Up @@ -78,4 +79,6 @@ message Failure {
}

// signals the receiver that the sender won't start new transfers
message Done {}
message Done {
uint32 reason = 1;
}
24 changes: 17 additions & 7 deletions src/items.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,34 +38,38 @@ impl Message {
}
}

pub(crate) fn done() -> Self {
pub(crate) fn done(reason: u32) -> Self {
Self {
message: Some(message::Message::Done(Done {})),
message: Some(message::Message::Done(Done { reason })),
}
}
}

impl Cipher {
/// the length of the key in bytes
pub(crate) fn key_length(&self) -> usize {
32
match self {
Self::Chacha20 | Self::Chacha8 | Self::Aes256 => 32,
Self::Aes128 => 16,
}
}

/// the length of the iv in bytes
pub(crate) fn iv_length(&self) -> usize {
match self {
Self::Chacha20 | Self::Chacha8 => 12,
Self::Aes => 16,
Self::Aes256 | Self::Aes128 => 16,
}
}
}

impl Display for Cipher {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let cipher = match self {
Self::Aes => "aes",
Self::Chacha8 => "chacha8",
Self::Chacha20 => "chacha20",
Self::Aes128 => "AES128",
Self::Aes256 => "AES256",
Self::Chacha8 => "CHACHA8",
Self::Chacha20 => "CHACHA20",
};

write!(f, "{}", cipher)
Expand All @@ -77,3 +81,9 @@ impl StartIndex {
Self { index }
}
}

impl Manifest {
pub(crate) fn is_empty(&self) -> bool {
self.files.is_empty() && self.directories.is_empty()
}
}
Loading

0 comments on commit c301df5

Please sign in to comment.