Skip to content

Commit

Permalink
check duplicate object keys (#81)
Browse files Browse the repository at this point in the history
Co-authored-by: David Hewitt <mail@davidhewitt.dev>
  • Loading branch information
samuelcolvin and davidhewitt committed May 16, 2024
1 parent 75699eb commit 1df24cc
Show file tree
Hide file tree
Showing 8 changed files with 197 additions and 51 deletions.
48 changes: 31 additions & 17 deletions crates/jiter-python/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,31 @@
import json

cases = [
('medium_response', Path('../benches/medium_response.json').read_bytes()),
('massive_ints_array', Path('../benches/massive_ints_array.json').read_bytes()),
('array_short_strings', '[{}]'.format(', '.join('"123"' for _ in range(100_000)))),
('object_short_strings', '{%s}' % ', '.join(f'"{i}": "{i}x"' for i in range(100_000))),
('array_short_arrays', '[{}]'.format(', '.join('["a", "b", "c", "d"]' for _ in range(10_000)))),
('one_long_string', json.dumps('x' * 100)),
('one_short_string', b'"foobar"'),
('1m_strings', json.dumps([str(i) for i in range(1_000_000)])),
("medium_response", Path("../jiter/benches/medium_response.json").read_bytes()),
(
"massive_ints_array",
Path("../jiter/benches/massive_ints_array.json").read_bytes(),
),
("array_short_strings", "[{}]".format(", ".join('"123"' for _ in range(100_000)))),
(
"object_short_strings",
"{%s}" % ", ".join(f'"{i}": "{i}x"' for i in range(100_000)),
),
(
"array_short_arrays",
"[{}]".format(", ".join('["a", "b", "c", "d"]' for _ in range(10_000))),
),
("one_long_string", json.dumps("x" * 100)),
("one_short_string", b'"foobar"'),
("1m_strings", json.dumps([str(i) for i in range(1_000_000)])),
]


def run_bench(func, d):
if isinstance(d, str):
d = d.encode()
timer = timeit.Timer(
'func(json_data)', setup='', globals={'func': func, 'json_data': d}
"func(json_data)", setup="", globals={"func": func, "json_data": d}
)
n, t = timer.autorange()
iter_time = t / n
Expand All @@ -31,13 +40,18 @@ def run_bench(func, d):


for name, json_data in cases:
print(f'Case: {name}')
print(f"Case: {name}")
times = [
('orjson', run_bench(lambda d: orjson.loads(d), json_data)),
('jiter-cache', run_bench(lambda d: jiter_python.from_json(d), json_data)),
('jiter', run_bench(lambda d: jiter_python.from_json(d, cache_strings=False), json_data)),
('ujson', run_bench(lambda d: ujson.loads(d), json_data)),
('json', run_bench(lambda d: json.loads(d), json_data)),
("orjson", run_bench(lambda d: orjson.loads(d), json_data)),
("jiter-cache", run_bench(lambda d: jiter_python.from_json(d), json_data)),
(
"jiter",
run_bench(
lambda d: jiter_python.from_json(d, cache_strings=False), json_data
),
),
("ujson", run_bench(lambda d: ujson.loads(d), json_data)),
("json", run_bench(lambda d: json.loads(d), json_data)),
]

times.sort(key=lambda x: x[1])
Expand All @@ -46,5 +60,5 @@ def run_bench(func, d):
print(f'{"package":>12} | {"time µs":>10} | slowdown')
print(f'{"-" * 13}|{"-" * 12}|{"-" * 9}')
for name, time in times:
print(f'{name:>12} | {time * 1_000_000:10.2f} | {time / best:8.2f}')
print('')
print(f"{name:>12} | {time * 1_000_000:10.2f} | {time / best:8.2f}")
print("")
23 changes: 21 additions & 2 deletions crates/jiter-python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,35 @@ use pyo3::prelude::*;

use jiter::{map_json_error, python_parse};

#[pyfunction(signature = (data, *, allow_inf_nan=true, cache_strings=true))]
#[pyfunction(
signature = (
data,
*,
allow_inf_nan=true,
cache_strings=true,
allow_partial=false,
catch_duplicate_keys=false
)
)]
pub fn from_json<'py>(
py: Python<'py>,
data: &[u8],
allow_inf_nan: bool,
cache_strings: bool,
allow_partial: bool,
catch_duplicate_keys: bool,
) -> PyResult<Bound<'py, PyAny>> {
let cache_mode = cache_strings.into();
let json_bytes = data;
python_parse(py, json_bytes, allow_inf_nan, cache_mode, false).map_err(|e| map_json_error(json_bytes, &e))
python_parse(
py,
json_bytes,
allow_inf_nan,
cache_mode,
allow_partial,
catch_duplicate_keys,
)
.map_err(|e| map_json_error(json_bytes, &e))
}

