Skip to content

Commit

Permalink
Merge pull request #56 from Lynnesbian/neon
Browse files Browse the repository at this point in the history
Add NEON encode and check
  • Loading branch information
quake authored Sep 27, 2024
2 parents b528490 + b67c41f commit 4acf38e
Show file tree
Hide file tree
Showing 3 changed files with 272 additions and 9 deletions.
171 changes: 164 additions & 7 deletions src/decode.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
// avx2 decode modified from https://github.com/zbjornson/fast-hex/blob/master/src/hex.cc

#[cfg(target_arch = "aarch64")]
use core::arch::aarch64::*;
#[cfg(target_arch = "x86")]
use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
Expand Down Expand Up @@ -113,7 +115,15 @@ pub fn hex_check_with_case(src: &[u8], check_case: CheckCase) -> bool {
}
}

#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
#[cfg(target_arch = "aarch64")]
{
match crate::vectorization_support() {
crate::Vectorization::Neon => unsafe { hex_check_neon_with_case(src, check_case) },
crate::Vectorization::None => hex_check_fallback_with_case(src, check_case),
}
}

#[cfg(not(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64")))]
hex_check_fallback_with_case(src, check_case)
}

Expand Down Expand Up @@ -210,6 +220,72 @@ pub unsafe fn hex_check_sse_with_case(mut src: &[u8], check_case: CheckCase) ->
hex_check_fallback_with_case(src, check_case)
}

#[target_feature(enable = "neon")]
#[cfg(target_arch = "aarch64")]
pub unsafe fn hex_check_neon(src: &[u8]) -> bool {
hex_check_neon_with_case(src, CheckCase::None)
}

#[target_feature(enable = "neon")]
#[cfg(target_arch = "aarch64")]
pub unsafe fn hex_check_neon_with_case(mut src: &[u8], check_case: CheckCase) -> bool {
let ascii_zero = vdupq_n_u8(b'0' - 1);
let ascii_nine = vdupq_n_u8(b'9' + 1);
let ascii_ua = vdupq_n_u8(b'A' - 1);
let ascii_uf = vdupq_n_u8(b'F' + 1);
let ascii_la = vdupq_n_u8(b'a' - 1);
let ascii_lf = vdupq_n_u8(b'f' + 1);

while src.len() >= 16 {
let unchecked = vld1q_u8(src.as_ptr() as *const _);

let gt0 = vcgtq_u8(unchecked, ascii_zero);
let lt9 = vcltq_u8(unchecked, ascii_nine);
let valid_digit = vandq_u8(gt0, lt9);

let (valid_la_lf, valid_ua_uf) = match check_case {
CheckCase::None => {
let gtua = vcgtq_u8(unchecked, ascii_ua);
let ltuf = vcltq_u8(unchecked, ascii_uf);

let gtla = vcgtq_u8(unchecked, ascii_la);
let ltlf = vcltq_u8(unchecked, ascii_lf);

(Some(vandq_u8(gtla, ltlf)), Some(vandq_u8(gtua, ltuf)))
}
CheckCase::Lower => {
let gtla = vcgtq_u8(unchecked, ascii_la);
let ltlf = vcltq_u8(unchecked, ascii_lf);

(Some(vandq_u8(gtla, ltlf)), None)
}
CheckCase::Upper => {
let gtua = vcgtq_u8(unchecked, ascii_ua);
let ltuf = vcltq_u8(unchecked, ascii_uf);

(None, Some(vandq_u8(gtua, ltuf)))
}
};

let valid_letter = match (valid_la_lf, valid_ua_uf) {
(Some(valid_lower), Some(valid_upper)) => vorrq_u8(valid_lower, valid_upper),
(Some(valid_lower), None) => valid_lower,
(None, Some(valid_upper)) => valid_upper,
_ => unreachable!(),
};

let ret = vminvq_u8(vorrq_u8(valid_digit, valid_letter));

if ret == 0 {
return false;
}

src = &src[16..];
}

hex_check_fallback_with_case(src, check_case)
}

/// Hex decode src into dst.
/// The length of src must be even, and it's allowed to decode a zero length src.
/// The length of dst must be at least src.len() / 2.
Expand Down Expand Up @@ -454,15 +530,25 @@ mod tests {
}
}

