Skip to content

Commit

Permalink
fix vila predictor to not rely on deprecated methods; fix core recipe…
Browse files Browse the repository at this point in the history
… to drop word predictor which is bugging out
  • Loading branch information
kyleclo committed Mar 14, 2024
1 parent bd4c9f1 commit 2385c23
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 38 deletions.
2 changes: 1 addition & 1 deletion papermage/magelib/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __getattr__(self, name: str) -> List["Entity"]:
# add method deprecation warning
logger = logging.getLogger(__name__)
logger.warning(
"Entity.layer is deprecated due to ambiguity and will be removed in a future release."
"Entity.__getattr__ is deprecated due to ambiguity and will be removed in a future release."
"Please use Entity.intersect_by_span or Entity.intersect_by_box instead."
)
try:
Expand Down
57 changes: 35 additions & 22 deletions papermage/predictors/vila_predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
@shannons, @kylel
"""

import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"
Expand All @@ -18,7 +19,7 @@

import torch
from tqdm import tqdm
from vila.predictors import LayoutIndicatorPDFPredictor, SimplePDFPredictor
from vila.predictors import LayoutIndicatorPDFPredictor

from papermage.magelib import (
BlocksFieldName,
Expand Down Expand Up @@ -102,16 +103,23 @@ def shift_index_sequence_to_zero_start(sequence):

# util
def get_visual_group_id(token: Entity, field_name: str, defaults=-1) -> int:
if not hasattr(token, field_name):
field_value = token.intersect_by_span(name=field_name)
if not field_value:
return defaults
field_value = getattr(token, field_name)
if len(field_value) == 0 or field_value[0].id is None:
return defaults
return field_value[0].id

# if not hasattr(token, field_name):
# return defaults
# field_value = getattr(token, field_name)
# if len(field_value) == 0 or field_value[0].id is None:
# return defaults
# return field_value[0].id


# util
def convert_document_page_to_pdf_dict(doc: Document, page_width: int, page_height: int) -> Dict[str, List]:
def convert_document_page_to_pdf_dict(page: Entity, page_width: int, page_height: int) -> Dict[str, List]:
"""Convert a document to a dictionary of the form:
{
'words': ['word1', 'word2', ...],
Expand Down Expand Up @@ -146,7 +154,7 @@ def convert_document_page_to_pdf_dict(doc: Document, page_width: int, page_heigh
get_visual_group_id(token, RowsFieldName, -1), # line_ids
get_visual_group_id(token, BlocksFieldName, -1), # block_ids
)
for token in doc.tokens
for token in page.intersect_by_span(name=TokensFieldName)
]

