From efb7ab3a3576bff9d7c119df69e73f41c49d300b Mon Sep 17 00:00:00 2001 From: David Huggins-Daines Date: Mon, 20 Nov 2023 13:56:19 -0500 Subject: [PATCH] fix: add some types --- alexi/convert.py | 10 ++++++---- alexi/download.py | 2 +- alexi/index.py | 5 +++-- alexi/search.py | 2 +- alexi/segment.py | 27 ++++++++++++++------------- alexi/types.py | 2 +- scripts/title_pages.py | 6 ++++-- test/test_convert.py | 4 ++-- 8 files changed, 32 insertions(+), 26 deletions(-) diff --git a/alexi/convert.py b/alexi/convert.py index c62bd55..893bfe4 100644 --- a/alexi/convert.py +++ b/alexi/convert.py @@ -130,7 +130,7 @@ def get_element_bbox(page: Page, el: PDFStructElement, mcids: Iterable[int]) -> return geometry.objects_to_bbox(mcid_objs) -def get_rgb(c: dict) -> str: +def get_rgb(c: T_obj) -> str: """Extraire la couleur d'un objet en 3 chiffres hexadécimaux""" couleur = c.get("non_stroking_color", c.get("stroking_color")) if couleur is None: @@ -146,11 +146,11 @@ def get_rgb(c: dict) -> str: def get_word_features( - word: dict, + word: T_obj, page: Page, chars: dict[tuple[int, int], T_obj], elmap: dict[int, str], -) -> dict: +) -> T_obj: # Extract things from first character (we do not use # extra_attrs because otherwise extract_words will # insert word breaks) @@ -183,7 +183,9 @@ class Converteur: y_tolerance: int def __init__( - self, path_or_fp: Union[str, Path, BufferedReader, BytesIO], y_tolerance=2 + self, + path_or_fp: Union[str, Path, BufferedReader, BytesIO], + y_tolerance: int = 2, ): self.pdf = PDF.open(path_or_fp) self.y_tolerance = y_tolerance diff --git a/alexi/download.py b/alexi/download.py index 5423ad0..0ef24f6 100644 --- a/alexi/download.py +++ b/alexi/download.py @@ -48,7 +48,7 @@ def add_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: return parser -def main(args): +def main(args: argparse.Namespace) -> None: u = urllib.parse.urlparse(args.url) LOGGER.info("Downloading %s", args.url) try: diff --git a/alexi/index.py b/alexi/index.py index 0a979e3..c47363a 100644 --- a/alexi/index.py +++ b/alexi/index.py @@ -12,13 +12,14 @@ from whoosh.index import create_in # type: ignore from whoosh.support.charset import charset_table_to_dict # type: ignore from whoosh.support.charset import default_charset +from whoosh.writing import IndexWriter # type: ignore LOGGER = logging.getLogger("index") CHARMAP = charset_table_to_dict(default_charset) ANALYZER = StemmingAnalyzer() | CharsetFilter(CHARMAP) -def add_from_dir(writer, document, docdir): +def add_from_dir(writer: IndexWriter, document: str, docdir: Path) -> None: LOGGER.info("Indexing %s", docdir) with open(docdir / "index.json") as infh: element = json.load(infh) @@ -30,7 +31,7 @@ def add_from_dir(writer, document, docdir): ) -def index(indir: Path, outdir: Path): +def index(indir: Path, outdir: Path) -> None: outdir.mkdir(exist_ok=True) schema = Schema( document=ID(stored=True), diff --git a/alexi/search.py b/alexi/search.py index cfb283e..9cb6816 100644 --- a/alexi/search.py +++ b/alexi/search.py @@ -9,7 +9,7 @@ from whoosh.qparser import MultifieldParser, OrGroup # type: ignore -def search(indexdir: Path, terms: List[str]): +def search(indexdir: Path, terms: List[str]) -> None: ix = open_dir(indexdir) parser = MultifieldParser( ["titre", "contenu"], ix.schema, group=OrGroup.factory(0.9) diff --git a/alexi/segment.py b/alexi/segment.py index 276486c..d99586e 100644 --- a/alexi/segment.py +++ b/alexi/segment.py @@ -7,11 +7,12 @@ from enum import Enum from os import PathLike from pathlib import Path -from typing import Any, Callable, Iterable, Iterator, Union +from typing import Any, Callable, Iterable, Iterator, Union, Optional import joblib # type: ignore from alexi.convert import FIELDNAMES +from alexi.types import T_obj FEATNAMES = [name for name in FIELDNAMES if name not in ("segment", "sequence")] DEFAULT_MODEL = Path(__file__).parent / "models" / "crf.joblib.gz" @@ -26,7 +27,7 @@ class Bullet(Enum): BULLET = re.compile(r"^([•-])$") # FIXME: need more bullets -def sign(x: Union[int | float]): +def sign(x: Union[int | float]) -> int: """Get the sign of a number (should exist...)""" if x == 0: return 0 @@ -36,11 +37,11 @@ def sign(x: Union[int | float]): def make_visual_structural_literal() -> FeatureFunc: - prev_word = None - prev_line_height = None - prev_line_start = None + prev_word: Optional[T_obj] = None + prev_line_height = 1.0 + prev_line_start = 0.0 - def visual_one(idx, word): + def visual_one(idx: int, word: T_obj) -> list[str]: nonlocal prev_word, prev_line_height, prev_line_start if idx == 0: # page break prev_word = None @@ -68,10 +69,10 @@ def visual_one(idx, word): ] newline = False linedelta = 0.0 - dx = 1 - dy = 0 - dh = 0 - prev_height = 1 + dx = 1.0 + dy = 0.0 + dh = 0.0 + prev_height = 1.0 if prev_word is not None: height = float(word["bottom"]) - float(word["top"]) prev_height = float(prev_word["bottom"]) - float(prev_word["top"]) @@ -118,10 +119,10 @@ def visual_one(idx, word): def make_visual_literal() -> FeatureFunc: prev_word = None - prev_line_height = None - prev_line_start = None + prev_line_height = 1.0 + prev_line_start = 0.0 - def visual_one(idx, word): + def visual_one(idx, word) -> list[str]: nonlocal prev_word, prev_line_height, prev_line_start if idx == 0: # page break prev_word = None diff --git a/alexi/types.py b/alexi/types.py index b3c1855..b2c9640 100644 --- a/alexi/types.py +++ b/alexi/types.py @@ -15,7 +15,7 @@ class Bloc: _bbox: Optional[T_bbox] = None _page_number: Optional[int] = None - def __hash__(self): + def __hash__(self) -> int: if self._bbox: return hash((self.type, self._bbox, self._page_number)) else: diff --git a/scripts/title_pages.py b/scripts/title_pages.py index 14bd61f..733ca1c 100644 --- a/scripts/title_pages.py +++ b/scripts/title_pages.py @@ -10,7 +10,7 @@ from pathlib import Path -def main(): +def main() -> None: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("indir", help="Repertoire d'entrée", type=Path) parser.add_argument("outdir", help="Repertoire de sortie", type=Path) @@ -20,6 +20,7 @@ def main(): for p in args.indir.glob("*.csv"): with open(p, "rt") as infh, open(args.outdir / p.name, "wt") as outfh: reader = csv.DictReader(infh) + assert reader.fieldnames is not None fieldnames = list(reader.fieldnames) fieldnames.insert(0, "seqtag") writer = csv.DictWriter(outfh, fieldnames=fieldnames) @@ -32,7 +33,8 @@ def main(): if last_page is None: writer.writerows(contents) last_page = contents - writer.writerows(last_page) + if last_page is not None: + writer.writerows(last_page) if __name__ == "__main__": diff --git a/test/test_convert.py b/test/test_convert.py index 5ba5f48..1c96e5b 100644 --- a/test/test_convert.py +++ b/test/test_convert.py @@ -8,7 +8,7 @@ DATADIR = Path(__file__).parent / "data" -def test_convert(): +def test_convert() -> None: with open(DATADIR / "pdf_structure.pdf", "rb") as infh: conv = Converteur(infh) words = list(conv.extract_words()) @@ -19,7 +19,7 @@ def test_convert(): assert len(words) == len(ref_words) -def test_extract_tables_and_figures(): +def test_extract_tables_and_figures() -> None: with open(DATADIR / "pdf_figures.pdf", "rb") as infh: conv = Converteur(infh) words = list(conv.extract_words())