Skip to content

Commit

Permalink
fix: add some types
Browse files Browse the repository at this point in the history
  • Loading branch information
dhdaines committed Nov 20, 2023
1 parent 2fce4c5 commit efb7ab3
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 26 deletions.
10 changes: 6 additions & 4 deletions alexi/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def get_element_bbox(page: Page, el: PDFStructElement, mcids: Iterable[int]) ->
return geometry.objects_to_bbox(mcid_objs)


def get_rgb(c: dict) -> str:
def get_rgb(c: T_obj) -> str:
"""Extraire la couleur d'un objet en 3 chiffres hexadécimaux"""
couleur = c.get("non_stroking_color", c.get("stroking_color"))
if couleur is None:
Expand All @@ -146,11 +146,11 @@ def get_rgb(c: dict) -> str:


def get_word_features(
word: dict,
word: T_obj,
page: Page,
chars: dict[tuple[int, int], T_obj],
elmap: dict[int, str],
) -> dict:
) -> T_obj:
# Extract things from first character (we do not use
# extra_attrs because otherwise extract_words will
# insert word breaks)
Expand Down Expand Up @@ -183,7 +183,9 @@ class Converteur:
y_tolerance: int

def __init__(
self, path_or_fp: Union[str, Path, BufferedReader, BytesIO], y_tolerance=2
self,
path_or_fp: Union[str, Path, BufferedReader, BytesIO],
y_tolerance: int = 2,
):
self.pdf = PDF.open(path_or_fp)
self.y_tolerance = y_tolerance
Expand Down
2 changes: 1 addition & 1 deletion alexi/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def add_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
return parser


def main(args):
def main(args: argparse.Namespace) -> None:
u = urllib.parse.urlparse(args.url)
LOGGER.info("Downloading %s", args.url)
try:
Expand Down
5 changes: 3 additions & 2 deletions alexi/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
from whoosh.index import create_in # type: ignore
from whoosh.support.charset import charset_table_to_dict # type: ignore
from whoosh.support.charset import default_charset
from whoosh.writing import IndexWriter # type: ignore

LOGGER = logging.getLogger("index")
CHARMAP = charset_table_to_dict(default_charset)
ANALYZER = StemmingAnalyzer() | CharsetFilter(CHARMAP)


def add_from_dir(writer, document, docdir):
def add_from_dir(writer: IndexWriter, document: str, docdir: Path) -> None:
LOGGER.info("Indexing %s", docdir)
with open(docdir / "index.json") as infh:
element = json.load(infh)
Expand All @@ -30,7 +31,7 @@ def add_from_dir(writer, document, docdir):
)


def index(indir: Path, outdir: Path):
def index(indir: Path, outdir: Path) -> None:
outdir.mkdir(exist_ok=True)
schema = Schema(
document=ID(stored=True),
Expand Down
2 changes: 1 addition & 1 deletion alexi/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from whoosh.qparser import MultifieldParser, OrGroup # type: ignore


def search(indexdir: Path, terms: List[str]):
def search(indexdir: Path, terms: List[str]) -> None:
ix = open_dir(indexdir)
parser = MultifieldParser(
["titre", "contenu"], ix.schema, group=OrGroup.factory(0.9)
Expand Down
27 changes: 14 additions & 13 deletions alexi/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
from enum import Enum
from os import PathLike
from pathlib import Path
from typing import Any, Callable, Iterable, Iterator, Union
from typing import Any, Callable, Iterable, Iterator, Union, Optional

import joblib # type: ignore

from alexi.convert import FIELDNAMES
from alexi.types import T_obj

FEATNAMES = [name for name in FIELDNAMES if name not in ("segment", "sequence")]
DEFAULT_MODEL = Path(__file__).parent / "models" / "crf.joblib.gz"
Expand All @@ -26,7 +27,7 @@ class Bullet(Enum):
BULLET = re.compile(r"^([•-])$") # FIXME: need more bullets


def sign(x: Union[int | float]):
def sign(x: Union[int | float]) -> int:
"""Get the sign of a number (should exist...)"""
if x == 0:
return 0
Expand All @@ -36,11 +37,11 @@ def sign(x: Union[int | float]):


def make_visual_structural_literal() -> FeatureFunc:
prev_word = None
prev_line_height = None
prev_line_start = None
prev_word: Optional[T_obj] = None
prev_line_height = 1.0
prev_line_start = 0.0

def visual_one(idx, word):
def visual_one(idx: int, word: T_obj) -> list[str]:
nonlocal prev_word, prev_line_height, prev_line_start
if idx == 0: # page break
prev_word = None
Expand Down Expand Up @@ -68,10 +69,10 @@ def visual_one(idx, word):
]
newline = False
linedelta = 0.0
dx = 1
dy = 0
dh = 0
prev_height = 1
dx = 1.0
dy = 0.0
dh = 0.0
prev_height = 1.0
if prev_word is not None:
height = float(word["bottom"]) - float(word["top"])
prev_height = float(prev_word["bottom"]) - float(prev_word["top"])
Expand Down Expand Up @@ -118,10 +119,10 @@ def visual_one(idx, word):

def make_visual_literal() -> FeatureFunc:
prev_word = None
prev_line_height = None
prev_line_start = None
prev_line_height = 1.0
prev_line_start = 0.0

def visual_one(idx, word):
def visual_one(idx, word) -> list[str]:
nonlocal prev_word, prev_line_height, prev_line_start
if idx == 0: # page break
prev_word = None
Expand Down
2 changes: 1 addition & 1 deletion alexi/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class Bloc:
_bbox: Optional[T_bbox] = None
_page_number: Optional[int] = None

def __hash__(self):
def __hash__(self) -> int:
if self._bbox:
return hash((self.type, self._bbox, self._page_number))
else:
Expand Down
6 changes: 4 additions & 2 deletions scripts/title_pages.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pathlib import Path


def main():
def main() -> None:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("indir", help="Repertoire d'entrée", type=Path)
parser.add_argument("outdir", help="Repertoire de sortie", type=Path)
Expand All @@ -20,6 +20,7 @@ def main():
for p in args.indir.glob("*.csv"):
with open(p, "rt") as infh, open(args.outdir / p.name, "wt") as outfh:
reader = csv.DictReader(infh)
assert reader.fieldnames is not None
fieldnames = list(reader.fieldnames)
fieldnames.insert(0, "seqtag")
writer = csv.DictWriter(outfh, fieldnames=fieldnames)
Expand All @@ -32,7 +33,8 @@ def main():
if last_page is None:
writer.writerows(contents)
last_page = contents
writer.writerows(last_page)
if last_page is not None:
writer.writerows(last_page)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions test/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
DATADIR = Path(__file__).parent / "data"


def test_convert():
def test_convert() -> None:
with open(DATADIR / "pdf_structure.pdf", "rb") as infh:
conv = Converteur(infh)
words = list(conv.extract_words())
Expand All @@ -19,7 +19,7 @@ def test_convert():
assert len(words) == len(ref_words)


def test_extract_tables_and_figures():
def test_extract_tables_and_figures() -> None:
with open(DATADIR / "pdf_figures.pdf", "rb") as infh:
conv = Converteur(infh)
words = list(conv.extract_words())
Expand Down

0 comments on commit efb7ab3

Please sign in to comment.