diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a11cc42..e5d8ecf 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -165,9 +165,7 @@ jobs: env: RUST_BACKTRACE: 1 - - run: python crates/jiter-python/bench.py - env: - FAST: 1 + - run: python crates/jiter-python/bench.py --fast - run: coverage-prepare lcov $(python -c 'import jiter.jiter;print(jiter.jiter.__file__)') diff --git a/Makefile b/Makefile index 1e206d1..057c2ab 100644 --- a/Makefile +++ b/Makefile @@ -28,6 +28,14 @@ python-dev: python-test: python-dev pytest crates/jiter-python/tests +.PHONY: python-dev-release +python-dev-release: + maturin develop -m crates/jiter-python/Cargo.toml --release + +.PHONY: python-bench +python-bench: python-dev-release + python crates/jiter-python/bench.py + .PHONY: bench bench: cargo bench -p jiter -F python diff --git a/crates/jiter-python/README.md b/crates/jiter-python/README.md index e13bd14..af2ba3c 100644 --- a/crates/jiter-python/README.md +++ b/crates/jiter-python/README.md @@ -18,7 +18,7 @@ def from_json( cache_mode: Literal[True, False, "all", "keys", "none"] = "all", partial_mode: Literal[True, False, "off", "on", "trailing-strings"] = False, catch_duplicate_keys: bool = False, - lossless_floats: bool = False, + float_mode: Literal["float", "decimal", "lossless-float"] = False, ) -> Any: """ Parse input bytes into a JSON object. @@ -36,7 +36,7 @@ def from_json( - True / 'on' - allow incomplete JSON but discard the last string if it is incomplete - 'trailing-strings' - allow incomplete JSON, and include the last incomplete string in the output catch_duplicate_keys: if True, raise an exception if objects contain the same key multiple times - lossless_floats: if True, preserve full detail on floats using `LosslessFloat` + float_mode: How to return floats: as a `float`, `Decimal` or `LosslessFloat` Returns: Python object built from the JSON input. diff --git a/crates/jiter-python/bench.py b/crates/jiter-python/bench.py index 1dee499..7126db9 100644 --- a/crates/jiter-python/bench.py +++ b/crates/jiter-python/bench.py @@ -5,37 +5,28 @@ import json -FAST = bool(os.getenv('FAST')) -THIS_DIR = Path(__file__).parent - -cases = [ - ("medium_response", (THIS_DIR / "../jiter/benches/medium_response.json").read_bytes()), - ( - "massive_ints_array", - (THIS_DIR / "../jiter/benches/massive_ints_array.json").read_bytes(), - ), - ("array_short_strings", "[{}]".format(", ".join('"123"' for _ in range(100_000)))), - ( - "object_short_strings", - "{%s}" % ", ".join(f'"{i}": "{i}x"' for i in range(100_000)), - ), - ( - "array_short_arrays", - "[{}]".format(", ".join('["a", "b", "c", "d"]' for _ in range(10_000))), - ), - ("one_long_string", json.dumps("x" * 100)), - ("one_short_string", b'"foobar"'), - ("1m_strings", json.dumps([str(i) for i in range(1_000_000)])), -] - - -def run_bench(func, d): +CASES = { + "array_short_strings": "[{}]".format(", ".join('"123"' for _ in range(100_000))), + "object_short_strings": "{%s}" % ", ".join(f'"{i}": "{i}x"' for i in range(100_000)), + "array_short_arrays": "[{}]".format(", ".join('["a", "b", "c", "d"]' for _ in range(10_000))), + "one_long_string": json.dumps("x" * 100), + "one_short_string": b'"foobar"', + "1m_strings": json.dumps([str(i) for i in range(1_000_000)]), +} + +BENCHES_DIR = Path(__file__).parent.parent / "jiter/benches/" + +for p in BENCHES_DIR.glob('*.json'): + CASES[p.stem] = p.read_bytes() + + +def run_bench(func, d, fast: bool): if isinstance(d, str): d = d.encode() timer = timeit.Timer( "func(json_data)", setup="", globals={"func": func, "json_data": d} ) - if FAST: + if fast: return timer.timeit(1) else: n, t = timer.autorange() @@ -85,20 +76,21 @@ def setup_json(): def main(): parser = argparse.ArgumentParser() + parser.add_argument("--case", default="all", choices=[*CASES.keys(), "all"]) + parser.add_argument("--fast", action="store_true", default=False) parser.add_argument( "parsers", nargs="*", default="all", choices=[*PARSERS.keys(), "all"] ) args = parser.parse_args() - if "all" in args.parsers: - args.parsers = [*PARSERS.keys()] + parsers = [*PARSERS.keys()] if "all" in args.parsers else args.parsers + cases = [*CASES.keys()] if args.case == "all" else [args.case] - for name, json_data in cases: + for name in cases: print(f"Case: {name}") - times = [ - (parser, run_bench(PARSERS[parser](), json_data)) for parser in args.parsers - ] + json_data = CASES[name] + times = [(parser, run_bench(PARSERS[parser](), json_data, args.fast)) for parser in parsers] times.sort(key=lambda x: x[1]) best = times[0][1] diff --git a/crates/jiter-python/jiter.pyi b/crates/jiter-python/jiter.pyi index 1fe3d9b..928fa75 100644 --- a/crates/jiter-python/jiter.pyi +++ b/crates/jiter-python/jiter.pyi @@ -9,7 +9,7 @@ def from_json( cache_mode: Literal[True, False, "all", "keys", "none"] = "all", partial_mode: Literal[True, False, "off", "on", "trailing-strings"] = False, catch_duplicate_keys: bool = False, - lossless_floats: bool = False, + float_mode: Literal["float", "decimal", "lossless-float"] = False, ) -> Any: """ Parse input bytes into a JSON object. @@ -27,7 +27,7 @@ def from_json( - True / 'on' - allow incomplete JSON but discard the last string if it is incomplete - 'trailing-strings' - allow incomplete JSON, and include the last incomplete string in the output catch_duplicate_keys: if True, raise an exception if objects contain the same key multiple times - lossless_floats: if True, preserve full detail on floats using `LosslessFloat` + float_mode: How to return floats: as a `float`, `Decimal` or `LosslessFloat` Returns: Python object built from the JSON input. diff --git a/crates/jiter-python/src/lib.rs b/crates/jiter-python/src/lib.rs index 7983311..8488982 100644 --- a/crates/jiter-python/src/lib.rs +++ b/crates/jiter-python/src/lib.rs @@ -19,7 +19,7 @@ pub fn get_jiter_version() -> &'static str { mod jiter_python { use pyo3::prelude::*; - use jiter::{map_json_error, LosslessFloat, PartialMode, PythonParse, StringCacheMode}; + use jiter::{map_json_error, FloatMode, LosslessFloat, PartialMode, PythonParse, StringCacheMode}; use super::get_jiter_version; @@ -33,7 +33,7 @@ mod jiter_python { cache_mode=StringCacheMode::All, partial_mode=PartialMode::Off, catch_duplicate_keys=false, - lossless_floats=false, + float_mode=FloatMode::Float, ) )] pub fn from_json<'py>( @@ -43,14 +43,14 @@ mod jiter_python { cache_mode: StringCacheMode, partial_mode: PartialMode, catch_duplicate_keys: bool, - lossless_floats: bool, + float_mode: FloatMode, ) -> PyResult> { let parse_builder = PythonParse { allow_inf_nan, cache_mode, partial_mode, catch_duplicate_keys, - lossless_floats, + float_mode, }; parse_builder .python_parse(py, json_data) diff --git a/crates/jiter-python/tests/test_jiter.py b/crates/jiter-python/tests/test_jiter.py index c0e6abb..61c61ab 100644 --- a/crates/jiter-python/tests/test_jiter.py +++ b/crates/jiter-python/tests/test_jiter.py @@ -240,13 +240,17 @@ def test_lossless_floats(): assert isinstance(f, float) assert f == 12.3 - f = jiter.from_json(b'12.3', lossless_floats=True) + f = jiter.from_json(b'12.3', float_mode='float') + assert isinstance(f, float) + assert f == 12.3 + + f = jiter.from_json(b'12.3', float_mode='lossless-float') assert isinstance(f, jiter.LosslessFloat) assert str(f) == '12.3' assert float(f) == 12.3 assert f.as_decimal() == Decimal('12.3') - f = jiter.from_json(b'123.456789123456789e45', lossless_floats=True) + f = jiter.from_json(b'123.456789123456789e45', float_mode='lossless-float') assert isinstance(f, jiter.LosslessFloat) assert 123e45 < float(f) < 124e45 assert f.as_decimal() == Decimal('1.23456789123456789E+47') @@ -254,11 +258,40 @@ def test_lossless_floats(): assert str(f) == '123.456789123456789e45' assert repr(f) == 'LosslessFloat(123.456789123456789e45)' + f = jiter.from_json(b'123', float_mode='lossless-float') + assert isinstance(f, int) + assert f == 123 + + with pytest.raises(ValueError, match='expected value at line 1 column 1'): + jiter.from_json(b'wrong', float_mode='lossless-float') + + with pytest.raises(ValueError, match='trailing characters at line 1 column 2'): + jiter.from_json(b'1wrong', float_mode='lossless-float') + + + +def test_decimal_floats(): + f = jiter.from_json(b'12.3') + assert isinstance(f, float) + assert f == 12.3 + + f = jiter.from_json(b'12.3', float_mode='decimal') + assert isinstance(f, Decimal) + assert f == Decimal('12.3') + + f = jiter.from_json(b'123.456789123456789e45', float_mode='decimal') + assert isinstance(f, Decimal) + assert f == Decimal('1.23456789123456789E+47') + + f = jiter.from_json(b'123', float_mode='decimal') + assert isinstance(f, int) + assert f == 123 + + with pytest.raises(ValueError, match='expected value at line 1 column 1'): + jiter.from_json(b'wrong', float_mode='decimal') -def test_lossless_floats_int(): - v = jiter.from_json(b'123', lossless_floats=True) - assert isinstance(v, int) - assert v == 123 + with pytest.raises(ValueError, match='trailing characters at line 1 column 2'): + jiter.from_json(b'1wrong', float_mode='decimal') def test_unicode_roundtrip(): diff --git a/crates/jiter/src/errors.rs b/crates/jiter/src/errors.rs index 2068ec9..23d0c59 100644 --- a/crates/jiter/src/errors.rs +++ b/crates/jiter/src/errors.rs @@ -10,6 +10,9 @@ pub enum JsonErrorType { /// duplicate keys in an object DuplicateKey(String), + /// happens when getting the `Decimal` type or constructing a decimal fails + InternalError(String), + /// NOTE: all errors from here on are copied from serde_json /// [src/error.rs](https://github.com/serde-rs/json/blob/v1.0.107/src/error.rs#L236) /// with `Io` and `Message` removed @@ -81,6 +84,7 @@ impl std::fmt::Display for JsonErrorType { match self { Self::FloatExpectingInt => f.write_str("float value was found where an int was expected"), Self::DuplicateKey(s) => write!(f, "Detected duplicate key {s:?}"), + Self::InternalError(s) => write!(f, "Internal error: {s:?}"), Self::EofWhileParsingList => f.write_str("EOF while parsing a list"), Self::EofWhileParsingObject => f.write_str("EOF while parsing an object"), Self::EofWhileParsingString => f.write_str("EOF while parsing a string"), diff --git a/crates/jiter/src/lib.rs b/crates/jiter/src/lib.rs index b9eea50..dff79a6 100644 --- a/crates/jiter/src/lib.rs +++ b/crates/jiter/src/lib.rs @@ -24,7 +24,7 @@ pub use parse::Peek; pub use value::{JsonArray, JsonObject, JsonValue}; #[cfg(feature = "python")] -pub use py_lossless_float::LosslessFloat; +pub use py_lossless_float::{FloatMode, LosslessFloat}; #[cfg(feature = "python")] pub use py_string_cache::{cache_clear, cache_usage, cached_py_string, pystring_fast_new, StringCacheMode}; #[cfg(feature = "python")] diff --git a/crates/jiter/src/py_lossless_float.rs b/crates/jiter/src/py_lossless_float.rs index fed5c50..1d1f14b 100644 --- a/crates/jiter/src/py_lossless_float.rs +++ b/crates/jiter/src/py_lossless_float.rs @@ -1,10 +1,40 @@ -use pyo3::exceptions::PyValueError; +use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; use pyo3::sync::GILOnceCell; use pyo3::types::PyType; use crate::Jiter; +#[derive(Debug, Clone, Copy)] +pub enum FloatMode { + Float, + Decimal, + LosslessFloat, +} + +impl Default for FloatMode { + fn default() -> Self { + Self::Float + } +} + +const FLOAT_ERROR: &str = "Invalid float mode, should be `'float'`, `'decimal'` or `'lossless-float'`"; + +impl<'py> FromPyObject<'py> for FloatMode { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + if let Ok(str_mode) = ob.extract::<&str>() { + match str_mode { + "float" => Ok(Self::Float), + "decimal" => Ok(Self::Decimal), + "lossless-float" => Ok(Self::LosslessFloat), + _ => Err(PyValueError::new_err(FLOAT_ERROR)), + } + } else { + Err(PyTypeError::new_err(FLOAT_ERROR)) + } + } +} + /// Represents a float from JSON, by holding the underlying bytes representing a float from JSON. #[derive(Debug, Clone)] #[pyclass(module = "jiter")] diff --git a/crates/jiter/src/python.rs b/crates/jiter/src/python.rs index 0bb49d1..e8cc0dc 100644 --- a/crates/jiter/src/python.rs +++ b/crates/jiter/src/python.rs @@ -12,6 +12,7 @@ use smallvec::SmallVec; use crate::errors::{json_err, json_error, JsonError, JsonResult, DEFAULT_RECURSION_LIMIT}; use crate::number_decoder::{AbstractNumberDecoder, NumberAny, NumberRange}; use crate::parse::{Parser, Peek}; +use crate::py_lossless_float::{get_decimal_type, FloatMode}; use crate::py_string_cache::{StringCacheAll, StringCacheKeys, StringCacheMode, StringMaybeCache, StringNoCache}; use crate::string_decoder::{StringDecoder, Tape}; use crate::{JsonErrorType, LosslessFloat}; @@ -27,8 +28,9 @@ pub struct PythonParse { pub partial_mode: PartialMode, /// Whether to catch duplicate keys in objects. pub catch_duplicate_keys: bool, - /// Whether to preserve full detail on floats using [`LosslessFloat`] - pub lossless_floats: bool, + /// How to return floats: as a `float` (`'float'`), `Decimal` (`'decimal'`) or + /// [`LosslessFloat`] (`'lossless-float'`) + pub float_mode: FloatMode, } impl PythonParse { @@ -56,11 +58,13 @@ impl PythonParse { } macro_rules! ppp_group { ($string_cache:ident) => { - match (self.catch_duplicate_keys, self.lossless_floats) { - (true, true) => ppp!($string_cache, DuplicateKeyCheck, ParseNumberLossless), - (true, false) => ppp!($string_cache, DuplicateKeyCheck, ParseNumberLossy), - (false, true) => ppp!($string_cache, NoopKeyCheck, ParseNumberLossless), - (false, false) => ppp!($string_cache, NoopKeyCheck, ParseNumberLossy), + match (self.catch_duplicate_keys, self.float_mode) { + (true, FloatMode::Float) => ppp!($string_cache, DuplicateKeyCheck, ParseNumberLossy), + (true, FloatMode::Decimal) => ppp!($string_cache, DuplicateKeyCheck, ParseNumberDecimal), + (true, FloatMode::LosslessFloat) => ppp!($string_cache, DuplicateKeyCheck, ParseNumberLossless), + (false, FloatMode::Float) => ppp!($string_cache, NoopKeyCheck, ParseNumberLossy), + (false, FloatMode::Decimal) => ppp!($string_cache, NoopKeyCheck, ParseNumberDecimal), + (false, FloatMode::LosslessFloat) => ppp!($string_cache, NoopKeyCheck, ParseNumberLossless), } }; } @@ -378,3 +382,42 @@ impl MaybeParseNumber for ParseNumberLossless { } } } + +struct ParseNumberDecimal; + +impl MaybeParseNumber for ParseNumberDecimal { + fn parse_number<'py>( + py: Python<'py>, + parser: &mut Parser, + peek: Peek, + allow_inf_nan: bool, + ) -> JsonResult> { + match parser.consume_number::(peek.into_inner(), allow_inf_nan) { + Ok(number_range) => { + let bytes = parser.slice(number_range.range).unwrap(); + if number_range.is_int { + let obj = NumberAny::decode(bytes, 0, peek.into_inner(), allow_inf_nan)? + .0 + .to_object(py); + Ok(obj.into_bound(py)) + } else { + let decimal_type = get_decimal_type(py) + .map_err(|e| JsonError::new(JsonErrorType::InternalError(e.to_string()), parser.index))?; + // SAFETY: NumberRange::decode has already confirmed that bytes are a valid JSON number, + // and therefore valid str + let float_str = unsafe { std::str::from_utf8_unchecked(bytes) }; + decimal_type + .call1((float_str,)) + .map_err(|e| JsonError::new(JsonErrorType::InternalError(e.to_string()), parser.index)) + } + } + Err(e) => { + if !peek.is_num() { + Err(json_error!(ExpectedSomeValue, parser.index)) + } else { + Err(e) + } + } + } + } +}