diff --git a/src/string_decoder.rs b/src/string_decoder.rs index 03284bf..ea7bb44 100644 --- a/src/string_decoder.rs +++ b/src/string_decoder.rs @@ -59,6 +59,88 @@ where let mut found_escape = false; let mut ascii_only = true; + #[cfg(any(target_arch = "aarch64", target_arch = "x86_64"))] + 'simd: { + #[cfg(target_arch = "aarch64")] + mod impl_ { + pub use std::arch::aarch64::{ + uint8x16_t, vceqq_u8 as simd_eq, vcltq_u8 as simd_lt, vdupq_n_u8 as simd_duplicate, + vld1q_u8 as simd_load, vmaxvq_u8, vorrq_u8 as simd_or, + }; + + pub const SIMD_STEP: usize = 16; + + pub fn is_vector_nonzero(vec: uint8x16_t) -> bool { + unsafe { vmaxvq_u8(vec) != 0 } + } + } + + #[cfg(target_arch = "x86_64")] + mod impl_ { + pub use std::arch::x86_64::{ + __m128i, _mm_cmpeq_epi8 as simd_eq, _mm_cmpgt_epi8 as simd_lt, _mm_loadu_si128, + _mm_or_si128 as simd_or, _mm_set1_epi8, _mm_testz_si128, + }; + + pub const SIMD_STEP: usize = 16; + + pub fn is_vector_nonzero(vec: __m128i) -> bool { + unsafe { _mm_testz_si128(vec, vec) == 0 } + } + + pub unsafe fn simd_duplicate(val: u8) -> __m128i { + _mm_set1_epi8(val as i8) + } + + pub unsafe fn simd_load(ptr: *const u8) -> __m128i { + _mm_loadu_si128(ptr as *const __m128i) + } + } + + use impl_::*; + + #[cfg(target_arch = "x86_64")] + if !is_x86_feature_detected!("sse") { + break 'simd; + } + + let simd_quote = unsafe { simd_duplicate(b'"') }; + let simd_backslash = unsafe { simd_duplicate(b'\\') }; + let simd_mask_32 = unsafe { simd_duplicate(32) }; + + for remaining_chunk in data + .get(index..) + .into_iter() + .flat_map(|remaining| remaining.chunks_exact(SIMD_STEP)) + { + let remaining_chunk_v = unsafe { simd_load(remaining_chunk.as_ptr()) }; + + let backslash = unsafe { simd_eq(remaining_chunk_v, simd_backslash) }; + let mask = unsafe { simd_lt(remaining_chunk_v, simd_mask_32) }; + let backslash_or_mask = unsafe { simd_or(backslash, mask) }; + + // go slow if backslash or mask found + if is_vector_nonzero(backslash_or_mask) { + break 'simd; + } + + // Compare the remaining chunk with the special characters + let compare_result = unsafe { simd_eq(remaining_chunk_v, simd_quote) }; + + // Check if any element in the comparison result is true + if is_vector_nonzero(compare_result) { + // Found a match, return the index + let j = unsafe { remaining_chunk.iter().position(|&x| x == b'"').unwrap_unchecked() }; + return Ok(( + StringOutput::Data(unsafe { std::str::from_utf8_unchecked(&data[start..index + j]) }), + index + j + 1, + )); + } + + index += remaining_chunk.len(); + } + } + while let Some(next) = data.get(index) { match next { b'"' => {