Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ff #40

Merged
merged 7 commits into from
Aug 7, 2023
Merged

ff #40

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions examples/quick_start_demo.ipynb

Large diffs are not rendered by default.

8 changes: 6 additions & 2 deletions papermage/magelib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
RowsFieldName,
BlocksFieldName,
ImagesFieldName,
WordsFieldName
WordsFieldName,
SentencesFieldName,
ParagraphsFieldName
)

__all__ = [
Expand All @@ -47,5 +49,7 @@
"TokensFieldName",
"RowsFieldName",
"BlocksFieldName",
"WordsFieldName"
"WordsFieldName",
"SentencesFieldName",
"ParagraphsFieldName",
]
32 changes: 20 additions & 12 deletions papermage/magelib/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
RowsFieldName = "rows"
BlocksFieldName = "blocks"
WordsFieldName = "words"
SentencesFieldName = "sentences"
ParagraphsFieldName = "paragraphs"


class Document:
Expand All @@ -44,43 +46,49 @@ def __init__(self, symbols: str, metadata: Optional[Metadata] = None):
def fields(self) -> List[str]:
return list(self.__entity_span_indexers.keys()) + self.SPECIAL_FIELDS

def find_by_span(self, query: Union[Entity, Span], field_name: str) -> List[Entity]:
if isinstance(query, Entity):
return self.__entity_span_indexers[field_name].find(query=query)
elif isinstance(query, Span):
def find(self, query: Union[Span, Box], field_name: str) -> List[Entity]:
if isinstance(query, Span):
return self.__entity_span_indexers[field_name].find(query=Entity(spans=[query]))
else:
raise TypeError(f"Unsupported query type {type(query)}")

def find_by_box(self, query: Union[Entity, Box], field_name: str) -> List[Entity]:
if isinstance(query, Entity):
return self.__entity_box_indexers[field_name].find(query=query)
elif isinstance(query, Box):
return self.__entity_box_indexers[field_name].find(query=Entity(boxes=[query]))
else:
raise TypeError(f"Unsupported query type {type(query)}")

def find_by_span(self, query: Entity, field_name: str) -> List[Entity]:
# TODO: will rename this to `intersect_by_span`
return self.__entity_span_indexers[field_name].find(query=query)

def find_by_box(self, query: Entity, field_name: str) -> List[Entity]:
# TODO: will rename this to `intersect_by_span`
return self.__entity_box_indexers[field_name].find(query=query)

def check_field_name_availability(self, field_name: str) -> None:
if field_name in self.SPECIAL_FIELDS:
raise AssertionError(f"{field_name} not allowed Document.SPECIAL_FIELDS.")
if field_name in self.__entity_span_indexers.keys():
raise AssertionError(f"{field_name} already exists. Try `is_overwrite=True`")
raise AssertionError(f'{field_name} already exists. Try `doc.remove_entity("{field_name}")` first.')
if field_name in dir(self):
raise AssertionError(f"{field_name} clashes with Document class properties.")

def get_entity(self, field_name: str) -> List[Entity]:
return getattr(self, field_name)

def annotate(self, field_name: str, entities: List[Entity]) -> None:
if all(isinstance(e, Entity) for e in entities):
self.annotate_entity(field_name=field_name, entities=entities)
else:
raise NotImplementedError(f"entity list contains non-entities: {[type(e) for e in entities]}")

def annotate_entity(self, field_name: str, entities: List[Entity]) -> None:
self.check_field_name_availability(field_name=field_name)

for i, entity in enumerate(entities):
entity.doc = self
entity.id = i

setattr(self, field_name, entities)
self.__entity_span_indexers[field_name] = EntitySpanIndexer(entities=entities)
self.__entity_box_indexers[field_name] = EntityBoxIndexer(entities=entities)
setattr(self, field_name, entities)

def remove_entity(self, field_name: str):
for entity in getattr(self, field_name):
Expand Down
4 changes: 2 additions & 2 deletions papermage/magelib/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _ensure_disjoint(self) -> None:

def find(self, query: Entity) -> List[Entity]:
if not isinstance(query, Entity):
raise ValueError(f"EntityIndexer only works with `query` that is Entity type")
raise TypeError(f"EntityIndexer only works with `query` that is Entity type")

if not query.spans:
return []
Expand Down Expand Up @@ -159,7 +159,7 @@ def _ensure_disjoint(self) -> None:

def find(self, query: Entity) -> List[Entity]:
if not isinstance(query, Entity):
raise ValueError(f"EntityBoxIndexer only works with `query` that is Entity type")
raise TypeError(f"EntityBoxIndexer only works with `query` that is Entity type")

if not query.boxes:
return []
Expand Down
38 changes: 34 additions & 4 deletions papermage/recipes/core_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,27 @@
"""

import logging
from pathlib import Path
from typing import Union

logger = logging.getLogger(__name__)

from papermage.magelib import Document, Entity
from papermage.magelib import (
BlocksFieldName,
Document,
EntitiesFieldName,
Entity,
ImagesFieldName,
MetadataFieldName,
PagesFieldName,
ParagraphsFieldName,
RelationsFieldName,
RowsFieldName,
SentencesFieldName,
SymbolsFieldName,
TokensFieldName,
WordsFieldName,
)
from papermage.parsers.pdfplumber_parser import PDFPlumberParser
from papermage.predictors import (
HFBIOTaggerPredictor,
Expand Down Expand Up @@ -46,21 +63,34 @@ def __init__(
self.sent_predictor = PysbdSentencePredictor()
logger.info("Finished instantiating recipe")

def run(self, pdf: Union[str, Path, Document]) -> Document:
if isinstance(pdf, str):
pdf = Path(pdf)
assert pdf.exists(), f"File {pdf} does not exist."
assert isinstance(
pdf, (Document, Path)
), f"Unsupported type {type(pdf)} for pdf; should be a Document or a path to a PDF file."
if isinstance(pdf, Path):
self.from_path(str(pdf))
else:
raise NotImplementedError("Document input not yet supported.")

def from_path(self, pdfpath: str) -> Document:
logger.info("Parsing document...")
doc = self.parser.parse(input_pdf_path=pdfpath)

logger.info("Rasterizing document...")
images = self.rasterizer.rasterize(input_pdf_path=pdfpath, dpi=72)
doc.annotate_images(images=list(images))
self.rasterizer.attach_images(images=images, doc=doc)

logger.info("Predicting words...")
words = self.word_predictor.predict(doc=doc)
doc.annotate_entity(field_name="words", entities=words)
doc.annotate_entity(field_name=WordsFieldName, entities=words)

logger.info("Predicting sentences...")
sentences = self.sent_predictor.predict(doc=doc)
doc.annotate_entity(field_name="sentences", entities=sentences)
doc.annotate_entity(field_name=SentencesFieldName, entities=sentences)

logger.info("Predicting blocks...")
layout = self.effdet_publaynet_predictor.predict(doc=doc)
Expand All @@ -75,7 +105,7 @@ def from_path(self, pdfpath: str) -> Document:
# blocks are used by IVILA, so we need to annotate them as well
# copy the entities because they already have docs attached
blocks = [Entity.from_json(ent.to_json()) for ent in layout + equations]
doc.annotate_entity(field_name="blocks", entities=blocks)
doc.annotate_entity(field_name=BlocksFieldName, entities=blocks)

logger.info("Predicting vila...")
vila_entities = self.ivila_predictor.predict(doc=doc)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = 'papermage'
version = '0.9.0'
version = '0.10.0'
description = 'Papermage. Casting magic over scientific PDFs.'
license = {text = 'Apache-2.0'}
readme = 'README.md'
Expand Down
6 changes: 4 additions & 2 deletions tests/test_magelib/test_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,15 +203,17 @@ def test_query(self):
# test query by span
self.assertListEqual(
doc.find_by_span(query=doc.chunks[0], field_name="tokens"),
doc.find_by_span(query=doc.chunks[0].spans[0], field_name="tokens"),
doc.find(query=doc.chunks[0].spans[0], field_name="tokens"),
)
# test query by box
self.assertListEqual(
doc.find_by_box(query=doc.chunks[0], field_name="tokens"),
doc.find_by_box(query=doc.chunks[0].boxes[0], field_name="tokens"),
doc.find(query=doc.chunks[0].boxes[0], field_name="tokens"),
)
# calling wrong method w input type should fail
with self.assertRaises(TypeError):
doc.find_by_box(query=doc.chunks[0].spans[0], field_name="tokens")
with self.assertRaises(TypeError):
doc.find_by_span(query=doc.chunks[0].boxes[0], field_name="tokens")
with self.assertRaises(TypeError):
doc.find(query=doc.chunks[0], field_name="tokens")
Loading