#[cfg(all(test, any(target_arch = "x86", target_arch = "x86_64")))]
mod test_sse {
#[cfg(all(
test,
any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64")
))]
mod test_simd {
use crate::decode::{
hex_check, hex_check_fallback, hex_check_fallback_with_case, hex_check_sse,
hex_check_sse_with_case, hex_check_with_case, hex_decode, hex_decode_unchecked,
hex_decode_with_case, CheckCase,
hex_check, hex_check_fallback, hex_check_fallback_with_case, hex_check_with_case,
hex_decode, hex_decode_unchecked, hex_decode_with_case, CheckCase,
};
#[cfg(target_arch = "aarch64")]
use crate::decode::{hex_check_neon, hex_check_neon_with_case};
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
use crate::decode::{hex_check_sse, hex_check_sse_with_case};
#[cfg(target_arch = "aarch64")]
use std::arch::is_aarch64_feature_detected;

use proptest::proptest;

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
fn _test_check_sse_with_case(s: &String, check_case: CheckCase, expect_result: bool) {
if is_x86_feature_detected!("sse4.1") {
assert_eq!(
Expand All @@ -472,12 +558,14 @@ mod test_sse {
}
}

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
fn _test_check_sse_true(s: &String) {
if is_x86_feature_detected!("sse4.1") {
assert!(unsafe { hex_check_sse(s.as_bytes()) });
}
}

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
proptest! {
#[test]
fn test_check_sse_true(ref s in "([0-9a-fA-F][0-9a-fA-F])+") {
Expand All @@ -504,12 +592,13 @@ mod test_sse {
}
}

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
fn _test_check_sse_false(s: &String) {
if is_x86_feature_detected!("sse4.1") {
assert!(!unsafe { hex_check_sse(s.as_bytes()) });
}
}

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
proptest! {
#[test]
fn test_check_sse_false(ref s in ".{16}[^0-9a-fA-F]+") {
Expand All @@ -520,6 +609,67 @@ mod test_sse {
}
}

#[cfg(target_arch = "aarch64")]
fn _test_check_neon_with_case(s: &String, check_case: CheckCase, expect_result: bool) {
if is_aarch64_feature_detected!("neon") {
assert_eq!(
unsafe { hex_check_neon_with_case(s.as_bytes(), check_case) },
expect_result
)
}
}

#[cfg(target_arch = "aarch64")]
fn _test_check_neon_true(s: &String) {
if is_aarch64_feature_detected!("neon") {
assert!(unsafe { hex_check_neon(s.as_bytes()) });
}
}

#[cfg(target_arch = "aarch64")]
proptest! {
#[test]
fn test_check_neon_true(ref s in "([0-9a-fA-F][0-9a-fA-F])+") {
_test_check_neon_true(s);
_test_check_neon_with_case(s, CheckCase::None, true);
match (s.contains(char::is_lowercase), s.contains(char::is_uppercase)){
(true, true) => {
_test_check_neon_with_case(s, CheckCase::Lower, false);
_test_check_neon_with_case(s, CheckCase::Upper, false);
},
(true, false) => {
_test_check_neon_with_case(s, CheckCase::Lower, true);
_test_check_neon_with_case(s, CheckCase::Upper, false);
},
(false, true) => {
_test_check_neon_with_case(s, CheckCase::Lower, false);
_test_check_neon_with_case(s, CheckCase::Upper, true);
},
(false, false) => {
_test_check_neon_with_case(s, CheckCase::Lower, true);
_test_check_neon_with_case(s, CheckCase::Upper, true);
}
}
}
}

#[cfg(target_arch = "aarch64")]
fn _test_check_neon_false(s: &String) {
if is_aarch64_feature_detected!("neon") {
assert!(!unsafe { hex_check_neon(s.as_bytes()) });
}
}
#[cfg(target_arch = "aarch64")]
proptest! {
#[test]
fn test_check_neon_false(ref s in ".{16}[^0-9a-fA-F]+") {
_test_check_neon_false(s);
_test_check_neon_with_case(s, CheckCase::None, false);
_test_check_neon_with_case(s, CheckCase::Lower, false);
_test_check_neon_with_case(s, CheckCase::Upper, false);
}
}

#[test]
fn test_decode_zero_length_src_should_not_be_ok() {
let src = b"";
Expand All @@ -535,11 +685,18 @@ mod test_sse {
assert!(hex_check_fallback(src));
assert!(hex_check_fallback_with_case(src, CheckCase::None));

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
if is_x86_feature_detected!("sse4.1") {
assert!(unsafe { hex_check_sse_with_case(src, CheckCase::None) });
assert!(unsafe { hex_check_sse(src) });
}

#[cfg(target_arch = "aarch64")]
if is_aarch64_feature_detected!("neon") {
assert!(unsafe { hex_check_neon_with_case(src, CheckCase::None) });
assert!(unsafe { hex_check_neon(src) });
}

// this function have no return value, so we just execute it and expect no panic
hex_decode_unchecked(src, &mut dst);
}
Expand Down
57 changes: 56 additions & 1 deletion src/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;

#[cfg(target_arch = "aarch64")]
use core::arch::aarch64::*;

#[cfg(feature = "alloc")]
use alloc::{string::String, vec};

Expand Down Expand Up @@ -102,7 +105,16 @@ pub fn hex_encode_custom<'a>(
// Safety: We just wrote valid utf8 hex string into the dst
return Ok(unsafe { mut_str(dst) });
}
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
#[cfg(target_arch = "aarch64")]
{
match crate::vectorization_support() {
crate::Vectorization::Neon => unsafe { hex_encode_neon(src, dst, upper_case) },
crate::Vectorization::None => hex_encode_custom_case_fallback(src, dst, upper_case),
}
// Safety: We just wrote valid utf8 hex string into the dst
return Ok(unsafe { mut_str(dst) });
}
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64")))]
{
hex_encode_custom_case_fallback(src, dst, upper_case);
// Saftey: We just wrote valid utf8 hex string into the dst
Expand Down Expand Up @@ -215,6 +227,49 @@ unsafe fn hex_encode_sse41(mut src: &[u8], dst: &mut [u8], upper_case: bool) {
hex_encode_custom_case_fallback(src, &mut dst[i * 2..], upper_case);
}

#[target_feature(enable = "neon")]
#[cfg(target_arch = "aarch64")]
unsafe fn hex_encode_neon(mut src: &[u8], dst: &mut [u8], upper_case: bool) {
let ascii_zero = vdupq_n_u8(b'0');
let nines = vdupq_n_u8(9);
let ascii_a = if upper_case {
vdupq_n_u8(b'A' - 9 - 1)
} else {
vdupq_n_u8(b'a' - 9 - 1)
};
let and4bits = vdupq_n_u8(0xf);

let mut i = 0_isize;

while src.len() >= 16 {
let invec = vld1q_u8(src.as_ptr() as *const _);

let masked1 = vandq_u8(invec, and4bits);
let masked2 = vandq_u8(vshrq_n_u8::<4>(invec), and4bits);

// return 0xff corresponding to the elements > 9, or 0x00 otherwise
let cmpmask1 = vcgtq_u8(masked1, nines);
let cmpmask2 = vcgtq_u8(masked2, nines);

// add '0' or the offset depending on the masks
let masked1 = vaddq_u8(masked1, vbslq_u8(cmpmask1, ascii_a, ascii_zero));
let masked2 = vaddq_u8(masked2, vbslq_u8(cmpmask2, ascii_a, ascii_zero));

// interleave masked1 and masked2 bytes
let res1 = vzip1q_u8(masked2, masked1);
let res2 = vzip2q_u8(masked2, masked1);

vst1q_u8(dst.as_mut_ptr().offset(i * 2) as *mut _, res1);
vst1q_u8(dst.as_mut_ptr().offset(i * 2 + 16) as *mut _, res2);

src = &src[16..];
i += 16;
}

let i = i as usize;
hex_encode_custom_case_fallback(src, &mut dst[i * 2..], upper_case);
}

#[inline]
fn hex_lower(byte: u8) -> u8 {
TABLE_LOWER[byte as usize]
Expand Down
Loading

0 comments on commit 4acf38e

Please sign in to comment.