Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add simd string quote seeking #52

Closed
wants to merge 5 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
369 changes: 323 additions & 46 deletions src/string_decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,63 +51,340 @@
{
type Output = StringOutput<'t, 'j>;

fn decode(data: &'j [u8], mut index: usize, tape: &'t mut Tape) -> JsonResult<(Self::Output, usize)> {
index += 1;
tape.clear();
let start = index;
let mut last_escape = start;
let mut found_escape = false;
let mut ascii_only = true;
fn decode(data: &'j [u8], index: usize, tape: &'t mut Tape) -> JsonResult<(Self::Output, usize)> {
#[cfg(target_arch = "x86_64")]
if is_x86_feature_detected!("avx2") {
return unsafe { decode_simd(data, index, tape) };
}

Check warning on line 58 in src/string_decoder.rs

View check run for this annotation

Codecov / codecov/patch

src/string_decoder.rs#L58

Added line #L58 was not covered by tests

while let Some(next) = data.get(index) {
match next {
b'"' => {
return if found_escape {
tape.extend_from_slice(&data[last_escape..index]);
index += 1;
let s = to_str(tape, ascii_only, start)?;
Ok((StringOutput::Tape(s), index))
} else {
let s = to_str(&data[start..index], ascii_only, start)?;
index += 1;
Ok((StringOutput::Data(s), index))
};
}
b'\\' => {
#[cfg(target_arch = "aarch64")]
{
return decode_simd(data, index, tape);
}

#[cfg(not(target_arch = "aarch64"))]
{
return decode_onebyone(data, index, tape);

Check warning on line 67 in src/string_decoder.rs

View check run for this annotation

Codecov / codecov/patch

src/string_decoder.rs#L60-L67

Added lines #L60 - L67 were not covered by tests
}
}
}

#[cfg(not(target_arch = "aarch64"))]
fn decode_onebyone<'j, 't>(
data: &'j [u8],
mut index: usize,
tape: &'t mut Tape,
) -> JsonResult<(StringOutput<'t, 'j>, usize)>
where
'j: 't,
{
index += 1;

let start = index;
let mut last_escape = start;
let mut found_escape = false;
let mut ascii_only = true;

Check warning on line 86 in src/string_decoder.rs

View check run for this annotation

Codecov / codecov/patch

src/string_decoder.rs#L73-L86

Added lines #L73 - L86 were not covered by tests

while let Some(next) = data.get(index) {
match next {

Check warning on line 89 in src/string_decoder.rs

View check run for this annotation

Codecov / codecov/patch

src/string_decoder.rs#L88-L89

Added lines #L88 - L89 were not covered by tests
b'"' => {
return if found_escape {
tape.extend_from_slice(&data[last_escape..index]);
index += 1;
let s = to_str(tape, ascii_only, start)?;
Ok((StringOutput::Tape(s), index))

Check warning on line 95 in src/string_decoder.rs

View check run for this annotation

Codecov / codecov/patch

src/string_decoder.rs#L91-L95

Added lines #L91 - L95 were not covered by tests
} else {
let s = to_str(&data[start..index], ascii_only, start)?;
index += 1;
Ok((StringOutput::Data(s), index))

Check warning on line 99 in src/string_decoder.rs

View check run for this annotation

Codecov / codecov/patch

src/string_decoder.rs#L97-L99

Added lines #L97 - L99 were not covered by tests
};
}
b'\\' => {
if !found_escape {
tape.clear();

Check warning on line 104 in src/string_decoder.rs

View check run for this annotation

Codecov / codecov/patch

src/string_decoder.rs#L103-L104

Added lines #L103 - L104 were not covered by tests
found_escape = true;
}
tape.extend_from_slice(&data[last_escape..index]);
index += 1;
if let Some(next_inner) = data.get(index) {
match next_inner {
b'"' | b'\\' | b'/' => tape.push(*next_inner),
b'b' => tape.push(b'\x08'),
b'f' => tape.push(b'\x0C'),
b'n' => tape.push(b'\n'),
b'r' => tape.push(b'\r'),
b't' => tape.push(b'\t'),
b'u' => {
let (c, new_index) = parse_escape(data, index)?;
index = new_index;
tape.extend_from_slice(c.encode_utf8(&mut [0_u8; 4]).as_bytes());

Check warning on line 120 in src/string_decoder.rs

View check run for this annotation

Codecov / codecov/patch

src/string_decoder.rs#L106-L120

Added lines #L106 - L120 were not covered by tests
}
_ => return json_err!(InvalidEscape, index),

Check warning on line 122 in src/string_decoder.rs

View check run for this annotation

Codecov / codecov/patch

src/string_decoder.rs#L122

Added line #L122 was not covered by tests
}
last_escape = index + 1;

Check warning on line 124 in src/string_decoder.rs

View check run for this annotation

Codecov / codecov/patch

src/string_decoder.rs#L124

Added line #L124 was not covered by tests
} else {
break;

Check warning on line 126 in src/string_decoder.rs

View check run for this annotation

Codecov / codecov/patch

src/string_decoder.rs#L126

Added line #L126 was not covered by tests
}
}
// all values below 32 are invalid
next if *next < 32u8 => return json_err!(ControlCharacterWhileParsingString, index),
next if *next >= 128u8 && ascii_only => {
ascii_only = false;
}
_ => (),

Check warning on line 134 in src/string_decoder.rs

View check run for this annotation

Codecov / codecov/patch

src/string_decoder.rs#L130-L134

Added lines #L130 - L134 were not covered by tests
}
index += 1;

Check warning on line 136 in src/string_decoder.rs

View check run for this annotation

Codecov / codecov/patch

src/string_decoder.rs#L136

Added line #L136 was not covered by tests
}
json_err!(EofWhileParsingString, index)
}

Check warning on line 139 in src/string_decoder.rs

View check run for this annotation

Codecov / codecov/patch

src/string_decoder.rs#L138-L139

Added lines #L138 - L139 were not covered by tests

#[cfg(target_arch = "aarch64")]
fn decode_simd<'j, 't>(
data: &'j [u8],
mut index: usize,
tape: &'t mut Tape,
) -> JsonResult<(StringOutput<'t, 'j>, usize)>
where
'j: 't,
{
index += 1;

let start = index;
let mut last_escape = start;
let mut found_escape = false;
let mut ascii_only = true;

'simd: {
use std::arch::aarch64::{
vceqq_u8 as simd_eq, vdupq_n_u8 as simd_duplicate, vld1q_u8 as simd_load, vorrq_u8 as simd_or, *,
};

const SIMD_STEP: usize = 16;

fn is_vector_nonzero(vec: uint8x16_t) -> bool {
unsafe { vmaxvq_u8(vec) != 0 }
}

unsafe fn simd_is_ascii_non_control(vec: uint8x16_t) -> uint8x16_t {
simd_or(vcltq_u8(vec, vdupq_n_u8(32)), vcgeq_u8(vec, vdupq_n_u8(128)))
}

let simd_quote = unsafe { simd_duplicate(b'"') };
let simd_backslash = unsafe { simd_duplicate(b'\\') };

for remaining_chunk in data
.get(index..)
.into_iter()
.flat_map(|remaining| remaining.chunks_exact(SIMD_STEP))
{
let remaining_chunk_v = unsafe { simd_load(remaining_chunk.as_ptr()) };

let backslash = unsafe { simd_eq(remaining_chunk_v, simd_backslash) };
let mask = unsafe { simd_is_ascii_non_control(remaining_chunk_v) };
let backslash_or_mask = unsafe { simd_or(backslash, mask) };

// go slow if backslash or mask found
if is_vector_nonzero(backslash_or_mask) {
break 'simd;
}

// Compare the remaining chunk with the special characters
let compare_result = unsafe { simd_eq(remaining_chunk_v, simd_quote) };

// Check if any element in the comparison result is true
if is_vector_nonzero(compare_result) {
// Found a match, return the index
let j = unsafe { remaining_chunk.iter().position(|&x| x == b'"').unwrap_unchecked() };
return Ok((
StringOutput::Data(unsafe { std::str::from_utf8_unchecked(&data[start..index + j]) }),
index + j + 1,
));
}

index += remaining_chunk.len();
}
}

while let Some(next) = data.get(index) {
match next {
b'"' => {
return if found_escape {
tape.extend_from_slice(&data[last_escape..index]);
index += 1;
if let Some(next_inner) = data.get(index) {
match next_inner {
b'"' | b'\\' | b'/' => tape.push(*next_inner),
b'b' => tape.push(b'\x08'),
b'f' => tape.push(b'\x0C'),
b'n' => tape.push(b'\n'),
b'r' => tape.push(b'\r'),
b't' => tape.push(b'\t'),
b'u' => {
let (c, new_index) = parse_escape(data, index)?;
index = new_index;
tape.extend_from_slice(c.encode_utf8(&mut [0_u8; 4]).as_bytes());
}
_ => return json_err!(InvalidEscape, index),
let s = to_str(tape, ascii_only, start)?;
Ok((StringOutput::Tape(s), index))
} else {
let s = to_str(&data[start..index], ascii_only, start)?;
index += 1;
Ok((StringOutput::Data(s), index))
};
}
b'\\' => {
if !found_escape {
tape.clear();
found_escape = true;
}
tape.extend_from_slice(&data[last_escape..index]);
index += 1;
if let Some(next_inner) = data.get(index) {
match next_inner {
b'"' | b'\\' | b'/' => tape.push(*next_inner),
b'b' => tape.push(b'\x08'),
b'f' => tape.push(b'\x0C'),
b'n' => tape.push(b'\n'),
b'r' => tape.push(b'\r'),
b't' => tape.push(b'\t'),
b'u' => {
let (c, new_index) = parse_escape(data, index)?;
index = new_index;
tape.extend_from_slice(c.encode_utf8(&mut [0_u8; 4]).as_bytes());
}
last_escape = index + 1;
} else {
break;
_ => return json_err!(InvalidEscape, index),
}
last_escape = index + 1;
} else {
break;
}
// all values below 32 are invalid
next if *next < 32u8 => return json_err!(ControlCharacterWhileParsingString, index),
next if *next >= 128u8 && ascii_only => {
ascii_only = false;
}
// all values below 32 are invalid
next if *next < 32u8 => return json_err!(ControlCharacterWhileParsingString, index),
next if *next >= 128u8 && ascii_only => {
ascii_only = false;
}
_ => (),
}
index += 1;
}
json_err!(EofWhileParsingString, index)
}

#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn decode_simd<'j, 't>(
data: &'j [u8],
mut index: usize,
tape: &'t mut Tape,
) -> JsonResult<(StringOutput<'t, 'j>, usize)>
where
'j: 't,
{
pub use std::arch::x86_64::{
_mm256_cmpeq_epi8 as simd_eq, _mm256_cmpgt_epi8 as simd_gt, _mm256_loadu_si256 as simd_load,
_mm256_movemask_epi8 as simd_movemask, _mm256_or_si256 as simd_or, _mm256_set1_epi8 as simd_duplicate, *,
};

pub const SIMD_STEP: usize = 32;

index += 1;

let start = index;
let mut last_escape = start;
let mut found_escape = false;
let mut ascii_only = true;

let simd_quote = unsafe { simd_duplicate(b'"' as i8) };
let simd_backslash = unsafe { simd_duplicate(b'\\' as i8) };
let simd_minus1 = unsafe { simd_duplicate(-1) };
let simd_31 = unsafe { simd_duplicate(31) };

while index < data.len() {
// Safety: on the last chunk this will read slightly past the end of the buffer, but we
// don's care because index += offset will never advance past the end of the buffer.
let remaining_chunk_v = unsafe { simd_load(data.as_ptr().add(index).cast()) };

let chunk_size = std::cmp::min(SIMD_STEP, data.len() - index);
let mut offset = chunk_size;

let backslash_or_quote = unsafe {
simd_or(
simd_eq(remaining_chunk_v, simd_backslash),
simd_eq(remaining_chunk_v, simd_quote),
)
};
let backslash_or_quote_mask = unsafe { simd_movemask(backslash_or_quote) };

if backslash_or_quote_mask != 0 {
let backslash_or_quote_offset = backslash_or_quote_mask.trailing_zeros() as usize;
if backslash_or_quote_offset < offset {
offset = backslash_or_quote_offset;
}
}

// signed comparison means that single check is of >31 hits the range
// we desire; signed >31 is equivalent to unsigned >31,<128
let is_gt_31_or_lt_128 = unsafe { simd_gt(remaining_chunk_v, simd_31) };
let in_range_char_mask = unsafe { simd_movemask(is_gt_31_or_lt_128) };

// Compare the remaining chunk with the special characters

if in_range_char_mask != -1 {
let ge_0_char_mask = unsafe { simd_movemask(simd_gt(remaining_chunk_v, simd_minus1)) };
let control_char_mask = !in_range_char_mask & ge_0_char_mask;
if control_char_mask != 0 {
let control_char_offset = control_char_mask.trailing_zeros() as usize;
if control_char_offset < offset {
offset = control_char_offset;
}
_ => (),
}
index += 1;
if ge_0_char_mask != -1 {
ascii_only = false;
}
}

index += offset;

if offset == chunk_size {
continue;
}

match unsafe { *data.as_ptr().add(index) } {
b'"' => {
return if found_escape {
tape.extend_from_slice(&data[last_escape..index]);
index += 1;
let s = to_str(tape, ascii_only, start)?;
Ok((StringOutput::Tape(s), index))
} else {
let s = to_str(&data[start..index], ascii_only, start)?;
index += 1;
Ok((StringOutput::Data(s), index))
}
}
b'\\' => {
if !found_escape {
tape.clear();
found_escape = true;
}
tape.extend_from_slice(&data[last_escape..index]);
index += 1;
if let Some(next_inner) = data.get(index) {
match next_inner {
b'"' | b'\\' | b'/' => tape.push(*next_inner),
b'b' => tape.push(b'\x08'),
b'f' => tape.push(b'\x0C'),
b'n' => tape.push(b'\n'),
b'r' => tape.push(b'\r'),
b't' => tape.push(b'\t'),
b'u' => {
let (c, new_index) = parse_escape(data, index)?;
index = new_index;
tape.extend_from_slice(c.encode_utf8(&mut [0_u8; 4]).as_bytes());
}
_ => return json_err!(InvalidEscape, index),

Check warning on line 373 in src/string_decoder.rs

View check run for this annotation

Codecov / codecov/patch

src/string_decoder.rs#L373

Added line #L373 was not covered by tests
}
last_escape = index + 1;
index += 1;
} else {
break;

Check warning on line 378 in src/string_decoder.rs

View check run for this annotation

Codecov / codecov/patch

src/string_decoder.rs#L378

Added line #L378 was not covered by tests
}
}
other => {
assert!(other < 32);
return json_err!(ControlCharacterWhileParsingString, index);
}
}
json_err!(EofWhileParsingString, index)
}
json_err!(EofWhileParsingString, index)
}

fn to_str(bytes: &[u8], ascii_only: bool, start: usize) -> JsonResult<&str> {
Expand Down
Loading