#[pymodule]
Expand Down
4 changes: 3 additions & 1 deletion crates/jiter/benches/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ fn python_parse_numeric(bench: &mut Bencher) {
false,
StringCacheMode::All,
false,
false,
)
.unwrap()
});
Expand All @@ -33,6 +34,7 @@ fn python_parse_other(bench: &mut Bencher) {
false,
StringCacheMode::All,
false,
false,
)
.unwrap()
});
Expand All @@ -47,7 +49,7 @@ fn _python_parse_file(path: &str, bench: &mut Bencher, cache_mode: StringCacheMo

Python::with_gil(|py| {
cache_clear(py);
bench.iter(|| python_parse(py, json_data, false, cache_mode, false).unwrap());
bench.iter(|| python_parse(py, json_data, false, cache_mode, false, false).unwrap());
})
}

Expand Down
6 changes: 5 additions & 1 deletion crates/jiter/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@ use std::fmt;
///
/// Almost all of `JsonErrorType` is copied from [serde_json](https://github.com/serde-rs) so errors match
/// those expected from `serde_json`.
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum JsonErrorType {
/// float value was found where an int was expected
FloatExpectingInt,

/// duplicate keys in an object
DuplicateKey(String),

/// NOTE: all errors from here on are copied from serde_json
/// [src/error.rs](https://github.com/serde-rs/json/blob/v1.0.107/src/error.rs#L236)
/// with `Io` and `Message` removed
Expand Down Expand Up @@ -79,6 +82,7 @@ impl std::fmt::Display for JsonErrorType {
// Messages for enum members copied from serde_json are unchanged
match self {
Self::FloatExpectingInt => f.write_str("float value was found where an int was expected"),
Self::DuplicateKey(s) => write!(f, "Detected duplicate key {s:?}"),
Self::EofWhileParsingList => f.write_str("EOF while parsing a list"),
Self::EofWhileParsingObject => f.write_str("EOF while parsing an object"),
Self::EofWhileParsingString => f.write_str("EOF while parsing a string"),
Expand Down
4 changes: 2 additions & 2 deletions crates/jiter/src/lazy_index_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ where
self.vec.is_empty()
}

pub fn get<Q: ?Sized>(&self, key: &Q) -> Option<&V>
pub fn get<Q>(&self, key: &Q) -> Option<&V>
where
K: Borrow<Q> + PartialEq<Q>,
Q: Hash + Eq,
Q: Hash + Eq + ?Sized,
{
let vec_len = self.vec.len();
// if the vec is longer than the threshold, we use the hashmap for lookups
Expand Down
66 changes: 59 additions & 7 deletions crates/jiter/src/python.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use ahash::AHashSet;
use std::marker::PhantomData;

use pyo3::exceptions::PyValueError;
Expand All @@ -21,6 +22,8 @@ use crate::JsonErrorType;
/// - `json_data`: The JSON data to parse.
/// - `allow_inf_nan`: Whether to allow `(-)Infinity` and `NaN` values.
/// - `cache_strings`: Whether to cache strings to avoid constructing new Python objects,
/// - `allow_partial`: Whether to allow partial JSON data.
/// - `catch_duplicate_keys`: Whether to catch duplicate keys in objects.
/// this should have a significant improvement on performance but increases memory slightly.
///
/// # Returns
Expand All @@ -32,11 +35,27 @@ pub fn python_parse<'py>(
allow_inf_nan: bool,
cache_mode: StringCacheMode,
allow_partial: bool,
catch_duplicate_keys: bool,
) -> JsonResult<Bound<'py, PyAny>> {
macro_rules! ppp {
($string_cache:ident, $key_check:ident) => {
PythonParser::<$string_cache, $key_check>::parse(py, json_data, allow_inf_nan, allow_partial)
};
}

match cache_mode {
StringCacheMode::All => PythonParser::<StringCacheAll>::parse(py, json_data, allow_inf_nan, allow_partial),
StringCacheMode::Keys => PythonParser::<StringCacheKeys>::parse(py, json_data, allow_inf_nan, allow_partial),
StringCacheMode::None => PythonParser::<StringNoCache>::parse(py, json_data, allow_inf_nan, allow_partial),
StringCacheMode::All => match catch_duplicate_keys {
true => ppp!(StringCacheAll, DuplicateKeyCheck),
false => ppp!(StringCacheAll, NoopKeyCheck),
},
StringCacheMode::Keys => match catch_duplicate_keys {
true => ppp!(StringCacheKeys, DuplicateKeyCheck),
false => ppp!(StringCacheKeys, NoopKeyCheck),
},
StringCacheMode::None => match catch_duplicate_keys {
true => ppp!(StringNoCache, DuplicateKeyCheck),
false => ppp!(StringNoCache, NoopKeyCheck),
},
}
}

Expand All @@ -45,16 +64,17 @@ pub fn map_json_error(json_data: &[u8], json_error: &JsonError) -> PyErr {
PyValueError::new_err(json_error.description(json_data))
}

struct PythonParser<'j, StringCache> {
struct PythonParser<'j, StringCache, KeyCheck> {
_string_cache: PhantomData<StringCache>,
_key_check: PhantomData<KeyCheck>,
parser: Parser<'j>,
tape: Tape,
recursion_limit: u8,
allow_inf_nan: bool,
allow_partial: bool,
}