words, bbox, line_ids, block_ids = (list(l) for l in zip(*token_data))
Expand Down Expand Up @@ -227,38 +235,43 @@ def from_pretrained(
def _predict(self, doc: Document, subpage_per_run: Optional[int] = None) -> List[Entity]:
page_prediction_results = []
for page_id, page in enumerate(doc.pages):
if page.tokens:
page_width, page_height = doc.images[page_id].pilimage.size

pdf_dict = self.preprocess(page, page_width=page_width, page_height=page_height)
# skip pages without tokens
tokens_on_page = page.intersect_by_span(name=TokensFieldName)
if not tokens_on_page:
continue

page_width, page_height = doc.images[page_id].pilimage.size

pdf_dict = self.preprocess(page=page, page_width=page_width, page_height=page_height)

model_predictions = self.predictor.predict(
page_data=pdf_dict,
page_size=(page_width, page_height),
batch_size=subpage_per_run or self.subpage_per_run,
return_type="list",
)
model_predictions = self.predictor.predict(
page_data=pdf_dict,
page_size=(page_width, page_height),
batch_size=subpage_per_run or self.subpage_per_run,
return_type="list",
)

assert len(model_predictions) == len(
page.tokens
), f"Model predictions and tokens are not the same length ({len(model_predictions)} != {len(page.tokens)}) for page {page_id}"
assert len(model_predictions) == len(
tokens_on_page
), f"Model predictions and tokens are not the same length ({len(model_predictions)} != {len(tokens_on_page)}) for page {page_id}"

page_prediction_results.extend(self.postprocess(page, model_predictions))
page_prediction_results.extend(self.postprocess(page=page, model_predictions=model_predictions))

return page_prediction_results

def preprocess(self, page: Document, page_width: float, page_height: float) -> Dict:
def preprocess(self, page: Entity, page_width: float, page_height: float) -> Dict:
# In the latest vila implementations (after 0.4.0), the predictor will
# handle all other preprocessing steps given the pdf_dict input format.

return convert_document_page_to_pdf_dict(page, page_width=page_width, page_height=page_height)
return convert_document_page_to_pdf_dict(page=page, page_width=page_width, page_height=page_height)

def postprocess(self, doc: Document, model_predictions) -> List[Entity]:
def postprocess(self, page: Entity, model_predictions) -> List[Entity]:
token_prediction_spans = convert_sequence_tagging_to_spans(model_predictions)

prediction_spans = []
for token_start, token_end, label in token_prediction_spans:
cur_spans = doc.tokens[token_start:token_end]
cur_spans = page.intersect_by_span(name=TokensFieldName)[token_start:token_end]

start = min([ele.start for ele in cur_spans])
end = max([ele.end for ele in cur_spans])
Expand Down
38 changes: 23 additions & 15 deletions papermage/recipes/core_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@
"Bibliography": BibliographiesFieldName,
"Equation": EquationsFieldName,
"Algorithm": AlgorithmsFieldName,
"Figure": FiguresFieldName,
"Table": TablesFieldName,
# "Figure": FiguresFieldName,
# "Table": TablesFieldName,
"Caption": CaptionsFieldName,
"Header": HeadersFieldName,
"Footer": FootersFieldName,
Expand All @@ -89,17 +89,12 @@ def __init__(
self.parser = PDFPlumberParser()
self.rasterizer = PDF2ImageRasterizer()

with warnings.catch_warnings():
warnings.simplefilter("ignore")
self.word_predictor = SVMWordPredictor.from_path(svm_word_predictor_path)
# with warnings.catch_warnings():
# warnings.simplefilter("ignore")
# self.word_predictor = SVMWordPredictor.from_path(svm_word_predictor_path)

self.publaynet_block_predictor = LPEffDetPubLayNetBlockPredictor.from_pretrained()
self.ivila_predictor = IVILATokenClassificationPredictor.from_pretrained(ivila_predictor_path)
self.bio_roberta_predictor = HFBIOTaggerPredictor.from_pretrained(
bio_roberta_predictor_path,
entity_name="tokens",
context_name="pages",
)
self.sent_predictor = PysbdSentencePredictor()
self.logger.info("Finished instantiating recipe")

Expand All @@ -114,9 +109,9 @@ def from_pdf(self, pdf: Path) -> Document:
return self.from_doc(doc=doc)

def from_doc(self, doc: Document) -> Document:
self.logger.info("Predicting words...")
words = self.word_predictor.predict(doc=doc)
doc.annotate_layer(name=WordsFieldName, entities=words)
# self.logger.info("Predicting words...")
# words = self.word_predictor.predict(doc=doc)
# doc.annotate_layer(name=WordsFieldName, entities=words)

self.logger.info("Predicting sentences...")
sentences = self.sent_predictor.predict(doc=doc)
Expand All @@ -128,7 +123,20 @@ def from_doc(self, doc: Document) -> Document:
blocks = self.publaynet_block_predictor.predict(doc=doc)
doc.annotate_layer(name=BlocksFieldName, entities=blocks)

self.logger.info("Predicting vila...")
self.logger.info("Predicting figures and tables...")
figures = []
tables = []
for block in blocks:
if block.metadata.type == "Figure":
figure = Entity(boxes=block.boxes)
figures.append(figure)
elif block.metadata.type == "Table":
table = Entity(boxes=block.boxes)
tables.append(table)
doc.annotate_layer(name=FiguresFieldName, entities=figures)
doc.annotate_layer(name=TablesFieldName, entities=tables)

# self.logger.info("Predicting vila...")
vila_entities = self.ivila_predictor.predict(doc=doc)
doc.annotate_layer(name="vila_entities", entities=vila_entities)

Expand All @@ -138,7 +146,7 @@ def from_doc(self, doc: Document) -> Document:
[b for t in doc.intersect_by_span(entity, name=TokensFieldName) for b in t.boxes]
)
]
entity.text = make_text(entity=entity, document=doc)
# entity.text = make_text(entity=entity, document=doc)
preds = group_by(entities=vila_entities, metadata_field="label", metadata_values_map=VILA_LABELS_MAP)
doc.annotate(*preds)
return doc
Expand Down

0 comments on commit 2385c23

Please sign in to comment.