diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1f80c42..99bfc08 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,7 @@ repos: hooks: - id: trailing-whitespace - id: end-of-file-fixer - - repo: https://gitlab.com/pycqa/flake8 + - repo: https://github.com/pycqa/flake8 rev: 3.8.4 hooks: - id: flake8 diff --git a/src/parsy/__init__.py b/src/parsy/__init__.py index cf96e5c..b210d2b 100644 --- a/src/parsy/__init__.py +++ b/src/parsy/__init__.py @@ -5,7 +5,7 @@ import re from dataclasses import dataclass from functools import wraps -from typing import Any, Callable, FrozenSet +from typing import Any, AnyStr, Callable, FrozenSet __version__ = "2.1" @@ -166,10 +166,23 @@ def combine_dict(self, combine_fn: Callable) -> Parser: def concat(self) -> Parser: """ - Returns a parser that concatenates together (as a string) the previously - produced values. + Returns a parser that concatenates together the previously produced values. + + This parser will join the values using the type of the input stream, so + when feeding bytes to the parser, the items to be joined must also be bytes. """ - return self.map("".join) + + @Parser + def parser(stream: bytes | str, index: int) -> Result: + joiner = type(stream)() + result = self(stream, index) + if result.status: + next_parser: Parser = success(joiner.join(result.value)) + return next_parser(stream, result.index).aggregate(result) + else: + return result + + return parser def then(self, other: Parser) -> Parser: """ @@ -516,13 +529,16 @@ def fail(expected: str) -> Parser: return Parser(lambda _, index: Result.failure(index, expected)) -def string(expected_string: str, transform: Callable[[str], str] = noop) -> Parser: +def string(expected_string: AnyStr, transform: Callable[[AnyStr], AnyStr] = noop) -> Parser: """ Returns a parser that expects the ``expected_string`` and produces that string value. Optionally, a transform function can be passed, which will be used on both the expected string and tested string. + + This parser can also be instantiated with a bytes value, in which case it can + should be applied to a stream of bytes. """ slen = len(expected_string) diff --git a/tests/test_parsy.py b/tests/test_parsy.py index ba08f56..f73af48 100644 --- a/tests/test_parsy.py +++ b/tests/test_parsy.py @@ -32,13 +32,13 @@ class TestParser(unittest.TestCase): - def test_string(self): + def test_string_str(self): parser = string("x") self.assertEqual(parser.parse("x"), "x") self.assertRaises(ParseError, parser.parse, "y") - def test_string_transform(self): + def test_string_transform_str(self): parser = string("x", transform=lambda s: s.lower()) self.assertEqual(parser.parse("x"), "x") self.assertEqual(parser.parse("X"), "x") @@ -53,6 +53,19 @@ def test_string_transform_2(self): self.assertRaises(ParseError, parser.parse, "dog") + def test_string_bytes(self): + parser = string(b"x") + self.assertEqual(parser.parse(b"x"), b"x") + + self.assertRaises(ParseError, parser.parse, b"y") + + def test_string_transform_bytes(self): + parser = string(b"x", transform=lambda s: s.lower()) + self.assertEqual(parser.parse(b"x"), b"x") + self.assertEqual(parser.parse(b"X"), b"x") + + self.assertRaises(ParseError, parser.parse, b"y") + def test_regex_str(self): parser = regex(r"[0-9]") @@ -157,11 +170,16 @@ def test_combine_dict_skip_underscores(self): ).combine_dict(Pair) self.assertEqual(parser.parse("ABC 123"), Pair(word="ABC", number=123)) - def test_concat(self): + def test_concat_str(self): parser = letter.many().concat() self.assertEqual(parser.parse(""), "") self.assertEqual(parser.parse("abc"), "abc") + def test_concat_bytes(self): + parser = any_char.many().concat() + self.assertEqual(parser.parse(b""), b"") + self.assertEqual(parser.parse(b"abc"), b"abc") + def test_generate(self): x = y = None