Skip to content

Commit

Permalink
bugfix! figures missing? (#73)
Browse files Browse the repository at this point in the history
* bugfix! figures were not  plottin properly because of weird interactino between getattr() and spans

* revert a fix so dont break apis; add deprecation warnings

* fix vila predictor to not rely on deprecated methods; fix core recipe to drop word predictor which is bugging out

* move names into separate file

* make default text operation smoother
  • Loading branch information
kyleclo committed Mar 14, 2024
1 parent 6a0a4a2 commit 4cb681e
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 82 deletions.
14 changes: 6 additions & 8 deletions papermage/magelib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,20 @@
"""


from .box import Box
from .document import Document, Prediction
from .entity import Entity
from .image import Image
from .indexer import EntityBoxIndexer, EntitySpanIndexer
from .layer import Layer
from .document import (
from .metadata import Metadata
from .names import (
AbstractsFieldName,
AlgorithmsFieldName,
AuthorsFieldName,
BibliographiesFieldName,
BlocksFieldName,
CaptionsFieldName,
Document,
Prediction,
EntitiesFieldName,
EquationsFieldName,
FiguresFieldName,
Expand All @@ -40,10 +42,6 @@
TokensFieldName,
WordsFieldName,
)
from .entity import Entity
from .image import Image
from .indexer import EntityBoxIndexer, EntitySpanIndexer
from .metadata import Metadata
from .span import Span

__all__ = [
Expand Down
53 changes: 20 additions & 33 deletions papermage/magelib/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

import logging
from itertools import chain
from typing import Dict, List, NamedTuple, Optional, Tuple, Union

Expand All @@ -13,40 +14,16 @@
from .image import Image
from .layer import Layer
from .metadata import Metadata
from .names import (
EntitiesFieldName,
ImagesFieldName,
MetadataFieldName,
RelationsFieldName,
SymbolsFieldName,
TokensFieldName,
)
from .span import Span

# document field names
SymbolsFieldName = "symbols"
ImagesFieldName = "images"
MetadataFieldName = "metadata"
EntitiesFieldName = "entities"
RelationsFieldName = "relations"

PagesFieldName = "pages"
TokensFieldName = "tokens"
RowsFieldName = "rows"
BlocksFieldName = "blocks"
WordsFieldName = "words"
SentencesFieldName = "sentences"
ParagraphsFieldName = "paragraphs"

# these come from vila
TitlesFieldName = "titles"
AuthorsFieldName = "authors"
AbstractsFieldName = "abstracts"
KeywordsFieldName = "keywords"
SectionsFieldName = "sections"
ListsFieldName = "lists"
BibliographiesFieldName = "bibliographies"
EquationsFieldName = "equations"
AlgorithmsFieldName = "algorithms"
FiguresFieldName = "figures"
TablesFieldName = "tables"
CaptionsFieldName = "captions"
HeadersFieldName = "headers"
FootersFieldName = "footers"
FootnotesFieldName = "footnotes"


class Prediction(NamedTuple):
name: str
Expand Down Expand Up @@ -198,7 +175,17 @@ def __repr__(self):

def find(self, query: Union[Span, Box], name: str) -> List[Entity]:
"""Finds all entities that intersect with the query"""
return self.get_layer(name=name).find(query=query)
logger = logging.getLogger(__name__)
logger.warning(
"This method is deprecated due to ambiguity and will be removed in a future release."
"Please use Document.intersect_by_span or Document.intersect_by_box instead."
)
if isinstance(query, Span):
return self.intersect_by_span(query=Entity(spans=[query]), name=name)
elif isinstance(query, Box):
return self.intersect_by_box(query=Entity(boxes=[query]), name=name)
else:
raise TypeError(f"Unsupported query type {type(query)}")

def intersect_by_span(self, query: Entity, name: str) -> List[Entity]:
"""Finds all entities that intersect with the query"""
Expand Down
13 changes: 10 additions & 3 deletions papermage/magelib/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@
"""

import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Union

from .box import Box
from .image import Image
from .metadata import Metadata
from .names import TokensFieldName
from .span import Span

if TYPE_CHECKING:
from .document import TokensFieldName
from .layer import Layer


Expand Down Expand Up @@ -93,6 +94,12 @@ def id(self, id: int) -> None:

def __getattr__(self, name: str) -> List["Entity"]:
"""This Overloading is convenient syntax since the `entity.layer` operation is intuitive for folks."""
# add method deprecation warning
logger = logging.getLogger(__name__)
logger.warning(
"Entity.__getattr__ is deprecated due to ambiguity and will be removed in a future release."
"Please use Entity.intersect_by_span or Entity.intersect_by_box instead."
)
try:
return self.intersect_by_span(name=name)
except ValueError:
Expand Down Expand Up @@ -164,10 +171,10 @@ def text(self) -> str:
return maybe_text
# return derived from symbols
if self.symbols_from_spans:
return " ".join(self.symbols_from_spans)
return " ".join(self.symbols_from_spans).replace("\n", " ")
# return derived from boxes and tokens
if self.symbols_from_boxes:
return " ".join(self.symbols_from_boxes)
return " ".join(self.symbols_from_boxes).replace("\n", " ")
return ""

@text.setter
Expand Down
6 changes: 6 additions & 0 deletions papermage/magelib/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""

import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Union

from .box import Box
Expand Down Expand Up @@ -63,6 +64,11 @@ def from_json(cls, layer_json):
return cls(entities=[Entity.from_json(entity_json) for entity_json in layer_json])

def find(self, query: Union[Span, Box]) -> List[Entity]:
logger = logging.getLogger(__name__)
logger.warning(
"This method is deprecated due to ambiguity and will be removed in a future release."
"Please use Layer.intersect_by_span or Layer.intersect_by_box instead."
)
if isinstance(query, Span):
return self.intersect_by_span(query=Entity(spans=[query]))
elif isinstance(query, Box):
Expand Down
37 changes: 37 additions & 0 deletions papermage/magelib/names.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""
@kylel
"""

# document field names
SymbolsFieldName = "symbols"
ImagesFieldName = "images"
MetadataFieldName = "metadata"
EntitiesFieldName = "entities"
RelationsFieldName = "relations"

PagesFieldName = "pages"
TokensFieldName = "tokens"
RowsFieldName = "rows"
BlocksFieldName = "blocks"
WordsFieldName = "words"
SentencesFieldName = "sentences"
ParagraphsFieldName = "paragraphs"

# these come from vila
TitlesFieldName = "titles"
AuthorsFieldName = "authors"
AbstractsFieldName = "abstracts"
KeywordsFieldName = "keywords"
SectionsFieldName = "sections"
ListsFieldName = "lists"
BibliographiesFieldName = "bibliographies"
EquationsFieldName = "equations"
AlgorithmsFieldName = "algorithms"
FiguresFieldName = "figures"
TablesFieldName = "tables"
CaptionsFieldName = "captions"
HeadersFieldName = "headers"
FootersFieldName = "footers"
FootnotesFieldName = "footnotes"
57 changes: 35 additions & 22 deletions papermage/predictors/vila_predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
@shannons, @kylel
"""

import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"
Expand All @@ -18,7 +19,7 @@

import torch
from tqdm import tqdm
from vila.predictors import LayoutIndicatorPDFPredictor, SimplePDFPredictor
from vila.predictors import LayoutIndicatorPDFPredictor

from papermage.magelib import (
BlocksFieldName,
Expand Down Expand Up @@ -102,16 +103,23 @@ def shift_index_sequence_to_zero_start(sequence):

# util
def get_visual_group_id(token: Entity, field_name: str, defaults=-1) -> int:
if not hasattr(token, field_name):
field_value = token.intersect_by_span(name=field_name)
if not field_value:
return defaults
field_value = getattr(token, field_name)
if len(field_value) == 0 or field_value[0].id is None:
return defaults
return field_value[0].id

# if not hasattr(token, field_name):
# return defaults
# field_value = getattr(token, field_name)
# if len(field_value) == 0 or field_value[0].id is None:
# return defaults
# return field_value[0].id


# util
def convert_document_page_to_pdf_dict(doc: Document, page_width: int, page_height: int) -> Dict[str, List]:
def convert_document_page_to_pdf_dict(page: Entity, page_width: int, page_height: int) -> Dict[str, List]:
"""Convert a document to a dictionary of the form:
{
'words': ['word1', 'word2', ...],
Expand Down Expand Up @@ -146,7 +154,7 @@ def convert_document_page_to_pdf_dict(doc: Document, page_width: int, page_heigh
get_visual_group_id(token, RowsFieldName, -1), # line_ids
get_visual_group_id(token, BlocksFieldName, -1), # block_ids
)
for token in doc.tokens
for token in page.intersect_by_span(name=TokensFieldName)
]

words, bbox, line_ids, block_ids = (list(l) for l in zip(*token_data))
Expand Down Expand Up @@ -227,38 +235,43 @@ def from_pretrained(
def _predict(self, doc: Document, subpage_per_run: Optional[int] = None) -> List[Entity]:
page_prediction_results = []
for page_id, page in enumerate(doc.pages):
if page.tokens:
page_width, page_height = doc.images[page_id].pilimage.size

pdf_dict = self.preprocess(page, page_width=page_width, page_height=page_height)
# skip pages without tokens
tokens_on_page = page.intersect_by_span(name=TokensFieldName)
if not tokens_on_page:
continue

page_width, page_height = doc.images[page_id].pilimage.size

pdf_dict = self.preprocess(page=page, page_width=page_width, page_height=page_height)

model_predictions = self.predictor.predict(
page_data=pdf_dict,
page_size=(page_width, page_height),
batch_size=subpage_per_run or self.subpage_per_run,
return_type="list",
)
model_predictions = self.predictor.predict(
page_data=pdf_dict,
page_size=(page_width, page_height),
batch_size=subpage_per_run or self.subpage_per_run,
return_type="list",
)

assert len(model_predictions) == len(
page.tokens
), f"Model predictions and tokens are not the same length ({len(model_predictions)} != {len(page.tokens)}) for page {page_id}"
assert len(model_predictions) == len(
tokens_on_page
), f"Model predictions and tokens are not the same length ({len(model_predictions)} != {len(tokens_on_page)}) for page {page_id}"

page_prediction_results.extend(self.postprocess(page, model_predictions))
page_prediction_results.extend(self.postprocess(page=page, model_predictions=model_predictions))

return page_prediction_results

def preprocess(self, page: Document, page_width: float, page_height: float) -> Dict:
def preprocess(self, page: Entity, page_width: float, page_height: float) -> Dict:
# In the latest vila implementations (after 0.4.0), the predictor will
# handle all other preprocessing steps given the pdf_dict input format.

return convert_document_page_to_pdf_dict(page, page_width=page_width, page_height=page_height)
return convert_document_page_to_pdf_dict(page=page, page_width=page_width, page_height=page_height)

def postprocess(self, doc: Document, model_predictions) -> List[Entity]:
def postprocess(self, page: Entity, model_predictions) -> List[Entity]:
token_prediction_spans = convert_sequence_tagging_to_spans(model_predictions)

prediction_spans = []
for token_start, token_end, label in token_prediction_spans:
cur_spans = doc.tokens[token_start:token_end]
cur_spans = page.intersect_by_span(name=TokensFieldName)[token_start:token_end]

start = min([ele.start for ele in cur_spans])
end = max([ele.end for ele in cur_spans])
Expand Down
Loading

0 comments on commit 4cb681e

Please sign in to comment.