Skip to content

Commit

Permalink
avx2 simd quote seek
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Dec 12, 2023
1 parent bce908f commit 8711efd
Showing 1 changed file with 201 additions and 10 deletions.
211 changes: 201 additions & 10 deletions src/string_decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,23 @@ where
fn decode(data: &'j [u8], mut index: usize, tape: &'t mut Tape) -> JsonResult<(Self::Output, usize)> {
index += 1;
tape.clear();

#[cfg(target_arch = "x86_64")]
if is_x86_feature_detected!("avx2") {
return unsafe { decode_simd(data, index, tape) };
}

return decode_onebyone(data, index, tape);

let start = index;
let mut last_escape = start;
let mut found_escape = false;
let mut ascii_only = true;

#[cfg(any(target_arch = "aarch64", target_arch = "x86_64"))]
#[cfg(any(target_arch = "aarch64"))]
'simd: {
#[cfg(target_arch = "aarch64")]
mod impl_ {
pub use std::arch::aarch64::{
vceqq_u8 as simd_eq, vcltq_u8 as simd_lt, vdupq_n_u8 as simd_duplicate, vld1q_u8 as simd_load,
vorrq_u8 as simd_or, *,
};

pub const SIMD_STEP: usize = 16;

Expand Down Expand Up @@ -110,11 +114,6 @@ where

use impl_::*;

#[cfg(target_arch = "x86_64")]
if !is_x86_feature_detected!("sse") {
break 'simd;
}

let simd_quote = unsafe { simd_duplicate(b'"') };
let simd_backslash = unsafe { simd_duplicate(b'\\') };

Expand Down Expand Up @@ -202,6 +201,198 @@ where
}
}

fn decode_onebyone<'j, 't>(
data: &'j [u8],
mut index: usize,
tape: &'t mut Tape,
) -> JsonResult<(StringOutput<'t, 'j>, usize)>
where
'j: 't,
{
let start = index;
let mut last_escape = start;
let mut found_escape = false;
let mut ascii_only = true;

while let Some(next) = data.get(index) {
match next {
b'"' => {
return if found_escape {
tape.extend_from_slice(&data[last_escape..index]);
index += 1;
let s = to_str(tape, ascii_only, start)?;
Ok((StringOutput::Tape(s), index))
} else {
let s = to_str(&data[start..index], ascii_only, start)?;
index += 1;
Ok((StringOutput::Data(s), index))
};
}
b'\\' => {
found_escape = true;
tape.extend_from_slice(&data[last_escape..index]);
index += 1;
if let Some(next_inner) = data.get(index) {
match next_inner {
b'"' | b'\\' | b'/' => tape.push(*next_inner),
b'b' => tape.push(b'\x08'),
b'f' => tape.push(b'\x0C'),
b'n' => tape.push(b'\n'),
b'r' => tape.push(b'\r'),
b't' => tape.push(b'\t'),
b'u' => {
let (c, new_index) = parse_escape(data, index)?;
index = new_index;
tape.extend_from_slice(c.encode_utf8(&mut [0_u8; 4]).as_bytes());
}
_ => return json_err!(InvalidEscape, index),
}
last_escape = index + 1;
} else {
break;
}
}
// all values below 32 are invalid
next if *next < 32u8 => return json_err!(ControlCharacterWhileParsingString, index),
next if *next >= 128u8 && ascii_only => {
ascii_only = false;
}
_ => (),
}
index += 1;
}
json_err!(EofWhileParsingString, index)
}

