diff --git a/Cargo.toml b/Cargo.toml index a5d0cd1..ec776a5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,18 +15,20 @@ exclude = [ "fuzz/*" ] +[features] +default = ["avx2", "sse41"] +bench = [] +avx2 = [] +sse41 = [] [dev-dependencies] -criterion = "0.2" +criterion = "0.3" rustc-hex = "1.0" -hex = "0.3.2" +hex = "0.4" proptest = "0.8" +rand = "0.7.3" [[bench]] name = "hex" harness = false - - -[[bench]] -name = "check" -harness = false +required-features = ["bench", "avx2", "sse41"] diff --git a/benches/check.rs b/benches/check.rs deleted file mode 100644 index 60052b9..0000000 --- a/benches/check.rs +++ /dev/null @@ -1,38 +0,0 @@ -use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use faster_hex::{hex_check_fallback, hex_check_sse}; - -fn bench(c: &mut Criterion) { - let s1 = "Bf9E2d38aceDeeCbbAfccc4B4B7AE"; - let s2 = "ed136fFDdCcC1DbaFE8CB6Df1AdDBAea44aCcC17b0DbC2741F9CeEeaFbE7A51D"; - let s3 = " \u{0} ๐€€G\u{0}๐€€ GG\u{0}๐€€G\u{0}Gเ €\u{0} ๐€€ \u{0}:\u{0}\u{0}gเ €G G::GG::g๐€€G๐€€\u{0}\u{0}ยก๐€€เ €\u{0}:GGG Gg๐€€ :\u{0}:gG ยก"; - let s4 = "ed136fFDdCcC1DbaFE8CB6Df1AdDBAea44aCcC17b0DbC2741F9CeEeaFbE7A51D\u{0} ๐€€G\u{0}๐€€ GG\u{0}๐€€G\u{0}Gเ €\u{0} ๐€€ \u{0}:\u{0}\u{0}gเ €G G::GG::g๐€€G๐€€\u{0}\u{0}ยก๐€€เ €\u{0}:GGG Gg๐€€ :\u{0}:gG ยก"; - - c.bench_function("bench_check_fallback", move |b| { - b.iter(|| { - let ret1 = hex_check_fallback(s1.as_bytes()); - black_box(ret1); - let ret2 = hex_check_fallback(s2.as_bytes()); - black_box(ret2); - let ret3 = hex_check_fallback(s3.as_bytes()); - black_box(ret3); - let ret4 = hex_check_fallback(s4.as_bytes()); - black_box(ret4); - }) - }); - - c.bench_function("bench_check_sse", move |b| { - b.iter(|| { - let ret1 = unsafe { hex_check_sse(s1.as_bytes()) }; - black_box(ret1); - let ret2 = unsafe { hex_check_sse(s2.as_bytes()) }; - black_box(ret2); - let ret3 = unsafe { hex_check_sse(s3.as_bytes()) }; - black_box(ret3); - let ret4 = unsafe { hex_check_sse(s4.as_bytes()) }; - black_box(ret4); - }) - }); -} - -criterion_group!(benches, bench); -criterion_main!(benches); diff --git a/benches/hex.rs b/benches/hex.rs index 018db82..458a581 100644 --- a/benches/hex.rs +++ b/benches/hex.rs @@ -1,91 +1,145 @@ -use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; use faster_hex::{ - hex_decode, hex_decode_fallback, hex_decode_unchecked, hex_encode_fallback, hex_string, + decode_fallback, decode_to_slice, decode_to_slice_unchecked, decode_unchecked_fallback, encode, + encode_fallback, }; use rustc_hex::{FromHex, ToHex}; +use std::time::Duration; -fn bench(c: &mut Criterion) { - let s = "Day before yesterday I saw a rabbit, and yesterday a deer, and today, you."; - - c.bench_function("bench_rustc_hex_encode", move |b| { - b.iter(|| { - let ret = s.as_bytes().to_hex(); - black_box(ret); - }) - }); - - c.bench_function("bench_hex_encode", move |b| { - b.iter(|| { - let ret = hex::encode(s); - black_box(ret); - }) - }); - - c.bench_function("bench_faster_hex_encode", move |b| { - b.iter(|| { - let ret = hex_string(s.as_bytes()).unwrap(); - black_box(ret); - }) - }); - - c.bench_function("bench_faster_hex_encode_fallback", move |b| { - b.iter(|| { - let bytes = s.as_bytes(); - let mut buffer = vec![0; bytes.len() * 2]; - let ret = hex_encode_fallback(bytes, &mut buffer); - black_box(ret); - }) - }); +const BYTE_SIZES: [usize; 5] = [2, 16, 32, 128, 4096]; - c.bench_function("bench_rustc_hex_decode", move |b| { - let hex = s.as_bytes().to_hex(); - b.iter(|| { - let ret: Vec = hex.from_hex().unwrap(); - black_box(ret); - }) - }); +fn rand_slice(size: usize) -> Vec { + use rand::Rng; + let mut input: Vec = vec![0; size]; + rand::thread_rng().fill(input.as_mut_slice()); + input +} - c.bench_function("bench_hex_decode", move |b| { - let hex = s.as_bytes().to_hex(); - b.iter(|| { - let ret: Vec = hex::decode(&hex).unwrap(); - black_box(ret); - }) - }); +fn rand_hex_encoded(size: usize) -> String { + use rand::seq::SliceRandom; + String::from_utf8( + std::iter::repeat(()) + .map(|_| *b"0123456789abcdef".choose(&mut rand::thread_rng()).unwrap()) + .take(size) + .collect(), + ) + .unwrap() +} - c.bench_function("bench_faster_hex_decode", move |b| { - let hex = hex_string(s.as_bytes()).unwrap(); - let len = s.as_bytes().len(); - b.iter(|| { - let mut dst = Vec::with_capacity(len); - dst.resize(len, 0); - let ret = hex_decode(hex.as_bytes(), &mut dst); - black_box(ret); - }) - }); +fn bench(c: &mut Criterion) { + let mut encode_group = c.benchmark_group("encode"); + for size in &BYTE_SIZES[..] { + encode_group.throughput(Throughput::Bytes(*size as u64)); + encode_group.bench_with_input(BenchmarkId::new("rustc", size), size, |b, &size| { + let input = rand_slice(size); + b.iter(|| { + let ret = input.to_hex(); + black_box(ret); + }) + }); + encode_group.bench_with_input(BenchmarkId::new("hex", size), size, |b, &size| { + let input = rand_slice(size); + b.iter(|| { + let ret = hex::encode(&input); + black_box(ret); + }) + }); + encode_group.bench_with_input(BenchmarkId::new("faster_hex", size), size, |b, &size| { + let input = rand_slice(size); + b.iter(|| { + let ret = encode(&input); + black_box(ret); + }) + }); + encode_group.bench_with_input( + BenchmarkId::new("faster_hex_fallback", size), + size, + |b, &size| { + let input = rand_slice(size); + let mut buffer = vec![0; input.len() * 2]; + b.iter(|| { + let ret = encode_fallback(&input, buffer.as_mut_slice()); + black_box(ret); + }) + }, + ); + } + encode_group.finish(); - c.bench_function("bench_faster_hex_decode_unchecked", move |b| { - let hex = hex_string(s.as_bytes()).unwrap(); - let len = s.as_bytes().len(); - b.iter(|| { - let mut dst = Vec::with_capacity(len); - dst.resize(len, 0); - let ret = hex_decode_unchecked(hex.as_bytes(), &mut dst); - black_box(ret); - }) - }); + let mut decode_group = c.benchmark_group("decode"); + for size in &BYTE_SIZES[..] { + decode_group.throughput(Throughput::Bytes(*size as u64)); + decode_group.bench_with_input(BenchmarkId::new("rustc", size), size, |b, &size| { + let hex_input = rand_hex_encoded(size); + b.iter(|| { + let ret: Vec = hex_input.from_hex().unwrap(); + black_box(ret); + }) + }); + decode_group.bench_with_input(BenchmarkId::new("hex", size), size, |b, &size| { + let hex_input = rand_hex_encoded(size); + b.iter(|| { + let ret: Vec = hex::decode(&hex_input).unwrap(); + black_box(ret); + }) + }); + decode_group.bench_with_input(BenchmarkId::new("faster_hex", size), size, |b, &size| { + let hex_input = rand_hex_encoded(size); + let mut dst = vec![0; size / 2]; + b.iter(|| { + let ret = decode_to_slice(hex_input.as_bytes(), &mut dst).unwrap(); + black_box(ret); + }) + }); + decode_group.bench_with_input( + BenchmarkId::new("faster_hex_unchecked", size), + size, + |b, &size| { + let hex_input = rand_hex_encoded(size); + let mut dst = vec![0; size / 2]; + b.iter(|| { + let ret = decode_to_slice_unchecked(hex_input.as_bytes(), &mut dst); + black_box(ret); + }) + }, + ); + decode_group.bench_with_input( + BenchmarkId::new("faster_hex_fallback", size), + size, + |b, &size| { + let hex_input = rand_hex_encoded(size); + let mut dst = vec![0; size / 2]; + b.iter(|| { + let ret = decode_fallback(hex_input.as_bytes(), &mut dst).unwrap(); + black_box(ret); + }) + }, + ); + decode_group.bench_with_input( + BenchmarkId::new("faster_hex_unchecked_fallback", size), + size, + |b, &size| { + let hex_input = rand_hex_encoded(size); + let mut dst = vec![0; size / 2]; + b.iter(|| { + let ret = decode_unchecked_fallback(hex_input.as_bytes(), &mut dst); + black_box(ret); + }) + }, + ); + } + decode_group.finish(); +} - c.bench_function("bench_faster_hex_decode_fallback", move |b| { - let hex = hex_string(s.as_bytes()).unwrap(); - let len = s.as_bytes().len(); - b.iter(|| { - let mut dst = Vec::with_capacity(len); - dst.resize(len, 0); - let ret = hex_decode_fallback(hex.as_bytes(), &mut dst); - black_box(ret); - }) - }); +fn quicker() -> Criterion { + Criterion::default() + .warm_up_time(Duration::from_millis(500)) + .measurement_time(Duration::from_secs(1)) } -criterion_group!(benches, bench); +criterion_group! { + name = benches; + config = quicker(); + targets = bench +} criterion_main!(benches); diff --git a/src/decode.rs b/src/decode.rs index 80aec6a..95b1752 100644 --- a/src/decode.rs +++ b/src/decode.rs @@ -1,289 +1,454 @@ -#[cfg(target_arch = "x86")] -use std::arch::x86::*; -#[cfg(target_arch = "x86_64")] -use std::arch::x86_64::*; - use crate::error::Error; -const NIL: u8 = u8::max_value(); -const T_MASK: i32 = 65535; - -// ASCII -> hex -pub(crate) static UNHEX: [u8; 256] = [ - NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, - NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, - NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, NIL, NIL, NIL, - NIL, NIL, NIL, NIL, 10, 11, 12, 13, 14, 15, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, - NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, 10, 11, 12, 13, - 14, 15, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, - NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, - NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, - NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, - NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, - NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, - NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, - NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, - NIL, NIL, NIL, -]; - -// ASCII -> hex << 4 -pub(crate) static UNHEX4: [u8; 256] = [ - NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, - NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, - NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, 0, 16, 32, 48, 64, 80, 96, 112, 128, 144, - NIL, NIL, NIL, NIL, NIL, NIL, NIL, 160, 176, 192, 208, 224, 240, NIL, NIL, NIL, NIL, NIL, NIL, - NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, - NIL, 160, 176, 192, 208, 224, 240, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, - NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, - NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, - NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, - NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, - NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, - NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, - NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, - NIL, NIL, NIL, NIL, NIL, NIL, NIL, NIL, -]; - -const _0213: i32 = 0b11011000; - -// lower nibble -#[inline] -fn unhex_b(x: usize) -> u8 { - UNHEX[x] -} - -// upper nibble, logically equivalent to unhex_b(x) << 4 -#[inline] -fn unhex_a(x: usize) -> u8 { - UNHEX4[x] -} - -#[inline] -#[target_feature(enable = "avx2")] -#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -unsafe fn unhex_avx2(value: __m256i) -> __m256i { - let sr6 = _mm256_srai_epi16(value, 6); - let and15 = _mm256_and_si256(value, _mm256_set1_epi16(0xf)); - let mul = _mm256_maddubs_epi16(sr6, _mm256_set1_epi16(9)); - _mm256_add_epi16(mul, and15) -} - -// (a << 4) | b; -#[inline] -#[target_feature(enable = "avx2")] -#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -unsafe fn nib2byte_avx2(a1: __m256i, b1: __m256i, a2: __m256i, b2: __m256i) -> __m256i { - let a4_1 = _mm256_slli_epi16(a1, 4); - let a4_2 = _mm256_slli_epi16(a2, 4); - let a4orb_1 = _mm256_or_si256(a4_1, b1); - let a4orb_2 = _mm256_or_si256(a4_2, b2); - let pck1 = _mm256_packus_epi16(a4orb_1, a4orb_2); - _mm256_permute4x64_epi64(pck1, _0213) +pub fn decode(src: &I) -> Result, Error> +where + I: AsRef<[u8]> + ?Sized, +{ + let src = src.as_ref(); + let mut output = vec![0u8; src.len() / 2]; + decode_to_slice(src, &mut output)?; + Ok(output) } -pub fn hex_check(src: &[u8]) -> bool { +pub fn decode_to_slice(src: &I, dst: &mut [u8]) -> Result<(), Error> +where + I: AsRef<[u8]> + ?Sized, +{ + let src = src.as_ref(); + validate_buffer_length(src, dst)?; #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] { - if is_x86_feature_detected!("sse4.1") { - return unsafe { hex_check_sse(src) }; + #[cfg(feature = "avx2")] + { + if is_x86_feature_detected!("avx2") && src.len() >= 64 { + return unsafe { arch::avx2::decode(src, dst) }; + } + } + #[cfg(feature = "sse41")] + { + if is_x86_feature_detected!("sse4.1") && src.len() >= 32 { + return unsafe { arch::sse41::decode(src, dst) }; + } } } - - hex_check_fallback(src) + arch::fallback::decode(src, dst) } -pub fn hex_check_fallback(src: &[u8]) -> bool { - for byte in src { - match byte { - b'A'...b'F' | b'a'...b'f' | b'0'...b'9' => continue, - _ => { - return false; +pub fn decode_to_slice_unchecked(src: &I, dst: &mut [u8]) +where + I: AsRef<[u8]> + ?Sized, +{ + let src = src.as_ref(); + validate_buffer_length(src, dst).unwrap(); + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + { + #[cfg(feature = "avx2")] + { + if is_x86_feature_detected!("avx2") && src.len() >= 64 { + return unsafe { + arch::avx2::decode_unchecked(src, dst); + }; + } + } + #[cfg(feature = "sse41")] + { + if is_x86_feature_detected!("sse4.1") && src.len() >= 32 { + return unsafe { + arch::sse41::decode_unchecked(src, dst); + }; } } } - true + arch::fallback::decode_unchecked(src, dst) } -#[target_feature(enable = "sse4.1")] -#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -pub unsafe fn hex_check_sse(mut src: &[u8]) -> bool { - let ascii_zero = _mm_set1_epi8((b'0' - 1) as i8); - let ascii_nine = _mm_set1_epi8((b'9' + 1) as i8); - let ascii_ua = _mm_set1_epi8((b'A' - 1) as i8); - let ascii_uf = _mm_set1_epi8((b'F' + 1) as i8); - let ascii_la = _mm_set1_epi8((b'a' - 1) as i8); - let ascii_lf = _mm_set1_epi8((b'f' + 1) as i8); +#[inline] +fn validate_buffer_length(src: &[u8], dst: &[u8]) -> Result<(), Error> { + let decoded_len = src.len().checked_div(2).unwrap(); + if dst.len() < decoded_len || ((src.len() & 1) != 0) { + return Err(Error::InvalidLength(src.len())); + } + Ok(()) +} - while src.len() >= 16 { - let unchecked = _mm_loadu_si128(src.as_ptr() as *const _); +struct Checked; +struct Unchecked; - let gt0 = _mm_cmpgt_epi8(unchecked, ascii_zero); - let lt9 = _mm_cmplt_epi8(unchecked, ascii_nine); - let outside1 = _mm_and_si128(gt0, lt9); +pub mod arch { + #[cfg(all(feature = "avx2", any(target_arch = "x86", target_arch = "x86_64")))] + pub mod avx2 { + #[cfg(target_arch = "x86")] + use std::arch::x86::*; + #[cfg(target_arch = "x86_64")] + use std::arch::x86_64::*; - let gtua = _mm_cmpgt_epi8(unchecked, ascii_ua); - let ltuf = _mm_cmplt_epi8(unchecked, ascii_uf); - let outside2 = _mm_and_si128(gtua, ltuf); + use crate::decode::{Checked, Error, Unchecked}; - let gtla = _mm_cmpgt_epi8(unchecked, ascii_la); - let ltlf = _mm_cmplt_epi8(unchecked, ascii_lf); - let outside3 = _mm_and_si128(gtla, ltlf); + #[target_feature(enable = "avx2")] + pub unsafe fn decode(src: &[u8], dst: &mut [u8]) -> Result<(), Error> { + _decode::(src, dst).map_err(|_| Error::InvalidChar) + } - let tmp = _mm_or_si128(outside1, outside2); - let ret = _mm_movemask_epi8(_mm_or_si128(tmp, outside3)); + #[target_feature(enable = "avx2")] + pub unsafe fn decode_unchecked(src: &[u8], dst: &mut [u8]) { + let _ = _decode::(src, dst); + } - if ret != T_MASK { - return false; + #[inline] + #[target_feature(enable = "avx2")] + pub unsafe fn _decode(mut src: &[u8], mut dst: &mut [u8]) -> Result<(), ()> { + while src.len() >= 64 { + let av1 = _mm256_loadu_si256(src.as_ptr() as *const _); + let av2 = _mm256_loadu_si256(src[32..].as_ptr() as *const _); + let av1 = decode_chunk::(av1)?; + let av1 = + _mm256_permutevar8x32_epi32(av1, _mm256_setr_epi32(0, 1, 4, 5, -1, -1, -1, -1)); + let av2 = decode_chunk::(av2)?; + let av2 = + _mm256_permutevar8x32_epi32(av2, _mm256_setr_epi32(-1, -1, -1, -1, 0, 1, 4, 5)); + let decoded = _mm256_or_si256(av1, av2); + _mm256_storeu_si256(dst.as_mut_ptr() as *mut _, decoded); + dst = &mut dst[32..]; + src = &src[64..]; + } + crate::decode::arch::fallback::_decode::(&src, &mut dst) } - src = &src[16..]; - } - hex_check_fallback(src) -} + #[inline] + #[target_feature(enable = "avx2")] + unsafe fn decode_chunk(input: __m256i) -> Result<__m256i, ()> { + #[allow(overflowing_literals)] + let hi_nibbles = + _mm256_and_si256(_mm256_srli_epi32(input, 4), _mm256_set1_epi8(0b00001111)); + let low_nibbles = _mm256_and_si256(input, _mm256_set1_epi8(0b00001111)); -pub fn hex_decode(src: &[u8], dst: &mut [u8]) -> Result<(), Error> { - if src.is_empty() { - return Err(Error::InvalidLength(0)); - } - let len = dst.len().checked_mul(2).unwrap(); - if src.len() < len || ((src.len() & 1) != 0) { - return Err(Error::InvalidLength(len)); - } - if !hex_check(src) { - return Err(Error::InvalidChar); - } - hex_decode_unchecked(src, dst); - Ok(()) -} + if !::is_valid(hi_nibbles, low_nibbles) { + return Err(()); + } -pub fn hex_decode_unchecked(src: &[u8], dst: &mut [u8]) { - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - { - if is_x86_feature_detected!("avx2") { - return unsafe { hex_decode_avx2(src, dst) }; + let shift_lut = _mm256_setr_epi8( + 0, 0, 0, -48, -55, 0, -87, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -48, -55, 0, -87, 0, + 0, 0, 0, 0, 0, 0, 0, 0, + ); + + let sh = _mm256_shuffle_epi8(shift_lut, hi_nibbles); + let input = _mm256_add_epi8(input, sh); + #[allow(overflowing_literals)] + let input = _mm256_maddubs_epi16( + input, + _mm256_set1_epi32(0b00000001_00010000_00000001_00010000), + ); + let input = _mm256_shuffle_epi8( + input, + _mm256_setr_epi8( + 0, 2, 4, 6, 8, 10, 12, 14, -1, -1, -1, -1, -1, -1, -1, -1, 0, 2, 4, 6, 8, 10, + 12, 14, -1, -1, -1, -1, -1, -1, -1, -1, + ), + ); + Ok(input) } - } - hex_decode_fallback(src, dst); -} + pub trait IsValid: crate::decode::arch::fallback::IsValid { + unsafe fn is_valid(hi_nibbles: __m256i, low_nibbles: __m256i) -> bool; + } -#[target_feature(enable = "avx2")] -#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -unsafe fn hex_decode_avx2(mut src: &[u8], mut dst: &mut [u8]) { - // 0, -1, 2, -1, 4, -1, 6, -1, 8, -1, 10, -1, 12, -1, 14, -1, - // 0, -1, 2, -1, 4, -1, 6, -1, 8, -1, 10, -1, 12, -1, 14, -1 - let mask_a = _mm256_setr_epi8( - 0, -1, 2, -1, 4, -1, 6, -1, 8, -1, 10, -1, 12, -1, 14, -1, 0, -1, 2, -1, 4, -1, 6, -1, 8, - -1, 10, -1, 12, -1, 14, -1, - ); - - // 1, -1, 3, -1, 5, -1, 7, -1, 9, -1, 11, -1, 13, -1, 15, -1, - // 1, -1, 3, -1, 5, -1, 7, -1, 9, -1, 11, -1, 13, -1, 15, -1 - let mask_b = _mm256_setr_epi8( - 1, -1, 3, -1, 5, -1, 7, -1, 9, -1, 11, -1, 13, -1, 15, -1, 1, -1, 3, -1, 5, -1, 7, -1, 9, - -1, 11, -1, 13, -1, 15, -1, - ); - - while dst.len() >= 32 { - let av1 = _mm256_loadu_si256(src.as_ptr() as *const _); - let av2 = _mm256_loadu_si256(src[32..].as_ptr() as *const _); - - let mut a1 = _mm256_shuffle_epi8(av1, mask_a); - let mut b1 = _mm256_shuffle_epi8(av1, mask_b); - let mut a2 = _mm256_shuffle_epi8(av2, mask_a); - let mut b2 = _mm256_shuffle_epi8(av2, mask_b); - - a1 = unhex_avx2(a1); - a2 = unhex_avx2(a2); - b1 = unhex_avx2(b1); - b2 = unhex_avx2(b2); - - let bytes = nib2byte_avx2(a1, b1, a2, b2); - - //dst does not need to be aligned on any particular boundary - _mm256_storeu_si256(dst.as_mut_ptr() as *mut _, bytes); - dst = &mut dst[32..]; - src = &src[64..]; - } - hex_decode_fallback(&src, &mut dst) -} + impl IsValid for Checked { + #[inline] + #[target_feature(enable = "avx2")] + unsafe fn is_valid(hi_nibbles: __m256i, low_nibbles: __m256i) -> bool { + let mask_lut = _mm256_setr_epi8( + 0b0000_1000, // 0 + 0b0101_1000, // 1 .. 6 + 0b0101_1000, // + 0b0101_1000, // + 0b0101_1000, // + 0b0101_1000, // + 0b0101_1000, // + 0b0000_1000, // 7 .. 9 + 0b0000_1000, // + 0b0000_1000, // + 0b0000_0000, // 10 .. 15 + 0b0000_0000, // + 0b0000_0000, // + 0b0000_0000, // + 0b0000_0000, // + 0b0000_0000, // + // + 0b0000_1000, // 0 + 0b0101_1000, // 1 .. 6 + 0b0101_1000, // + 0b0101_1000, // + 0b0101_1000, // + 0b0101_1000, // + 0b0101_1000, // + 0b0000_1000, // 7 .. 9 + 0b0000_1000, // + 0b0000_1000, // + 0b0000_0000, // 10 .. 15 + 0b0000_0000, // + 0b0000_0000, // + 0b0000_0000, // + 0b0000_0000, // + 0b0000_0000, // + ); + + #[allow(overflowing_literals)] + let bit_pos_lut = _mm256_setr_epi8( + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + ); + + let m = _mm256_shuffle_epi8(mask_lut, low_nibbles); + let bit = _mm256_shuffle_epi8(bit_pos_lut, hi_nibbles); + let non_match = _mm256_cmpeq_epi8(_mm256_and_si256(m, bit), _mm256_setzero_si256()); + _mm256_movemask_epi8(non_match) == 0 + } + } -pub fn hex_decode_fallback(src: &[u8], dst: &mut [u8]) { - for (slot, bytes) in dst.iter_mut().zip(src.chunks(2)) { - let a = unhex_a(bytes[0] as usize); - let b = unhex_b(bytes[1] as usize); - *slot = a | b; + impl IsValid for Unchecked { + #[inline] + #[target_feature(enable = "avx2")] + unsafe fn is_valid(_: __m256i, _: __m256i) -> bool { + true + } + } } -} -#[cfg(test)] -mod tests { - use crate::decode::hex_check_fallback; - use crate::decode::hex_check_sse; - use crate::decode::hex_decode_fallback; - use crate::encode::hex_string; - use proptest::{proptest, proptest_helper}; + #[cfg(all(feature = "sse41", any(target_arch = "x86", target_arch = "x86_64")))] + pub mod sse41 { + #[cfg(target_arch = "x86")] + use std::arch::x86::*; + #[cfg(target_arch = "x86_64")] + use std::arch::x86_64::*; - fn _test_decode_fallback(s: &String) { - let len = s.as_bytes().len(); - let mut dst = Vec::with_capacity(len); - dst.resize(len, 0); + use crate::decode::{Checked, Error, Unchecked}; - let hex_string = hex_string(s.as_bytes()).unwrap(); + #[target_feature(enable = "sse4.1")] + pub unsafe fn decode(src: &[u8], dst: &mut [u8]) -> Result<(), Error> { + _decode::(src, dst).map_err(|_| Error::InvalidChar) + } - hex_decode_fallback(hex_string.as_bytes(), &mut dst); + #[target_feature(enable = "sse4.1")] + pub unsafe fn decode_unchecked(src: &[u8], dst: &mut [u8]) { + let _ = _decode::(src, dst); + } - assert_eq!(&dst[..], s.as_bytes()); - } + #[inline] + #[target_feature(enable = "sse4.1")] + pub unsafe fn _decode(mut src: &[u8], mut dst: &mut [u8]) -> Result<(), ()> { + while src.len() >= 32 { + let av1 = _mm_loadu_si128(src.as_ptr() as *const _); + let av2 = _mm_loadu_si128(src[16..].as_ptr() as *const _); + let av1 = decode_chunk::(av1)?; + let av1 = _mm_shuffle_epi8( + av1, + _mm_setr_epi8(0, 2, 4, 6, 8, 10, 12, 14, -1, -1, -1, -1, -1, -1, -1, -1), + ); + let av2 = decode_chunk::(av2)?; + let av2 = _mm_shuffle_epi8( + av2, + _mm_setr_epi8(-1, -1, -1, -1, -1, -1, -1, -1, 0, 2, 4, 6, 8, 10, 12, 14), + ); + let decoded = _mm_or_si128(av1, av2); + _mm_storeu_si128(dst.as_mut_ptr() as *mut _, decoded); + dst = &mut dst[16..]; + src = &src[32..]; + } + crate::decode::arch::fallback::_decode::(&src, &mut dst) + } + + #[inline] + #[target_feature(enable = "sse4.1")] + unsafe fn decode_chunk(input: __m128i) -> Result<__m128i, ()> { + #[allow(overflowing_literals)] + let hi_nibbles = _mm_and_si128(_mm_srli_epi32(input, 4), _mm_set1_epi8(0b00001111)); + let low_nibbles = _mm_and_si128(input, _mm_set1_epi8(0b00001111)); + + if !::is_valid(hi_nibbles, low_nibbles) { + return Err(()); + } + + let shift_lut = _mm_setr_epi8(0, 0, 0, -48, -55, 0, -87, 0, 0, 0, 0, 0, 0, 0, 0, 0); - proptest! { - #[test] - fn test_decode_fallback(ref s in ".+") { - _test_decode_fallback(s); + let sh = _mm_shuffle_epi8(shift_lut, hi_nibbles); + let input = _mm_add_epi8(input, sh); + #[allow(overflowing_literals)] + let input = + _mm_maddubs_epi16(input, _mm_set1_epi32(0b00000001_00010000_00000001_00010000)); + Ok(input) } - } - fn _test_check_fallback_true(s: &String) { - assert!(hex_check_fallback(s.as_bytes())); - } + pub trait IsValid: crate::decode::arch::fallback::IsValid { + unsafe fn is_valid(hi_nibbles: __m128i, low_nibbles: __m128i) -> bool; + } - proptest! { - #[test] - fn test_check_fallback_true(ref s in "[0-9a-fA-F]+") { - _test_check_fallback_true(s); + impl IsValid for Checked { + #[inline] + #[target_feature(enable = "sse4.1")] + unsafe fn is_valid(hi_nibbles: __m128i, low_nibbles: __m128i) -> bool { + let mask_lut = _mm_setr_epi8( + 0b0000_1000, // 0 + 0b0101_1000, // 1 .. 6 + 0b0101_1000, // + 0b0101_1000, // + 0b0101_1000, // + 0b0101_1000, // + 0b0101_1000, // + 0b0000_1000, // 7 .. 9 + 0b0000_1000, // + 0b0000_1000, // + 0b0000_0000, // 10 .. 15 + 0b0000_0000, // + 0b0000_0000, // + 0b0000_0000, // + 0b0000_0000, // + 0b0000_0000, // + ); + + #[allow(overflowing_literals)] + let bit_pos_lut = _mm_setr_epi8( + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, + ); + + let m = _mm_shuffle_epi8(mask_lut, low_nibbles); + let bit = _mm_shuffle_epi8(bit_pos_lut, hi_nibbles); + let non_match = _mm_cmpeq_epi8(_mm_and_si128(m, bit), _mm_setzero_si128()); + _mm_movemask_epi8(non_match) == 0 + } } - } - fn _test_check_fallback_false(s: &String) { - assert!(!hex_check_fallback(s.as_bytes())); + impl IsValid for Unchecked { + #[inline] + #[target_feature(enable = "sse4.1")] + unsafe fn is_valid(_: __m128i, _: __m128i) -> bool { + true + } + } } - proptest! { - #[test] - fn test_check_fallback_false(ref s in ".{16}[^0-9a-fA-F]+") { - _test_check_fallback_false(s); + pub mod fallback { + use crate::decode::{Checked, Error, Unchecked}; + + #[inline] + pub fn decode(src: &[u8], dst: &mut [u8]) -> Result<(), Error> { + _decode::(src, dst).map_err(|_| Error::InvalidChar) } - } - fn _test_check_sse_true(s: &String) { - assert!(unsafe { hex_check_sse(s.as_bytes()) }); - } + #[inline] + pub fn decode_unchecked(src: &[u8], dst: &mut [u8]) { + let _ = _decode::(src, dst); + } - proptest! { - #[test] - fn test_check_sse_true(ref s in "([0-9a-fA-F][0-9a-fA-F])+") { - _test_check_sse_true(s); + #[inline] + pub fn _decode(src: &[u8], dst: &mut [u8]) -> Result<(), ()> { + for (slot, bytes) in dst.iter_mut().zip(src.chunks(2)) { + if !V::is_valid(bytes[0], bytes[1]) { + return Err(()); + } + let a = unhex_a(bytes[0]); + let b = unhex_b(bytes[1]); + *slot = a | b; + } + Ok(()) } - } - fn _test_check_sse_false(s: &String) { - assert!(!unsafe { hex_check_sse(s.as_bytes()) }); - } + pub trait IsValid { + fn is_valid(a: u8, b: u8) -> bool; + } - proptest! { - #[test] - fn test_check_sse_false(ref s in ".{16}[^0-9a-fA-F]+") { - _test_check_sse_false(s); + impl IsValid for Checked { + #[inline] + fn is_valid(a: u8, b: u8) -> bool { + (unhex_a(a) | unhex_a(b)) != 0xff + } + } + + impl IsValid for Unchecked { + #[inline] + fn is_valid(_: u8, _: u8) -> bool { + return true; + } + } + + // lower nibble + #[inline] + fn unhex_b(x: u8) -> u8 { + // ASCII -> hex + static UNHEX: [u8; 256] = [ + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 10, 11, 12, 13, 14, 15, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ]; + UNHEX[x as usize] + } + + // upper nibble, logically equivalent to unhex_b(x) << 4 + #[inline] + fn unhex_a(x: u8) -> u8 { + // ASCII -> hex << 4 + static UNHEX4: [u8; 256] = [ + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 160, 176, 192, 208, 224, 240, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 160, 176, 192, 208, 224, 240, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + ]; + UNHEX4[x as usize] + } + + #[cfg(test)] + mod tests { + use super::*; + use proptest::{proptest, proptest_helper}; + + fn _test_decode(s: &String) { + let len = s.as_bytes().len(); + let mut dst = Vec::with_capacity(len); + dst.resize(len, 0); + + let hex_string = crate::encode(s.as_bytes()); + + decode(hex_string.as_bytes(), &mut dst).unwrap(); + + assert_eq!(&dst[..], s.as_bytes()); + } + + proptest! { + #[test] + fn test_decode(ref s in ".+") { + _test_decode(s); + } + } } } } diff --git a/src/encode.rs b/src/encode.rs index 345b59a..d240277 100644 --- a/src/encode.rs +++ b/src/encode.rs @@ -1,20 +1,25 @@ #![allow(clippy::cast_ptr_alignment)] -#[cfg(target_arch = "x86")] -use std::arch::x86::*; -#[cfg(target_arch = "x86_64")] -use std::arch::x86_64::*; - use crate::error::Error; static TABLE: &[u8] = b"0123456789abcdef"; -pub fn hex_string(src: &[u8]) -> Result { +pub fn encode(src: &I) -> String +where + I: AsRef<[u8]> + ?Sized, +{ + let src = src.as_ref(); let mut buffer = vec![0; src.len() * 2]; - hex_encode(src, &mut buffer).map(|_| unsafe { String::from_utf8_unchecked(buffer) }) + // should never panic because the destination buffer is large enough. + encode_to_slice(src, &mut buffer).unwrap(); + unsafe { String::from_utf8_unchecked(buffer) } } -pub fn hex_encode(src: &[u8], dst: &mut [u8]) -> Result<(), Error> { +pub fn encode_to_slice(src: &I, dst: &mut [u8]) -> Result<(), Error> +where + I: AsRef<[u8]> + ?Sized, +{ + let src = src.as_ref(); let len = src.len().checked_mul(2).unwrap(); if dst.len() < len { return Err(Error::InvalidLength(len)); @@ -22,105 +27,123 @@ pub fn hex_encode(src: &[u8], dst: &mut [u8]) -> Result<(), Error> { #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] { - if is_x86_feature_detected!("avx2") { - unsafe { hex_encode_avx2(src, dst) }; - return Ok(()); + #[cfg(feature = "avx2")] + { + if is_x86_feature_detected!("avx2") && src.len() >= 16 { + unsafe { avx2::encode(src, dst) }; + return Ok(()); + } } - if is_x86_feature_detected!("sse4.1") { - unsafe { hex_encode_sse41(src, dst) }; - return Ok(()); + #[cfg(feature = "sse41")] + { + if is_x86_feature_detected!("sse4.1") && src.len() >= 16 { + unsafe { sse41::encode(src, dst) }; + return Ok(()); + } } } - hex_encode_fallback(src, dst); + encode_fallback(src, dst); Ok(()) } -#[deprecated(since = "0.3.0", note = "please use `hex_encode` instead")] -pub fn hex_to(src: &[u8], dst: &mut [u8]) -> Result<(), Error> { - hex_encode(src, dst) -} - -#[target_feature(enable = "avx2")] -#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -unsafe fn hex_encode_avx2(mut src: &[u8], dst: &mut [u8]) { - let ascii_zero = _mm256_set1_epi8(b'0' as i8); - let nines = _mm256_set1_epi8(9); - let ascii_a = _mm256_set1_epi8((b'a' - 9 - 1) as i8); - let and4bits = _mm256_set1_epi8(0xf); - - let mut i = 0_isize; - while src.len() >= 32 { - // https://stackoverflow.com/questions/47425851/whats-the-difference-between-mm256-lddqu-si256-and-mm256-loadu-si256 - let invec = _mm256_loadu_si256(src.as_ptr() as *const _); - - let masked1 = _mm256_and_si256(invec, and4bits); - let masked2 = _mm256_and_si256(_mm256_srli_epi64(invec, 4), and4bits); - - // return 0xff corresponding to the elements > 9, or 0x00 otherwise - let cmpmask1 = _mm256_cmpgt_epi8(masked1, nines); - let cmpmask2 = _mm256_cmpgt_epi8(masked2, nines); - - // add '0' or the offset depending on the masks - let masked1 = _mm256_add_epi8(masked1, _mm256_blendv_epi8(ascii_zero, ascii_a, cmpmask1)); - let masked2 = _mm256_add_epi8(masked2, _mm256_blendv_epi8(ascii_zero, ascii_a, cmpmask2)); - - // interleave masked1 and masked2 bytes - let res1 = _mm256_unpacklo_epi8(masked2, masked1); - let res2 = _mm256_unpackhi_epi8(masked2, masked1); - - // Store everything into the right destination now - let base = dst.as_mut_ptr().offset(i * 2); - let base1 = base.offset(0) as *mut _; - let base2 = base.offset(16) as *mut _; - let base3 = base.offset(32) as *mut _; - let base4 = base.offset(48) as *mut _; - _mm256_storeu2_m128i(base3, base1, res1); - _mm256_storeu2_m128i(base4, base2, res2); - src = &src[32..]; - i += 32; +#[cfg(all(feature = "avx2", any(target_arch = "x86", target_arch = "x86_64")))] +mod avx2 { + #[cfg(target_arch = "x86")] + use std::arch::x86::*; + #[cfg(target_arch = "x86_64")] + use std::arch::x86_64::*; + + #[target_feature(enable = "avx2")] + pub(super) unsafe fn encode(mut src: &[u8], mut dst: &mut [u8]) { + while src.len() >= 32 { + let input = _mm256_loadu_si256(src.as_ptr() as *const _); + _mm256_storeu_si256( + dst.as_mut_ptr() as *mut _, + encode_chunk(_mm256_castsi256_si128(input)), + ); + _mm256_storeu_si256( + dst.as_mut_ptr().offset(32) as *mut _, + encode_chunk(_mm256_extracti128_si256(input, 1)), + ); + src = &src[32..]; + dst = &mut dst[64..]; + } + if src.len() >= 16 { + let chunk = _mm_loadu_si128(src.as_ptr() as *const _); + _mm256_storeu_si256(dst.as_mut_ptr() as *mut _, encode_chunk(chunk)); + src = &src[16..]; + dst = &mut dst[32..]; + } + super::encode_fallback(src, dst); } - let i = i as usize; - hex_encode_sse41(src, &mut dst[i * 2..]); + #[target_feature(enable = "avx2")] + unsafe fn encode_chunk(input: __m128i) -> __m256i { + let hi = _mm_shuffle_epi8( + input, + _mm_setr_epi8(0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7), + ); + let lo = _mm_shuffle_epi8( + input, + _mm_setr_epi8(8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15), + ); + let joined = _mm256_set_m128i(lo, hi); + let shifted = _mm256_srlv_epi64(joined, _mm256_setr_epi64x(4, 0, 4, 0)); + let masked = _mm256_and_si256(shifted, _mm256_set1_epi8(0xf)); + let shuffled = _mm256_shuffle_epi8( + masked, + _mm256_setr_epi8( + 0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15, 0, 8, 1, 9, 2, 10, 3, 11, 4, + 12, 5, 13, 6, 14, 7, 15, + ), + ); + let offset_lut = _mm256_setr_epi8( + 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 87, 87, 87, 87, 87, 87, 48, 48, 48, 48, 48, 48, + 48, 48, 48, 48, 87, 87, 87, 87, 87, 87, + ); + let offsets = _mm256_shuffle_epi8(offset_lut, shuffled); + _mm256_add_epi8(shuffled, offsets) + } } -// copied from https://github.com/Matherunner/bin2hex-sse/blob/master/base16_sse4.cpp -#[target_feature(enable = "sse4.1")] -#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -unsafe fn hex_encode_sse41(mut src: &[u8], dst: &mut [u8]) { - let ascii_zero = _mm_set1_epi8(b'0' as i8); - let nines = _mm_set1_epi8(9); - let ascii_a = _mm_set1_epi8((b'a' - 9 - 1) as i8); - let and4bits = _mm_set1_epi8(0xf); - - let mut i = 0_isize; - while src.len() >= 16 { - let invec = _mm_loadu_si128(src.as_ptr() as *const _); - - let masked1 = _mm_and_si128(invec, and4bits); - let masked2 = _mm_and_si128(_mm_srli_epi64(invec, 4), and4bits); - - // return 0xff corresponding to the elements > 9, or 0x00 otherwise - let cmpmask1 = _mm_cmpgt_epi8(masked1, nines); - let cmpmask2 = _mm_cmpgt_epi8(masked2, nines); - - // add '0' or the offset depending on the masks - let masked1 = _mm_add_epi8(masked1, _mm_blendv_epi8(ascii_zero, ascii_a, cmpmask1)); - let masked2 = _mm_add_epi8(masked2, _mm_blendv_epi8(ascii_zero, ascii_a, cmpmask2)); - - // interleave masked1 and masked2 bytes - let res1 = _mm_unpacklo_epi8(masked2, masked1); - let res2 = _mm_unpackhi_epi8(masked2, masked1); - - _mm_storeu_si128(dst.as_mut_ptr().offset(i * 2) as *mut _, res1); - _mm_storeu_si128(dst.as_mut_ptr().offset(i * 2 + 16) as *mut _, res2); - src = &src[16..]; - i += 16; - } +#[cfg(all(feature = "sse41", any(target_arch = "x86", target_arch = "x86_64")))] +mod sse41 { + #[cfg(target_arch = "x86")] + use std::arch::x86::*; + #[cfg(target_arch = "x86_64")] + use std::arch::x86_64::*; + + #[target_feature(enable = "sse4.1")] + pub(super) unsafe fn encode(mut src: &[u8], mut dst: &mut [u8]) { + let and4bits = _mm_set1_epi8(0xf); - let i = i as usize; - hex_encode_fallback(src, &mut dst[i * 2..]); + while src.len() >= 16 { + let invec = _mm_loadu_si128(src.as_ptr() as *const _); + + let masked1 = _mm_and_si128(invec, and4bits); + let masked2 = _mm_and_si128(_mm_srli_epi64(invec, 4), and4bits); + + let offset_lut = _mm_setr_epi8( + 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 87, 87, 87, 87, 87, 87, + ); + let offsets1 = _mm_shuffle_epi8(offset_lut, masked1); + let offsets2 = _mm_shuffle_epi8(offset_lut, masked2); + + let masked1 = _mm_add_epi8(masked1, offsets1); + let masked2 = _mm_add_epi8(masked2, offsets2); + + // interleave masked1 and masked2 bytes + let res1 = _mm_unpacklo_epi8(masked2, masked1); + let res2 = _mm_unpackhi_epi8(masked2, masked1); + + _mm_storeu_si128(dst.as_mut_ptr() as *mut _, res1); + _mm_storeu_si128(dst.as_mut_ptr().offset(16) as *mut _, res2); + src = &src[16..]; + dst = &mut dst[32..]; + } + super::encode_fallback(src, dst); + } } #[inline] @@ -128,7 +151,7 @@ fn hex(byte: u8) -> u8 { TABLE[byte as usize] } -pub fn hex_encode_fallback(src: &[u8], dst: &mut [u8]) { +pub fn encode_fallback(src: &[u8], dst: &mut [u8]) { for (byte, slots) in src.iter().zip(dst.chunks_mut(2)) { slots[0] = hex((*byte >> 4) & 0xf); slots[1] = hex(*byte & 0xf); @@ -137,13 +160,13 @@ pub fn hex_encode_fallback(src: &[u8], dst: &mut [u8]) { #[cfg(test)] mod tests { - use crate::encode::hex_encode_fallback; + use crate::encode::encode_fallback; use proptest::{proptest, proptest_helper}; use std::str; fn _test_encode_fallback(s: &String) { let mut buffer = vec![0; s.as_bytes().len() * 2]; - hex_encode_fallback(s.as_bytes(), &mut buffer); + encode_fallback(s.as_bytes(), &mut buffer); let encode = unsafe { str::from_utf8_unchecked(&buffer[..s.as_bytes().len() * 2]) }; assert_eq!(encode, hex::encode(s)); } diff --git a/src/lib.rs b/src/lib.rs index 3c4e825..89659c4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,77 +1,81 @@ mod decode; mod encode; mod error; -pub use crate::decode::{ - hex_check_fallback, hex_decode, hex_decode_fallback, hex_decode_unchecked, -}; -pub use crate::encode::{hex_encode, hex_encode_fallback, hex_string, hex_to}; +pub use crate::decode::{decode, decode_to_slice, decode_to_slice_unchecked}; +pub use crate::encode::{encode, encode_to_slice}; pub use crate::error::Error; -#[cfg(any(target_arch = "x86", target_arch = "x86_64", feature = "sse4.1"))] -pub use crate::decode::hex_check_sse; +#[cfg(feature = "bench")] +pub use crate::{ + decode::arch::fallback::{ + decode as decode_fallback, decode_unchecked as decode_unchecked_fallback, + }, + encode::encode_fallback, +}; #[cfg(test)] mod tests { - use crate::decode::hex_decode; - use crate::encode::{hex_encode, hex_string}; + use crate::decode::decode; + use crate::encode::{encode, encode_to_slice}; use proptest::{proptest, proptest_helper}; use std::str; - fn _test_hex_encode(s: &String) { + fn _test_encode(s: &String) { let mut buffer = vec![0; s.as_bytes().len() * 2]; - hex_encode(s.as_bytes(), &mut buffer).unwrap(); - let encode = unsafe { str::from_utf8_unchecked(&buffer[..s.as_bytes().len() * 2]) }; + encode_to_slice(s.as_bytes(), &mut buffer).unwrap(); + let encoded = unsafe { str::from_utf8_unchecked(&buffer[..s.as_bytes().len() * 2]) }; - let hex_string = hex_string(s.as_bytes()).unwrap(); + let hex_string = encode(s); - assert_eq!(encode, hex::encode(s)); + assert_eq!(encoded, hex::encode(s)); assert_eq!(hex_string, hex::encode(s)); } proptest! { #[test] - fn test_hex_encode(ref s in ".*") { - _test_hex_encode(s); + fn test_encode(ref s in ".*") { + _test_encode(s); } } - fn _test_hex_decode(s: &String) { - let len = s.as_bytes().len(); - let mut dst = Vec::with_capacity(len); - dst.resize(len, 0); - - let hex_string = hex_string(s.as_bytes()).unwrap(); - - hex_decode(hex_string.as_bytes(), &mut dst).unwrap(); - - assert_eq!(&dst[..], s.as_bytes()); + fn _test_decode_check(s: &String, ok: bool) { + assert!(decode(s).is_ok() == ok); } proptest! { #[test] - fn test_hex_decode(ref s in ".+") { - _test_hex_decode(s); + fn test_decode_check(ref s in "([0-9a-fA-F][0-9a-fA-F])+") { + _test_decode_check(s, true); } } - fn _test_hex_decode_check(s: &String, ok: bool) { - let len = s.as_bytes().len(); - let mut dst = Vec::with_capacity(len / 2); - dst.resize(len / 2, 0); - assert!(hex_decode(s.as_bytes(), &mut dst).is_ok() == ok); - } - proptest! { #[test] - fn test_hex_decode_check(ref s in "([0-9a-fA-F][0-9a-fA-F])+") { - _test_hex_decode_check(s, true); + fn test_decode_check_odd(ref s in "[0-9a-fA-F]{11}") { + _test_decode_check(s, false); } } proptest! { #[test] - fn test_hex_decode_check_odd(ref s in "[0-9a-fA-F]{11}") { - _test_hex_decode_check(s, false); + fn test_roundtrip(input: Vec) { + let encoded = encode(&input); + let decoded = decode(&encoded).unwrap(); + assert_eq!(&decoded, &input); + } + + #[test] + fn test_encode_matches(input: Vec) { + let encoded = encode(&input); + let expected = hex::encode(&input); + assert_eq!(encoded, expected); + } + + #[test] + fn test_decode_matches(input: Vec) { + let decoded = decode(&input).map_err(|_| ()); + let expected = hex::decode(&input).map_err(|_| ()); + assert_eq!(decoded, expected); } } }