Skip to content

Commit

Permalink
feat: consolidate AES-GCM and XSalsa20-Poly1305 into AEAD
Browse files Browse the repository at this point in the history
  • Loading branch information
enmand committed Dec 8, 2024
1 parent a7275c3 commit 53ee4cf
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 105 deletions.
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
use aes::cipher::generic_array::GenericArray;
use aes::cipher::{generic_array::GenericArray, ArrayLength};
use aes_gcm::{
aead::{AeadMutInPlace, Buffer},
Aes256Gcm, KeyInit,
};
use bytes::{Bytes, BytesMut};
use crypto_secretbox::XSalsa20Poly1305 as XSalsa20Poly1305Cipher;
use thiserror::Error;

use super::{Encryption, IVEncryption};

pub(super) struct AESBuffer<'a>(pub(crate) &'a mut BytesMut);
pub struct AEAD<C: AeadMutInPlace> {
cipher: C,
iv: Option<GenericArray<u8, C::NonceSize>>,
}

pub(super) struct AEADBufferBytesMut<'a>(&'a mut BytesMut);

impl<'a> Buffer for AESBuffer<'a> {
impl<'a> Buffer for AEADBufferBytesMut<'a> {
fn extend_from_slice(&mut self, other: &[u8]) -> aes_gcm::aead::Result<()> {
self.0.extend_from_slice(other);

Expand All @@ -22,23 +28,18 @@ impl<'a> Buffer for AESBuffer<'a> {
}
}

impl<'a> AsRef<[u8]> for AESBuffer<'a> {
impl<'a> AsRef<[u8]> for AEADBufferBytesMut<'a> {
fn as_ref(&self) -> &[u8] {
self.0.as_ref()
}
}

impl<'a> AsMut<[u8]> for AESBuffer<'a> {
impl<'a> AsMut<[u8]> for AEADBufferBytesMut<'a> {
fn as_mut(&mut self) -> &mut [u8] {
self.0.as_mut()
}
}

pub struct AES256GCM {
cipher: Aes256Gcm,
iv: Option<GenericArray<u8, typenum::consts::U12>>,
}

#[derive(Debug, Error)]
pub enum Error {
#[error("AES-256-GCM encryption/decryption error: {0}")]
Expand All @@ -47,14 +48,19 @@ pub enum Error {
NoIVError,
}

impl Encryption for AES256GCM {
fn new(key: &[u8; 32]) -> Result<Self, super::Error> {
let cipher = Aes256Gcm::new(key.into());
impl<C: AeadMutInPlace + KeyInit> Encryption for AEAD<C>
where
C::NonceSize: ArrayLength<u8>,
{
type KeySize = C::KeySize;

fn new(key: GenericArray<u8, Self::KeySize>) -> Result<Self, super::Error> {
let cipher = C::new(&key);
Ok(Self { cipher, iv: None })
}

fn encrypt(&mut self, data: &mut BytesMut) -> Result<Bytes, super::Error> {
let mut data = AESBuffer(data);
let mut data = AEADBufferBytesMut(data);
if let Some(iv) = &self.iv {
self.cipher
.encrypt_in_place(iv, b"", &mut data)
Expand All @@ -66,7 +72,7 @@ impl Encryption for AES256GCM {
}

fn decrypt(&mut self, data: &mut BytesMut) -> Result<Bytes, super::Error> {
let mut data = AESBuffer(data);
let mut data = AEADBufferBytesMut(data);
if let Some(iv) = &self.iv {
self.cipher
.decrypt_in_place(iv, b"", &mut data)
Expand All @@ -78,8 +84,8 @@ impl Encryption for AES256GCM {
}
}

impl IVEncryption for AES256GCM {
type NonceSize = typenum::consts::U12;
impl<C: AeadMutInPlace + KeyInit + Clone> IVEncryption for AEAD<C> {
type NonceSize = C::NonceSize;

fn with_iv(&mut self, iv: GenericArray<u8, Self::NonceSize>) -> Result<Self, super::Error> {
Ok(Self {
Expand All @@ -89,8 +95,13 @@ impl IVEncryption for AES256GCM {
}
}

pub type AES256GCM = AEAD<Aes256Gcm>;
pub type XSalsa20Poly1305 = AEAD<XSalsa20Poly1305Cipher>;

#[cfg(test)]
mod test {
use aes_gcm::Aes256Gcm;

use super::*;

const KEY: [u8; 32] = [
Expand All @@ -101,10 +112,14 @@ mod test {
const IV: [u8; 12] = [
0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9, 0xfa, 0xfb,
];
const SALSA_IV: [u8; 24] = [
0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe,
0xff, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
];

#[test]
fn test_aes256gcm() {
let mut enc = AES256GCM::new(&KEY)
let mut enc = AEAD::<Aes256Gcm>::new(KEY.into())
.unwrap()
.with_iv(IV.into())
.expect("IV error");
Expand All @@ -120,7 +135,36 @@ mod test {

#[test]
fn test_aes256gcm_no_iv() {
let mut enc = AES256GCM::new(&KEY).unwrap();
let mut enc = AEAD::<Aes256Gcm>::new(KEY.into()).unwrap();

let data = BytesMut::from("Hello, world!");

let enc_data = enc.encrypt(&mut data.clone());
let dec_data = enc.decrypt(&mut data.clone());

assert!(enc_data.is_err());
assert!(dec_data.is_err());
}

#[test]
fn test_xsalsa20poly1305() {
let mut enc = XSalsa20Poly1305::new(KEY.into())
.unwrap()
.with_iv(SALSA_IV.into())
.expect("IV error");

let data = BytesMut::from("Hello, world!");

let enc_data = enc.encrypt(&mut data.clone()).unwrap();
let dec_data = enc.decrypt(&mut enc_data.clone().into()).unwrap();

assert_ne!(data, enc_data);
assert_eq!(data, dec_data);
}

#[test]
fn test_xsalsa20poly1305_no_iv() {
let mut enc = XSalsa20Poly1305::new(KEY.into()).unwrap();

let data = BytesMut::from("Hello, world!");

Expand Down
14 changes: 8 additions & 6 deletions crates/dwn-rs-core/src/encryption/symmetric/aes_ctr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@ pub enum Error {
}

impl Encryption for AES256CTR {
fn new(key: &[u8; 32]) -> Result<Self, super::Error> {
type KeySize = typenum::consts::U32;

fn new(key: GenericArray<u8, Self::KeySize>) -> Result<Self, super::Error> {
Ok(Self {
key: *key,
key: key.into(),
enc: None,
dec: None,
})
Expand Down Expand Up @@ -88,7 +90,7 @@ mod test {

#[test]
fn test_aes256ctr() {
let mut enc = AES256CTR::new(&KEY)
let mut enc = AES256CTR::new(KEY.into())
.expect("Failed to create AES256CTR")
.with_iv(IV.into())
.expect("Failed to set IV");
Expand All @@ -105,19 +107,19 @@ mod test {

#[test]
fn test_aes256ctr_no_iv() {
let mut enc = AES256CTR::new(&KEY).expect("Failed to create AES256CTR");
let mut enc = AES256CTR::new(KEY.into()).expect("Failed to create AES256CTR");

let data = Bytes::from_static(b"hello world! this is my plaintext.");
let enc_data = enc.encrypt(&mut data.clone().into());
assert_eq!(
enc_data.unwrap_err().to_string(),
"AES-256-CBC encryption error: AES-256-CTR IV error"
"AES-256-CTR encryption error: AES-256-CTR IV error"
);

let dec_data = enc.decrypt(&mut data.clone().into());
assert_eq!(
dec_data.unwrap_err().to_string(),
"AES-256-CBC encryption error: AES-256-CTR IV error"
"AES-256-CTR encryption error: AES-256-CTR IV error"
);
}
}
38 changes: 20 additions & 18 deletions crates/dwn-rs-core/src/encryption/symmetric/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,42 +6,43 @@ use futures_util::{ready, Stream};
use pin_project_lite::pin_project;
use thiserror::Error;

pub mod aead;
pub mod aes_ctr;
pub mod aes_gcm;
pub mod xsalsa20_poly1305;

#[derive(Debug, Error)]
pub enum Error {
#[error("AES-256-CTR encryption error: {0}")]
AES256CTR(#[from] aes_ctr::Error),
#[error("AES-256-GCM encryption error: {0}")]
AES256GCM(#[from] aes_gcm::Error),
#[error("XSalsa20Poly1305 encryption error: {0}")]
XSalsa20Poly1305(#[from] xsalsa20_poly1305::Error),
#[error("AEAD encryption error: {0}")]
AEAD(#[from] aead::Error),
}

impl<T: ?Sized> StreamEncryptionExt for T where T: Stream {}

pub trait StreamEncryptionExt: Stream {
fn encrypt<E>(self, key: &[u8; 32]) -> Result<Encrypt<Self, E>, Error>
fn encrypt<E>(self, key: GenericArray<u8, E::KeySize>) -> Result<Encrypt<Self, E>, Error>
where
E: Encryption,
E::KeySize: ArrayLength<u8>,
Self: Sized,
{
Encrypt::new(self, key)
}

fn decrypt<E>(self, key: &[u8; 32]) -> Result<Decrypt<Self, E>, Error>
fn decrypt<E>(self, key: GenericArray<u8, E::KeySize>) -> Result<Decrypt<Self, E>, Error>
where
E: Encryption,
E::KeySize: ArrayLength<u8>,
Self: Sized,
{
Decrypt::new(self, key)
}
}

pub trait Encryption {
fn new(key: &[u8; 32]) -> Result<Self, Error>
type KeySize: ArrayLength<u8>;

fn new(key: GenericArray<u8, Self::KeySize>) -> Result<Self, Error>
where
Self: Sized;
fn encrypt(&mut self, data: &mut BytesMut) -> Result<Bytes, Error>;
Expand All @@ -68,8 +69,9 @@ pin_project! {
impl<D, E> Encrypt<D, E>
where
E: Encryption,
E::KeySize: ArrayLength<u8>,
{
pub fn new(stream: D, key: &[u8; 32]) -> Result<Self, Error> {
pub fn new(stream: D, key: GenericArray<u8, E::KeySize>) -> Result<Self, Error> {
Ok(Self {
stream,
encryption: E::new(key)?,
Expand Down Expand Up @@ -122,7 +124,7 @@ impl<D, E> Decrypt<D, E>
where
E: Encryption,
{
pub fn new(stream: D, key: &[u8; 32]) -> Result<Self, Error> {
pub fn new(stream: D, key: GenericArray<u8, E::KeySize>) -> Result<Self, Error> {
Ok(Self {
stream,
encryption: E::new(key)?,
Expand Down Expand Up @@ -190,13 +192,13 @@ mod test {
Ok::<Bytes, Error>(msg_part_1.clone()),
Ok(msg_part_2.clone()),
])
.encrypt::<aes_ctr::AES256CTR>(&KEY)
.encrypt::<aes_ctr::AES256CTR>(KEY.into())
.expect("unable to generate encryption")
.with_iv(IV.into())
.expect("unable to set IV");

// Static encryption
let mut enc = aes_ctr::AES256CTR::new(&KEY)
let mut enc = aes_ctr::AES256CTR::new(KEY.into())
.expect("Failed to create AES256CTR")
.with_iv(IV.into())
.expect("Failed to set IV");
Expand Down Expand Up @@ -228,17 +230,17 @@ mod test {
Ok::<Bytes, Error>(msg_part_1.clone()),
Ok(msg_part_2.clone()),
])
.encrypt::<aes_ctr::AES256CTR>(&KEY)
.encrypt::<aes_ctr::AES256CTR>(KEY.into())
.expect("unable to generate encryption")
.with_iv(IV.into())
.expect("Unable to set IV")
.decrypt::<aes_ctr::AES256CTR>(&KEY)
.decrypt::<aes_ctr::AES256CTR>(KEY.into())
.expect("unable to generate decryption")
.with_iv(IV.into())
.expect("Unable to set IV");

// Static encryption
let mut enc = aes_ctr::AES256CTR::new(&KEY)
let mut enc = aes_ctr::AES256CTR::new(KEY.into())
.expect("Failed to create AES256CTR")
.with_iv(IV.into())
.expect("Unable to set IV");
Expand Down Expand Up @@ -273,7 +275,7 @@ mod test {
aes::cipher::InvalidLength,
))),
])
.encrypt::<aes_ctr::AES256CTR>(&KEY)
.encrypt::<aes_ctr::AES256CTR>(KEY.into())
.expect("unable to generate encryption")
.with_iv(IV.into())
.expect("Unable to set IV");
Expand All @@ -285,7 +287,7 @@ mod test {
Err(e) => {
assert_eq!(
e.to_string(),
"AES-256-CBC encryption error: Invalid key length: Invalid Length"
"AES-256-CTR encryption error: Invalid key length: Invalid Length"
);
}
}
Expand Down
Loading

0 comments on commit 53ee4cf

Please sign in to comment.