diff --git a/src/decode.rs b/src/decode.rs index 0e2af54..86c5bd8 100644 --- a/src/decode.rs +++ b/src/decode.rs @@ -1,5 +1,7 @@ // avx2 decode modified from https://github.com/zbjornson/fast-hex/blob/master/src/hex.cc +#[cfg(target_arch = "aarch64")] +use core::arch::aarch64::*; #[cfg(target_arch = "x86")] use core::arch::x86::*; #[cfg(target_arch = "x86_64")] @@ -113,7 +115,15 @@ pub fn hex_check_with_case(src: &[u8], check_case: CheckCase) -> bool { } } - #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))] + #[cfg(target_arch = "aarch64")] + { + match crate::vectorization_support() { + crate::Vectorization::Neon => unsafe { hex_check_neon_with_case(src, check_case) }, + crate::Vectorization::None => hex_check_fallback_with_case(src, check_case), + } + } + + #[cfg(not(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64")))] hex_check_fallback_with_case(src, check_case) } @@ -210,6 +220,72 @@ pub unsafe fn hex_check_sse_with_case(mut src: &[u8], check_case: CheckCase) -> hex_check_fallback_with_case(src, check_case) } +#[target_feature(enable = "neon")] +#[cfg(target_arch = "aarch64")] +pub unsafe fn hex_check_neon(src: &[u8]) -> bool { + hex_check_neon_with_case(src, CheckCase::None) +} + +#[target_feature(enable = "neon")] +#[cfg(target_arch = "aarch64")] +pub unsafe fn hex_check_neon_with_case(mut src: &[u8], check_case: CheckCase) -> bool { + let ascii_zero = vdupq_n_u8(b'0' - 1); + let ascii_nine = vdupq_n_u8(b'9' + 1); + let ascii_ua = vdupq_n_u8(b'A' - 1); + let ascii_uf = vdupq_n_u8(b'F' + 1); + let ascii_la = vdupq_n_u8(b'a' - 1); + let ascii_lf = vdupq_n_u8(b'f' + 1); + + while src.len() >= 16 { + let unchecked = vld1q_u8(src.as_ptr() as *const _); + + let gt0 = vcgtq_u8(unchecked, ascii_zero); + let lt9 = vcltq_u8(unchecked, ascii_nine); + let valid_digit = vandq_u8(gt0, lt9); + + let (valid_la_lf, valid_ua_uf) = match check_case { + CheckCase::None => { + let gtua = vcgtq_u8(unchecked, ascii_ua); + let ltuf = vcltq_u8(unchecked, ascii_uf); + + let gtla = vcgtq_u8(unchecked, ascii_la); + let ltlf = vcltq_u8(unchecked, ascii_lf); + + (Some(vandq_u8(gtla, ltlf)), Some(vandq_u8(gtua, ltuf))) + } + CheckCase::Lower => { + let gtla = vcgtq_u8(unchecked, ascii_la); + let ltlf = vcltq_u8(unchecked, ascii_lf); + + (Some(vandq_u8(gtla, ltlf)), None) + } + CheckCase::Upper => { + let gtua = vcgtq_u8(unchecked, ascii_ua); + let ltuf = vcltq_u8(unchecked, ascii_uf); + + (None, Some(vandq_u8(gtua, ltuf))) + } + }; + + let valid_letter = match (valid_la_lf, valid_ua_uf) { + (Some(valid_lower), Some(valid_upper)) => vorrq_u8(valid_lower, valid_upper), + (Some(valid_lower), None) => valid_lower, + (None, Some(valid_upper)) => valid_upper, + _ => unreachable!(), + }; + + let ret = vminvq_u8(vorrq_u8(valid_digit, valid_letter)); + + if ret == 0 { + return false; + } + + src = &src[16..]; + } + + hex_check_fallback_with_case(src, check_case) +} + /// Hex decode src into dst. /// The length of src must be even, and it's allowed to decode a zero length src. /// The length of dst must be at least src.len() / 2. @@ -454,15 +530,25 @@ mod tests { } } -#[cfg(all(test, any(target_arch = "x86", target_arch = "x86_64")))] -mod test_sse { +#[cfg(all( + test, + any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64") +))] +mod test_simd { use crate::decode::{ - hex_check, hex_check_fallback, hex_check_fallback_with_case, hex_check_sse, - hex_check_sse_with_case, hex_check_with_case, hex_decode, hex_decode_unchecked, - hex_decode_with_case, CheckCase, + hex_check, hex_check_fallback, hex_check_fallback_with_case, hex_check_with_case, + hex_decode, hex_decode_unchecked, hex_decode_with_case, CheckCase, }; + #[cfg(target_arch = "aarch64")] + use crate::decode::{hex_check_neon, hex_check_neon_with_case}; + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + use crate::decode::{hex_check_sse, hex_check_sse_with_case}; + #[cfg(target_arch = "aarch64")] + use std::arch::is_aarch64_feature_detected; + use proptest::proptest; + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] fn _test_check_sse_with_case(s: &String, check_case: CheckCase, expect_result: bool) { if is_x86_feature_detected!("sse4.1") { assert_eq!( @@ -472,12 +558,14 @@ mod test_sse { } } + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] fn _test_check_sse_true(s: &String) { if is_x86_feature_detected!("sse4.1") { assert!(unsafe { hex_check_sse(s.as_bytes()) }); } } + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] proptest! { #[test] fn test_check_sse_true(ref s in "([0-9a-fA-F][0-9a-fA-F])+") { @@ -504,12 +592,13 @@ mod test_sse { } } + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] fn _test_check_sse_false(s: &String) { if is_x86_feature_detected!("sse4.1") { assert!(!unsafe { hex_check_sse(s.as_bytes()) }); } } - + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] proptest! { #[test] fn test_check_sse_false(ref s in ".{16}[^0-9a-fA-F]+") { @@ -520,6 +609,67 @@ mod test_sse { } } + #[cfg(target_arch = "aarch64")] + fn _test_check_neon_with_case(s: &String, check_case: CheckCase, expect_result: bool) { + if is_aarch64_feature_detected!("neon") { + assert_eq!( + unsafe { hex_check_neon_with_case(s.as_bytes(), check_case) }, + expect_result + ) + } + } + + #[cfg(target_arch = "aarch64")] + fn _test_check_neon_true(s: &String) { + if is_aarch64_feature_detected!("neon") { + assert!(unsafe { hex_check_neon(s.as_bytes()) }); + } + } + + #[cfg(target_arch = "aarch64")] + proptest! { + #[test] + fn test_check_neon_true(ref s in "([0-9a-fA-F][0-9a-fA-F])+") { + _test_check_neon_true(s); + _test_check_neon_with_case(s, CheckCase::None, true); + match (s.contains(char::is_lowercase), s.contains(char::is_uppercase)){ + (true, true) => { + _test_check_neon_with_case(s, CheckCase::Lower, false); + _test_check_neon_with_case(s, CheckCase::Upper, false); + }, + (true, false) => { + _test_check_neon_with_case(s, CheckCase::Lower, true); + _test_check_neon_with_case(s, CheckCase::Upper, false); + }, + (false, true) => { + _test_check_neon_with_case(s, CheckCase::Lower, false); + _test_check_neon_with_case(s, CheckCase::Upper, true); + }, + (false, false) => { + _test_check_neon_with_case(s, CheckCase::Lower, true); + _test_check_neon_with_case(s, CheckCase::Upper, true); + } + } + } + } + + #[cfg(target_arch = "aarch64")] + fn _test_check_neon_false(s: &String) { + if is_aarch64_feature_detected!("neon") { + assert!(!unsafe { hex_check_neon(s.as_bytes()) }); + } + } + #[cfg(target_arch = "aarch64")] + proptest! { + #[test] + fn test_check_neon_false(ref s in ".{16}[^0-9a-fA-F]+") { + _test_check_neon_false(s); + _test_check_neon_with_case(s, CheckCase::None, false); + _test_check_neon_with_case(s, CheckCase::Lower, false); + _test_check_neon_with_case(s, CheckCase::Upper, false); + } + } + #[test] fn test_decode_zero_length_src_should_not_be_ok() { let src = b""; @@ -535,11 +685,18 @@ mod test_sse { assert!(hex_check_fallback(src)); assert!(hex_check_fallback_with_case(src, CheckCase::None)); + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] if is_x86_feature_detected!("sse4.1") { assert!(unsafe { hex_check_sse_with_case(src, CheckCase::None) }); assert!(unsafe { hex_check_sse(src) }); } + #[cfg(target_arch = "aarch64")] + if is_aarch64_feature_detected!("neon") { + assert!(unsafe { hex_check_neon_with_case(src, CheckCase::None) }); + assert!(unsafe { hex_check_neon(src) }); + } + // this function have no return value, so we just execute it and expect no panic hex_decode_unchecked(src, &mut dst); } diff --git a/src/encode.rs b/src/encode.rs index 5095778..a5234a2 100644 --- a/src/encode.rs +++ b/src/encode.rs @@ -3,6 +3,9 @@ use core::arch::x86::*; #[cfg(target_arch = "x86_64")] use core::arch::x86_64::*; +#[cfg(target_arch = "aarch64")] +use core::arch::aarch64::*; + #[cfg(feature = "alloc")] use alloc::{string::String, vec}; @@ -102,7 +105,16 @@ pub fn hex_encode_custom<'a>( // Safety: We just wrote valid utf8 hex string into the dst return Ok(unsafe { mut_str(dst) }); } - #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))] + #[cfg(target_arch = "aarch64")] + { + match crate::vectorization_support() { + crate::Vectorization::Neon => unsafe { hex_encode_neon(src, dst, upper_case) }, + crate::Vectorization::None => hex_encode_custom_case_fallback(src, dst, upper_case), + } + // Safety: We just wrote valid utf8 hex string into the dst + return Ok(unsafe { mut_str(dst) }); + } + #[cfg(not(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64")))] { hex_encode_custom_case_fallback(src, dst, upper_case); // Saftey: We just wrote valid utf8 hex string into the dst @@ -215,6 +227,49 @@ unsafe fn hex_encode_sse41(mut src: &[u8], dst: &mut [u8], upper_case: bool) { hex_encode_custom_case_fallback(src, &mut dst[i * 2..], upper_case); } +#[target_feature(enable = "neon")] +#[cfg(target_arch = "aarch64")] +unsafe fn hex_encode_neon(mut src: &[u8], dst: &mut [u8], upper_case: bool) { + let ascii_zero = vdupq_n_u8(b'0'); + let nines = vdupq_n_u8(9); + let ascii_a = if upper_case { + vdupq_n_u8(b'A' - 9 - 1) + } else { + vdupq_n_u8(b'a' - 9 - 1) + }; + let and4bits = vdupq_n_u8(0xf); + + let mut i = 0_isize; + + while src.len() >= 16 { + let invec = vld1q_u8(src.as_ptr() as *const _); + + let masked1 = vandq_u8(invec, and4bits); + let masked2 = vandq_u8(vshrq_n_u8::<4>(invec), and4bits); + + // return 0xff corresponding to the elements > 9, or 0x00 otherwise + let cmpmask1 = vcgtq_u8(masked1, nines); + let cmpmask2 = vcgtq_u8(masked2, nines); + + // add '0' or the offset depending on the masks + let masked1 = vaddq_u8(masked1, vbslq_u8(cmpmask1, ascii_a, ascii_zero)); + let masked2 = vaddq_u8(masked2, vbslq_u8(cmpmask2, ascii_a, ascii_zero)); + + // interleave masked1 and masked2 bytes + let res1 = vzip1q_u8(masked2, masked1); + let res2 = vzip2q_u8(masked2, masked1); + + vst1q_u8(dst.as_mut_ptr().offset(i * 2) as *mut _, res1); + vst1q_u8(dst.as_mut_ptr().offset(i * 2 + 16) as *mut _, res2); + + src = &src[16..]; + i += 16; + } + + let i = i as usize; + hex_encode_custom_case_fallback(src, &mut dst[i * 2..], upper_case); +} + #[inline] fn hex_lower(byte: u8) -> u8 { TABLE_LOWER[byte as usize] diff --git a/src/lib.rs b/src/lib.rs index 0dfbb83..badfd0e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -33,12 +33,19 @@ pub use crate::encode::hex_to; #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] pub use crate::decode::{hex_check_sse, hex_check_sse_with_case}; +#[cfg(target_arch = "aarch64")] +pub use crate::decode::{hex_check_neon, hex_check_neon_with_case}; + #[derive(Copy, Clone, PartialEq, Eq, Debug)] #[cfg_attr(feature = "defmt-03", derive(defmt::Format))] pub(crate) enum Vectorization { None = 0, + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] SSE41 = 1, + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] AVX2 = 2, + #[cfg(target_arch = "aarch64")] + Neon = 3, } #[inline(always)] @@ -68,6 +75,27 @@ pub(crate) fn vectorization_support() -> Vectorization { FLAGS.store(val as u8, Ordering::Relaxed); return val; } + + #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] + { + // reuse flag code from x86 impl + use core::sync::atomic::{AtomicU8, Ordering}; + static FLAGS: AtomicU8 = AtomicU8::new(u8::MAX); + + let current_flags = FLAGS.load(Ordering::Relaxed); + if current_flags != u8::MAX { + return match current_flags { + 0 => Vectorization::None, + 3 => Vectorization::Neon, + _ => unreachable!(), + }; + } + + let val = vectorization_support_no_cache_arm(); + FLAGS.store(val as u8, Ordering::Relaxed); + return val; + } + #[allow(unreachable_code)] Vectorization::None } @@ -128,6 +156,20 @@ unsafe fn avx2_support_no_cache_x86() -> bool { false } +#[cfg(target_arch = "aarch64")] +#[cold] +fn vectorization_support_no_cache_arm() -> Vectorization { + #[cfg(feature = "std")] + if std::arch::is_aarch64_feature_detected!("neon") { + return Vectorization::Neon; + } + #[cfg(target_feature = "neon")] + return Vectorization::Neon; + + #[allow(unreachable_code)] + Vectorization::None +} + #[cfg(test)] mod tests { use crate::decode::{hex_decode, hex_decode_with_case, CheckCase}; @@ -152,7 +194,16 @@ mod tests { ), } } - #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))] + + #[cfg(target_arch = "aarch64")] + match vector_support { + Vectorization::Neon => assert!(std::arch::is_aarch64_feature_detected!("neon")), + Vectorization::None => assert!( + !cfg!(target_feature = "neon") || !std::arch::is_aarch64_feature_detected!("neon") + ), + } + + #[cfg(not(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64")))] assert_eq!(vector_support, Vectorization::None); }