impl<'j, StringCache: StringMaybeCache> PythonParser<'j, StringCache> {
impl<'j, StringCache: StringMaybeCache, KeyCheck: MaybeKeyCheck> PythonParser<'j, StringCache, KeyCheck> {
fn parse<'py>(
py: Python<'py>,
json_data: &[u8],
Expand All @@ -63,6 +83,7 @@ impl<'j, StringCache: StringMaybeCache> PythonParser<'j, StringCache> {
) -> JsonResult<Bound<'py, PyAny>> {
let mut slf = PythonParser {
_string_cache: PhantomData::<StringCache>,
_key_check: PhantomData::<KeyCheck>,
parser: Parser::new(json_data),
tape: Tape::default(),
recursion_limit: DEFAULT_RECURSION_LIMIT,
Expand Down Expand Up @@ -166,13 +187,18 @@ impl<'j, StringCache: StringMaybeCache> PythonParser<'j, StringCache> {
panic!("PyDict_SetItem failed")
}
};
let mut check_keys = KeyCheck::default();
if let Some(first_key) = self.parser.object_first::<StringDecoder>(&mut self.tape)? {
let first_key = StringCache::get_key(py, first_key.as_str(), first_key.ascii_only());
let first_key_s = first_key.as_str();
check_keys.check(first_key_s, self.parser.index)?;
let first_key = StringCache::get_key(py, first_key_s, first_key.ascii_only());
let peek = self.parser.peek()?;
let first_value = self._check_take_value(py, peek)?;
set_item(first_key, first_value);
while let Some(key) = self.parser.object_step::<StringDecoder>(&mut self.tape)? {
let key = StringCache::get_key(py, key.as_str(), key.ascii_only());
let key_s = key.as_str();
check_keys.check(key_s, self.parser.index)?;
let key = StringCache::get_key(py, key_s, key.ascii_only());
let peek = self.parser.peek()?;
let value = self._check_take_value(py, peek)?;
set_item(key, value);
Expand Down Expand Up @@ -209,3 +235,29 @@ impl<'j, StringCache: StringMaybeCache> PythonParser<'j, StringCache> {
r
}
}

trait MaybeKeyCheck: Default {
fn check(&mut self, key: &str, index: usize) -> JsonResult<()>;
}

#[derive(Default)]
struct NoopKeyCheck;

impl MaybeKeyCheck for NoopKeyCheck {
fn check(&mut self, _key: &str, _index: usize) -> JsonResult<()> {
Ok(())
}
}

#[derive(Default)]
struct DuplicateKeyCheck(AHashSet<String>);

impl MaybeKeyCheck for DuplicateKeyCheck {
fn check(&mut self, key: &str, index: usize) -> JsonResult<()> {
if self.0.insert(key.to_owned()) {
Ok(())
} else {
Err(JsonError::new(JsonErrorType::DuplicateKey(key.to_owned()), index))
}
}
}
2 changes: 1 addition & 1 deletion crates/jiter/tests/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ macro_rules! single_expect_ok_or_error {
let position = jiter.error_position(e.index);
// no wrong type errors, so unwrap the json error
let error_type = match e.error_type {
JiterErrorType::JsonError(e) => e,
JiterErrorType::JsonError(ref e) => e,
_ => panic!("unexpected error type: {:?}", e.error_type),
};
let actual_error = format!("{:?} @ {}", error_type, position.short());
Expand Down
Loading

0 comments on commit 1df24cc

Please sign in to comment.