diff --git a/src/string_decoder.rs b/src/string_decoder.rs index 03284bf..4b219c4 100644 --- a/src/string_decoder.rs +++ b/src/string_decoder.rs @@ -51,63 +51,340 @@ where { type Output = StringOutput<'t, 'j>; - fn decode(data: &'j [u8], mut index: usize, tape: &'t mut Tape) -> JsonResult<(Self::Output, usize)> { - index += 1; - tape.clear(); - let start = index; - let mut last_escape = start; - let mut found_escape = false; - let mut ascii_only = true; + fn decode(data: &'j [u8], index: usize, tape: &'t mut Tape) -> JsonResult<(Self::Output, usize)> { + #[cfg(target_arch = "x86_64")] + if is_x86_feature_detected!("avx2") { + return unsafe { decode_simd(data, index, tape) }; + } - while let Some(next) = data.get(index) { - match next { - b'"' => { - return if found_escape { - tape.extend_from_slice(&data[last_escape..index]); - index += 1; - let s = to_str(tape, ascii_only, start)?; - Ok((StringOutput::Tape(s), index)) - } else { - let s = to_str(&data[start..index], ascii_only, start)?; - index += 1; - Ok((StringOutput::Data(s), index)) - }; - } - b'\\' => { + #[cfg(target_arch = "aarch64")] + { + return decode_simd(data, index, tape); + } + + #[cfg(not(target_arch = "aarch64"))] + { + return decode_onebyone(data, index, tape); + } + } +} + +#[cfg(not(target_arch = "aarch64"))] +fn decode_onebyone<'j, 't>( + data: &'j [u8], + mut index: usize, + tape: &'t mut Tape, +) -> JsonResult<(StringOutput<'t, 'j>, usize)> +where + 'j: 't, +{ + index += 1; + + let start = index; + let mut last_escape = start; + let mut found_escape = false; + let mut ascii_only = true; + + while let Some(next) = data.get(index) { + match next { + b'"' => { + return if found_escape { + tape.extend_from_slice(&data[last_escape..index]); + index += 1; + let s = to_str(tape, ascii_only, start)?; + Ok((StringOutput::Tape(s), index)) + } else { + let s = to_str(&data[start..index], ascii_only, start)?; + index += 1; + Ok((StringOutput::Data(s), index)) + }; + } + b'\\' => { + if !found_escape { + tape.clear(); found_escape = true; + } + tape.extend_from_slice(&data[last_escape..index]); + index += 1; + if let Some(next_inner) = data.get(index) { + match next_inner { + b'"' | b'\\' | b'/' => tape.push(*next_inner), + b'b' => tape.push(b'\x08'), + b'f' => tape.push(b'\x0C'), + b'n' => tape.push(b'\n'), + b'r' => tape.push(b'\r'), + b't' => tape.push(b'\t'), + b'u' => { + let (c, new_index) = parse_escape(data, index)?; + index = new_index; + tape.extend_from_slice(c.encode_utf8(&mut [0_u8; 4]).as_bytes()); + } + _ => return json_err!(InvalidEscape, index), + } + last_escape = index + 1; + } else { + break; + } + } + // all values below 32 are invalid + next if *next < 32u8 => return json_err!(ControlCharacterWhileParsingString, index), + next if *next >= 128u8 && ascii_only => { + ascii_only = false; + } + _ => (), + } + index += 1; + } + json_err!(EofWhileParsingString, index) +} + +#[cfg(target_arch = "aarch64")] +fn decode_simd<'j, 't>( + data: &'j [u8], + mut index: usize, + tape: &'t mut Tape, +) -> JsonResult<(StringOutput<'t, 'j>, usize)> +where + 'j: 't, +{ + index += 1; + + let start = index; + let mut last_escape = start; + let mut found_escape = false; + let mut ascii_only = true; + + 'simd: { + use std::arch::aarch64::{ + vceqq_u8 as simd_eq, vdupq_n_u8 as simd_duplicate, vld1q_u8 as simd_load, vorrq_u8 as simd_or, *, + }; + + const SIMD_STEP: usize = 16; + + fn is_vector_nonzero(vec: uint8x16_t) -> bool { + unsafe { vmaxvq_u8(vec) != 0 } + } + + unsafe fn simd_is_ascii_non_control(vec: uint8x16_t) -> uint8x16_t { + simd_or(vcltq_u8(vec, vdupq_n_u8(32)), vcgeq_u8(vec, vdupq_n_u8(128))) + } + + let simd_quote = unsafe { simd_duplicate(b'"') }; + let simd_backslash = unsafe { simd_duplicate(b'\\') }; + + for remaining_chunk in data + .get(index..) + .into_iter() + .flat_map(|remaining| remaining.chunks_exact(SIMD_STEP)) + { + let remaining_chunk_v = unsafe { simd_load(remaining_chunk.as_ptr()) }; + + let backslash = unsafe { simd_eq(remaining_chunk_v, simd_backslash) }; + let mask = unsafe { simd_is_ascii_non_control(remaining_chunk_v) }; + let backslash_or_mask = unsafe { simd_or(backslash, mask) }; + + // go slow if backslash or mask found + if is_vector_nonzero(backslash_or_mask) { + break 'simd; + } + + // Compare the remaining chunk with the special characters + let compare_result = unsafe { simd_eq(remaining_chunk_v, simd_quote) }; + + // Check if any element in the comparison result is true + if is_vector_nonzero(compare_result) { + // Found a match, return the index + let j = unsafe { remaining_chunk.iter().position(|&x| x == b'"').unwrap_unchecked() }; + return Ok(( + StringOutput::Data(unsafe { std::str::from_utf8_unchecked(&data[start..index + j]) }), + index + j + 1, + )); + } + + index += remaining_chunk.len(); + } + } + + while let Some(next) = data.get(index) { + match next { + b'"' => { + return if found_escape { tape.extend_from_slice(&data[last_escape..index]); index += 1; - if let Some(next_inner) = data.get(index) { - match next_inner { - b'"' | b'\\' | b'/' => tape.push(*next_inner), - b'b' => tape.push(b'\x08'), - b'f' => tape.push(b'\x0C'), - b'n' => tape.push(b'\n'), - b'r' => tape.push(b'\r'), - b't' => tape.push(b'\t'), - b'u' => { - let (c, new_index) = parse_escape(data, index)?; - index = new_index; - tape.extend_from_slice(c.encode_utf8(&mut [0_u8; 4]).as_bytes()); - } - _ => return json_err!(InvalidEscape, index), + let s = to_str(tape, ascii_only, start)?; + Ok((StringOutput::Tape(s), index)) + } else { + let s = to_str(&data[start..index], ascii_only, start)?; + index += 1; + Ok((StringOutput::Data(s), index)) + }; + } + b'\\' => { + if !found_escape { + tape.clear(); + found_escape = true; + } + tape.extend_from_slice(&data[last_escape..index]); + index += 1; + if let Some(next_inner) = data.get(index) { + match next_inner { + b'"' | b'\\' | b'/' => tape.push(*next_inner), + b'b' => tape.push(b'\x08'), + b'f' => tape.push(b'\x0C'), + b'n' => tape.push(b'\n'), + b'r' => tape.push(b'\r'), + b't' => tape.push(b'\t'), + b'u' => { + let (c, new_index) = parse_escape(data, index)?; + index = new_index; + tape.extend_from_slice(c.encode_utf8(&mut [0_u8; 4]).as_bytes()); } - last_escape = index + 1; - } else { - break; + _ => return json_err!(InvalidEscape, index), } + last_escape = index + 1; + } else { + break; } - // all values below 32 are invalid - next if *next < 32u8 => return json_err!(ControlCharacterWhileParsingString, index), - next if *next >= 128u8 && ascii_only => { - ascii_only = false; + } + // all values below 32 are invalid + next if *next < 32u8 => return json_err!(ControlCharacterWhileParsingString, index), + next if *next >= 128u8 && ascii_only => { + ascii_only = false; + } + _ => (), + } + index += 1; + } + json_err!(EofWhileParsingString, index) +} + +#[cfg(target_arch = "x86_64")] +#[target_feature(enable = "avx2")] +unsafe fn decode_simd<'j, 't>( + data: &'j [u8], + mut index: usize, + tape: &'t mut Tape, +) -> JsonResult<(StringOutput<'t, 'j>, usize)> +where + 'j: 't, +{ + pub use std::arch::x86_64::{ + _mm256_cmpeq_epi8 as simd_eq, _mm256_cmpgt_epi8 as simd_gt, _mm256_loadu_si256 as simd_load, + _mm256_movemask_epi8 as simd_movemask, _mm256_or_si256 as simd_or, _mm256_set1_epi8 as simd_duplicate, *, + }; + + pub const SIMD_STEP: usize = 32; + + index += 1; + + let start = index; + let mut last_escape = start; + let mut found_escape = false; + let mut ascii_only = true; + + let simd_quote = unsafe { simd_duplicate(b'"' as i8) }; + let simd_backslash = unsafe { simd_duplicate(b'\\' as i8) }; + let simd_minus1 = unsafe { simd_duplicate(-1) }; + let simd_31 = unsafe { simd_duplicate(31) }; + + while index < data.len() { + // Safety: on the last chunk this will read slightly past the end of the buffer, but we + // don's care because index += offset will never advance past the end of the buffer. + let remaining_chunk_v = unsafe { simd_load(data.as_ptr().add(index).cast()) }; + + let chunk_size = std::cmp::min(SIMD_STEP, data.len() - index); + let mut offset = chunk_size; + + let backslash_or_quote = unsafe { + simd_or( + simd_eq(remaining_chunk_v, simd_backslash), + simd_eq(remaining_chunk_v, simd_quote), + ) + }; + let backslash_or_quote_mask = unsafe { simd_movemask(backslash_or_quote) }; + + if backslash_or_quote_mask != 0 { + let backslash_or_quote_offset = backslash_or_quote_mask.trailing_zeros() as usize; + if backslash_or_quote_offset < offset { + offset = backslash_or_quote_offset; + } + } + + // signed comparison means that single check is of >31 hits the range + // we desire; signed >31 is equivalent to unsigned >31,<128 + let is_gt_31_or_lt_128 = unsafe { simd_gt(remaining_chunk_v, simd_31) }; + let in_range_char_mask = unsafe { simd_movemask(is_gt_31_or_lt_128) }; + + // Compare the remaining chunk with the special characters + + if in_range_char_mask != -1 { + let ge_0_char_mask = unsafe { simd_movemask(simd_gt(remaining_chunk_v, simd_minus1)) }; + let control_char_mask = !in_range_char_mask & ge_0_char_mask; + if control_char_mask != 0 { + let control_char_offset = control_char_mask.trailing_zeros() as usize; + if control_char_offset < offset { + offset = control_char_offset; } - _ => (), } - index += 1; + if ge_0_char_mask != -1 { + ascii_only = false; + } + } + + index += offset; + + if offset == chunk_size { + continue; + } + + match unsafe { *data.as_ptr().add(index) } { + b'"' => { + return if found_escape { + tape.extend_from_slice(&data[last_escape..index]); + index += 1; + let s = to_str(tape, ascii_only, start)?; + Ok((StringOutput::Tape(s), index)) + } else { + let s = to_str(&data[start..index], ascii_only, start)?; + index += 1; + Ok((StringOutput::Data(s), index)) + } + } + b'\\' => { + if !found_escape { + tape.clear(); + found_escape = true; + } + tape.extend_from_slice(&data[last_escape..index]); + index += 1; + if let Some(next_inner) = data.get(index) { + match next_inner { + b'"' | b'\\' | b'/' => tape.push(*next_inner), + b'b' => tape.push(b'\x08'), + b'f' => tape.push(b'\x0C'), + b'n' => tape.push(b'\n'), + b'r' => tape.push(b'\r'), + b't' => tape.push(b'\t'), + b'u' => { + let (c, new_index) = parse_escape(data, index)?; + index = new_index; + tape.extend_from_slice(c.encode_utf8(&mut [0_u8; 4]).as_bytes()); + } + _ => return json_err!(InvalidEscape, index), + } + last_escape = index + 1; + index += 1; + } else { + break; + } + } + other => { + assert!(other < 32); + return json_err!(ControlCharacterWhileParsingString, index); + } } - json_err!(EofWhileParsingString, index) } + json_err!(EofWhileParsingString, index) } fn to_str(bytes: &[u8], ascii_only: bool, start: usize) -> JsonResult<&str> {