Skip to content

Commit

Permalink
support float_mode='decimal' (#116)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin authored Jul 2, 2024
1 parent b613634 commit 6c0a187
Show file tree
Hide file tree
Showing 11 changed files with 166 additions and 58 deletions.
4 changes: 1 addition & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,7 @@ jobs:
env:
RUST_BACKTRACE: 1

- run: python crates/jiter-python/bench.py
env:
FAST: 1
- run: python crates/jiter-python/bench.py --fast

- run: coverage-prepare lcov $(python -c 'import jiter.jiter;print(jiter.jiter.__file__)')

Expand Down
8 changes: 8 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ python-dev:
python-test: python-dev
pytest crates/jiter-python/tests

.PHONY: python-dev-release
python-dev-release:
maturin develop -m crates/jiter-python/Cargo.toml --release

.PHONY: python-bench
python-bench: python-dev-release
python crates/jiter-python/bench.py

.PHONY: bench
bench:
cargo bench -p jiter -F python
Expand Down
4 changes: 2 additions & 2 deletions crates/jiter-python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def from_json(
cache_mode: Literal[True, False, "all", "keys", "none"] = "all",
partial_mode: Literal[True, False, "off", "on", "trailing-strings"] = False,
catch_duplicate_keys: bool = False,
lossless_floats: bool = False,
float_mode: Literal["float", "decimal", "lossless-float"] = False,
) -> Any:
"""
Parse input bytes into a JSON object.
Expand All @@ -36,7 +36,7 @@ def from_json(
- True / 'on' - allow incomplete JSON but discard the last string if it is incomplete
- 'trailing-strings' - allow incomplete JSON, and include the last incomplete string in the output
catch_duplicate_keys: if True, raise an exception if objects contain the same key multiple times
lossless_floats: if True, preserve full detail on floats using `LosslessFloat`
float_mode: How to return floats: as a `float`, `Decimal` or `LosslessFloat`
Returns:
Python object built from the JSON input.
Expand Down
56 changes: 24 additions & 32 deletions crates/jiter-python/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,28 @@

import json

FAST = bool(os.getenv('FAST'))
THIS_DIR = Path(__file__).parent

cases = [
("medium_response", (THIS_DIR / "../jiter/benches/medium_response.json").read_bytes()),
(
"massive_ints_array",
(THIS_DIR / "../jiter/benches/massive_ints_array.json").read_bytes(),
),
("array_short_strings", "[{}]".format(", ".join('"123"' for _ in range(100_000)))),
(
"object_short_strings",
"{%s}" % ", ".join(f'"{i}": "{i}x"' for i in range(100_000)),
),
(
"array_short_arrays",
"[{}]".format(", ".join('["a", "b", "c", "d"]' for _ in range(10_000))),
),
("one_long_string", json.dumps("x" * 100)),
("one_short_string", b'"foobar"'),
("1m_strings", json.dumps([str(i) for i in range(1_000_000)])),
]


def run_bench(func, d):
CASES = {
"array_short_strings": "[{}]".format(", ".join('"123"' for _ in range(100_000))),
"object_short_strings": "{%s}" % ", ".join(f'"{i}": "{i}x"' for i in range(100_000)),
"array_short_arrays": "[{}]".format(", ".join('["a", "b", "c", "d"]' for _ in range(10_000))),
"one_long_string": json.dumps("x" * 100),
"one_short_string": b'"foobar"',
"1m_strings": json.dumps([str(i) for i in range(1_000_000)]),
}

BENCHES_DIR = Path(__file__).parent.parent / "jiter/benches/"

for p in BENCHES_DIR.glob('*.json'):
CASES[p.stem] = p.read_bytes()


def run_bench(func, d, fast: bool):
if isinstance(d, str):
d = d.encode()
timer = timeit.Timer(
"func(json_data)", setup="", globals={"func": func, "json_data": d}
)
if FAST:
if fast:
return timer.timeit(1)
else:
n, t = timer.autorange()
Expand Down Expand Up @@ -85,20 +76,21 @@ def setup_json():

def main():
parser = argparse.ArgumentParser()
parser.add_argument("--case", default="all", choices=[*CASES.keys(), "all"])
parser.add_argument("--fast", action="store_true", default=False)
parser.add_argument(
"parsers", nargs="*", default="all", choices=[*PARSERS.keys(), "all"]
)
args = parser.parse_args()

if "all" in args.parsers:
args.parsers = [*PARSERS.keys()]
parsers = [*PARSERS.keys()] if "all" in args.parsers else args.parsers
cases = [*CASES.keys()] if args.case == "all" else [args.case]

for name, json_data in cases:
for name in cases:
print(f"Case: {name}")

times = [
(parser, run_bench(PARSERS[parser](), json_data)) for parser in args.parsers
]
json_data = CASES[name]
times = [(parser, run_bench(PARSERS[parser](), json_data, args.fast)) for parser in parsers]

times.sort(key=lambda x: x[1])
best = times[0][1]
Expand Down
4 changes: 2 additions & 2 deletions crates/jiter-python/jiter.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def from_json(
cache_mode: Literal[True, False, "all", "keys", "none"] = "all",
partial_mode: Literal[True, False, "off", "on", "trailing-strings"] = False,
catch_duplicate_keys: bool = False,
lossless_floats: bool = False,
float_mode: Literal["float", "decimal", "lossless-float"] = False,
) -> Any:
"""
Parse input bytes into a JSON object.
Expand All @@ -27,7 +27,7 @@ def from_json(
- True / 'on' - allow incomplete JSON but discard the last string if it is incomplete
- 'trailing-strings' - allow incomplete JSON, and include the last incomplete string in the output
catch_duplicate_keys: if True, raise an exception if objects contain the same key multiple times
lossless_floats: if True, preserve full detail on floats using `LosslessFloat`
float_mode: How to return floats: as a `float`, `Decimal` or `LosslessFloat`
Returns:
Python object built from the JSON input.
Expand Down
8 changes: 4 additions & 4 deletions crates/jiter-python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub fn get_jiter_version() -> &'static str {
mod jiter_python {
use pyo3::prelude::*;

use jiter::{map_json_error, LosslessFloat, PartialMode, PythonParse, StringCacheMode};
use jiter::{map_json_error, FloatMode, LosslessFloat, PartialMode, PythonParse, StringCacheMode};

use super::get_jiter_version;

Expand All @@ -33,7 +33,7 @@ mod jiter_python {
cache_mode=StringCacheMode::All,
partial_mode=PartialMode::Off,
catch_duplicate_keys=false,
lossless_floats=false,
float_mode=FloatMode::Float,
)
)]
pub fn from_json<'py>(
Expand All @@ -43,14 +43,14 @@ mod jiter_python {
cache_mode: StringCacheMode,
partial_mode: PartialMode,
catch_duplicate_keys: bool,
lossless_floats: bool,
float_mode: FloatMode,
) -> PyResult<Bound<'py, PyAny>> {
let parse_builder = PythonParse {
allow_inf_nan,
cache_mode,
partial_mode,
catch_duplicate_keys,
lossless_floats,
float_mode,
};
parse_builder
.python_parse(py, json_data)
Expand Down
45 changes: 39 additions & 6 deletions crates/jiter-python/tests/test_jiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,25 +240,58 @@ def test_lossless_floats():
assert isinstance(f, float)
assert f == 12.3

f = jiter.from_json(b'12.3', lossless_floats=True)
f = jiter.from_json(b'12.3', float_mode='float')
assert isinstance(f, float)
assert f == 12.3

f = jiter.from_json(b'12.3', float_mode='lossless-float')
assert isinstance(f, jiter.LosslessFloat)
assert str(f) == '12.3'
assert float(f) == 12.3
assert f.as_decimal() == Decimal('12.3')

f = jiter.from_json(b'123.456789123456789e45', lossless_floats=True)
f = jiter.from_json(b'123.456789123456789e45', float_mode='lossless-float')
assert isinstance(f, jiter.LosslessFloat)
assert 123e45 < float(f) < 124e45
assert f.as_decimal() == Decimal('1.23456789123456789E+47')
assert bytes(f) == b'123.456789123456789e45'
assert str(f) == '123.456789123456789e45'
assert repr(f) == 'LosslessFloat(123.456789123456789e45)'

f = jiter.from_json(b'123', float_mode='lossless-float')
assert isinstance(f, int)
assert f == 123

with pytest.raises(ValueError, match='expected value at line 1 column 1'):
jiter.from_json(b'wrong', float_mode='lossless-float')

with pytest.raises(ValueError, match='trailing characters at line 1 column 2'):
jiter.from_json(b'1wrong', float_mode='lossless-float')



def test_decimal_floats():
f = jiter.from_json(b'12.3')
assert isinstance(f, float)
assert f == 12.3

f = jiter.from_json(b'12.3', float_mode='decimal')
assert isinstance(f, Decimal)
assert f == Decimal('12.3')

f = jiter.from_json(b'123.456789123456789e45', float_mode='decimal')
assert isinstance(f, Decimal)
assert f == Decimal('1.23456789123456789E+47')

f = jiter.from_json(b'123', float_mode='decimal')
assert isinstance(f, int)
assert f == 123

with pytest.raises(ValueError, match='expected value at line 1 column 1'):
jiter.from_json(b'wrong', float_mode='decimal')

def test_lossless_floats_int():
v = jiter.from_json(b'123', lossless_floats=True)
assert isinstance(v, int)
assert v == 123
with pytest.raises(ValueError, match='trailing characters at line 1 column 2'):
jiter.from_json(b'1wrong', float_mode='decimal')


def test_unicode_roundtrip():
Expand Down
4 changes: 4 additions & 0 deletions crates/jiter/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ pub enum JsonErrorType {
/// duplicate keys in an object
DuplicateKey(String),

/// happens when getting the `Decimal` type or constructing a decimal fails
InternalError(String),

/// NOTE: all errors from here on are copied from serde_json
/// [src/error.rs](https://github.com/serde-rs/json/blob/v1.0.107/src/error.rs#L236)
/// with `Io` and `Message` removed
Expand Down Expand Up @@ -81,6 +84,7 @@ impl std::fmt::Display for JsonErrorType {
match self {
Self::FloatExpectingInt => f.write_str("float value was found where an int was expected"),
Self::DuplicateKey(s) => write!(f, "Detected duplicate key {s:?}"),
Self::InternalError(s) => write!(f, "Internal error: {s:?}"),
Self::EofWhileParsingList => f.write_str("EOF while parsing a list"),
Self::EofWhileParsingObject => f.write_str("EOF while parsing an object"),
Self::EofWhileParsingString => f.write_str("EOF while parsing a string"),
Expand Down
2 changes: 1 addition & 1 deletion crates/jiter/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub use parse::Peek;
pub use value::{JsonArray, JsonObject, JsonValue};

#[cfg(feature = "python")]
pub use py_lossless_float::LosslessFloat;
pub use py_lossless_float::{FloatMode, LosslessFloat};
#[cfg(feature = "python")]
pub use py_string_cache::{cache_clear, cache_usage, cached_py_string, pystring_fast_new, StringCacheMode};
#[cfg(feature = "python")]
Expand Down
32 changes: 31 additions & 1 deletion crates/jiter/src/py_lossless_float.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,40 @@
use pyo3::exceptions::PyValueError;
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::sync::GILOnceCell;
use pyo3::types::PyType;

use crate::Jiter;

#[derive(Debug, Clone, Copy)]
pub enum FloatMode {
Float,
Decimal,
LosslessFloat,
}

impl Default for FloatMode {
fn default() -> Self {
Self::Float
}
}

const FLOAT_ERROR: &str = "Invalid float mode, should be `'float'`, `'decimal'` or `'lossless-float'`";

impl<'py> FromPyObject<'py> for FloatMode {
fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
if let Ok(str_mode) = ob.extract::<&str>() {
match str_mode {
"float" => Ok(Self::Float),
"decimal" => Ok(Self::Decimal),
"lossless-float" => Ok(Self::LosslessFloat),
_ => Err(PyValueError::new_err(FLOAT_ERROR)),
}
} else {
Err(PyTypeError::new_err(FLOAT_ERROR))
}
}
}

/// Represents a float from JSON, by holding the underlying bytes representing a float from JSON.
#[derive(Debug, Clone)]
#[pyclass(module = "jiter")]
Expand Down
57 changes: 50 additions & 7 deletions crates/jiter/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use smallvec::SmallVec;
use crate::errors::{json_err, json_error, JsonError, JsonResult, DEFAULT_RECURSION_LIMIT};
use crate::number_decoder::{AbstractNumberDecoder, NumberAny, NumberRange};
use crate::parse::{Parser, Peek};
use crate::py_lossless_float::{get_decimal_type, FloatMode};
use crate::py_string_cache::{StringCacheAll, StringCacheKeys, StringCacheMode, StringMaybeCache, StringNoCache};
use crate::string_decoder::{StringDecoder, Tape};
use crate::{JsonErrorType, LosslessFloat};
Expand All @@ -27,8 +28,9 @@ pub struct PythonParse {
pub partial_mode: PartialMode,
/// Whether to catch duplicate keys in objects.
pub catch_duplicate_keys: bool,
/// Whether to preserve full detail on floats using [`LosslessFloat`]
pub lossless_floats: bool,
/// How to return floats: as a `float` (`'float'`), `Decimal` (`'decimal'`) or
/// [`LosslessFloat`] (`'lossless-float'`)
pub float_mode: FloatMode,
}

impl PythonParse {
Expand Down Expand Up @@ -56,11 +58,13 @@ impl PythonParse {
}
macro_rules! ppp_group {
($string_cache:ident) => {
match (self.catch_duplicate_keys, self.lossless_floats) {
(true, true) => ppp!($string_cache, DuplicateKeyCheck, ParseNumberLossless),
(true, false) => ppp!($string_cache, DuplicateKeyCheck, ParseNumberLossy),
(false, true) => ppp!($string_cache, NoopKeyCheck, ParseNumberLossless),
(false, false) => ppp!($string_cache, NoopKeyCheck, ParseNumberLossy),
match (self.catch_duplicate_keys, self.float_mode) {
(true, FloatMode::Float) => ppp!($string_cache, DuplicateKeyCheck, ParseNumberLossy),
(true, FloatMode::Decimal) => ppp!($string_cache, DuplicateKeyCheck, ParseNumberDecimal),
(true, FloatMode::LosslessFloat) => ppp!($string_cache, DuplicateKeyCheck, ParseNumberLossless),
(false, FloatMode::Float) => ppp!($string_cache, NoopKeyCheck, ParseNumberLossy),
(false, FloatMode::Decimal) => ppp!($string_cache, NoopKeyCheck, ParseNumberDecimal),
(false, FloatMode::LosslessFloat) => ppp!($string_cache, NoopKeyCheck, ParseNumberLossless),
}
};
}
Expand Down Expand Up @@ -378,3 +382,42 @@ impl MaybeParseNumber for ParseNumberLossless {
}
}
}

struct ParseNumberDecimal;

impl MaybeParseNumber for ParseNumberDecimal {
fn parse_number<'py>(
py: Python<'py>,
parser: &mut Parser,
peek: Peek,
allow_inf_nan: bool,
) -> JsonResult<Bound<'py, PyAny>> {
match parser.consume_number::<NumberRange>(peek.into_inner(), allow_inf_nan) {
Ok(number_range) => {
let bytes = parser.slice(number_range.range).unwrap();
if number_range.is_int {
let obj = NumberAny::decode(bytes, 0, peek.into_inner(), allow_inf_nan)?
.0
.to_object(py);
Ok(obj.into_bound(py))
} else {
let decimal_type = get_decimal_type(py)
.map_err(|e| JsonError::new(JsonErrorType::InternalError(e.to_string()), parser.index))?;
// SAFETY: NumberRange::decode has already confirmed that bytes are a valid JSON number,
// and therefore valid str
let float_str = unsafe { std::str::from_utf8_unchecked(bytes) };
decimal_type
.call1((float_str,))
.map_err(|e| JsonError::new(JsonErrorType::InternalError(e.to_string()), parser.index))
}
}
Err(e) => {
if !peek.is_num() {
Err(json_error!(ExpectedSomeValue, parser.index))
} else {
Err(e)
}
}
}
}
}

0 comments on commit 6c0a187

Please sign in to comment.