Skip to content

Commit

Permalink
rfc8785: ensure IntEnum works on < 3.11 (#5)
Browse files Browse the repository at this point in the history
* rfc8785: ensure IntEnum works on < 3.11

Signed-off-by: William Woodruff <william@trailofbits.com>

* rfc8785: generalize coercion

Signed-off-by: William Woodruff <william@trailofbits.com>

* test_impl: test StrEnum as well

Signed-off-by: William Woodruff <william@trailofbits.com>

* test_impl: skip StrEnum test below 3.11

Since StrEnum does not exist.

Signed-off-by: William Woodruff <william@trailofbits.com>

* test_impl: fix import

Signed-off-by: William Woodruff <william@trailofbits.com>

---------

Signed-off-by: William Woodruff <william@trailofbits.com>
  • Loading branch information
woodruffw authored Mar 7, 2024
1 parent 18fb92d commit 98e7670
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/rfc8785/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,11 @@ def dump(obj: _Value, sink: IO[bytes]) -> None:
else:
sink.write(b"false")
case int():
# Annoyance: int can be subclassed by types like IntEnum,
# which then break or change `int.__str__`. Rather than plugging
# these individually, we coerce back to `int`.
obj = int(obj)

if obj < _INT_MIN or obj > _INT_MAX:
raise IntegerDomainError(obj)
sink.write(str(obj).encode("utf-8"))
Expand Down
15 changes: 15 additions & 0 deletions test/test_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import gzip
import json
import struct
import sys
from enum import IntEnum
from io import BytesIO

Expand Down Expand Up @@ -112,3 +113,17 @@ class X(IntEnum):

raw = impl.dumps([X.A, X.B, X.C])
assert json.loads(raw) == [1, 2, 9001]


@pytest.mark.skipif(sys.version_info < (3, 11), reason="StrEnum added in 3.11+")
def test_dumps_strenum():
from enum import StrEnum

# StrEnum is a subclass of str, so this should work transparently.
class X(StrEnum):
A = "foo"
B = "bar"
C = "baz"

raw = impl.dumps([X.A, X.B, X.C])
assert json.loads(raw) == ["foo", "bar", "baz"]

0 comments on commit 98e7670

Please sign in to comment.