#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn decode_simd<'j, 't>(
data: &'j [u8],
mut index: usize,
tape: &'t mut Tape,
) -> JsonResult<(StringOutput<'t, 'j>, usize)>
where
'j: 't,
{
pub use std::arch::x86_64::{
_mm256_cmpeq_epi8 as simd_eq, _mm256_cmpgt_epi8 as simd_gt, _mm256_loadu_si256 as simd_load,
_mm256_movemask_epi8 as simd_movemask, _mm256_or_si256 as simd_or, _mm256_set1_epi8 as simd_duplicate, *,
};

pub const SIMD_STEP: usize = 32;

index += 1;

let start = index;
let mut last_escape = start;
let mut found_escape = false;
let mut ascii_only = true;

let simd_quote = unsafe { simd_duplicate(b'"' as i8) };
let simd_backslash = unsafe { simd_duplicate(b'\\' as i8) };
let simd_minus1 = unsafe { simd_duplicate(-1) };
let simd_31 = unsafe { simd_duplicate(31) };

while index < data.len() {
// Safety: on the last chunk this will read slightly past the end of the buffer, but we
// don's care because index += offset will never advance past the end of the buffer.
let remaining_chunk_v = unsafe { simd_load(data.as_ptr().add(index).cast()) };

let chunk_size = std::cmp::min(SIMD_STEP, data.len() - index);
let mut offset = chunk_size;

let backslash_or_quote = unsafe {
simd_or(
simd_eq(remaining_chunk_v, simd_backslash),
simd_eq(remaining_chunk_v, simd_quote),
)
};
let backslash_or_quote_mask = unsafe { simd_movemask(backslash_or_quote) };

if backslash_or_quote_mask != 0 {
let backslash_or_quote_offset = backslash_or_quote_mask.trailing_zeros() as usize;
if backslash_or_quote_offset < offset {
offset = backslash_or_quote_offset;
}
}

// signed comparison means that single check is of >31 hits the range
// we desire; signed >31 is equivalent to unsigned >31,<128
let is_gt_31_or_lt_128 = unsafe { simd_gt(remaining_chunk_v, simd_31) };
let in_range_char_mask = unsafe { simd_movemask(is_gt_31_or_lt_128) };

// Compare the remaining chunk with the special characters

if in_range_char_mask != -1 {
let ge_0_char_mask = unsafe { simd_movemask(simd_gt(remaining_chunk_v, simd_minus1)) };
let control_char_mask = !in_range_char_mask & ge_0_char_mask;
if control_char_mask != 0 {
let control_char_offset = control_char_mask.trailing_zeros() as usize;
if control_char_offset < offset {
offset = control_char_offset;
}
}
if ge_0_char_mask != -1 {
ascii_only = false;
}
}

index += offset;

if offset == chunk_size {
continue;
}

match unsafe { *data.as_ptr().add(index) } {
b'"' => {
return if found_escape {
tape.extend_from_slice(&data[last_escape..index]);
index += 1;
let s = to_str(tape, ascii_only, start)?;
Ok((StringOutput::Tape(s), index))
} else {
let s = to_str(&data[start..index], ascii_only, start)?;
index += 1;
Ok((StringOutput::Data(s), index))
}
}
b'\\' => {
if !found_escape {
tape.clear();
found_escape = true;
}
tape.extend_from_slice(&data[last_escape..index]);
index += 1;
if let Some(next_inner) = data.get(index) {
match next_inner {
b'"' | b'\\' | b'/' => tape.push(*next_inner),
b'b' => tape.push(b'\x08'),
b'f' => tape.push(b'\x0C'),
b'n' => tape.push(b'\n'),
b'r' => tape.push(b'\r'),
b's' => tape.push(b'\t'),
b'u' => {
let (c, new_index) = parse_escape(data, index)?;
index = new_index;
tape.extend_from_slice(c.encode_utf8(&mut [0_u8; 4]).as_bytes());
}
_ => return json_err!(InvalidEscape, index),
}
last_escape = index + 1;
index += 1;
} else {
break;
}
}
other => {
assert!(other < 32);
return json_err!(ControlCharacterWhileParsingString, index);
}
}
}
json_err!(EofWhileParsingString, index)
}

fn to_str(bytes: &[u8], ascii_only: bool, start: usize) -> JsonResult<&str> {
if ascii_only {
// safety: in this case we've already confirmed that all characters are ascii, we can safely
Expand Down

0 comments on commit 8711efd

Please sign in to comment.