From 4cb681e436136668f556d88433980e58b42e1752 Mon Sep 17 00:00:00 2001 From: Kyle Lo Date: Thu, 14 Mar 2024 13:15:11 -0700 Subject: [PATCH] bugfix! figures missing? (#73) * bugfix! figures were not plottin properly because of weird interactino between getattr() and spans * revert a fix so dont break apis; add deprecation warnings * fix vila predictor to not rely on deprecated methods; fix core recipe to drop word predictor which is bugging out * move names into separate file * make default text operation smoother --- papermage/magelib/__init__.py | 14 +++--- papermage/magelib/document.py | 53 +++++++++-------------- papermage/magelib/entity.py | 13 ++++-- papermage/magelib/layer.py | 6 +++ papermage/magelib/names.py | 37 ++++++++++++++++ papermage/predictors/vila_predictors.py | 57 +++++++++++++++---------- papermage/recipes/core_recipe.py | 38 ++++++++++------- pyproject.toml | 2 +- 8 files changed, 138 insertions(+), 82 deletions(-) create mode 100644 papermage/magelib/names.py diff --git a/papermage/magelib/__init__.py b/papermage/magelib/__init__.py index dd1ab76..290f64d 100644 --- a/papermage/magelib/__init__.py +++ b/papermage/magelib/__init__.py @@ -6,18 +6,20 @@ """ - from .box import Box +from .document import Document, Prediction +from .entity import Entity +from .image import Image +from .indexer import EntityBoxIndexer, EntitySpanIndexer from .layer import Layer -from .document import ( +from .metadata import Metadata +from .names import ( AbstractsFieldName, AlgorithmsFieldName, AuthorsFieldName, BibliographiesFieldName, BlocksFieldName, CaptionsFieldName, - Document, - Prediction, EntitiesFieldName, EquationsFieldName, FiguresFieldName, @@ -40,10 +42,6 @@ TokensFieldName, WordsFieldName, ) -from .entity import Entity -from .image import Image -from .indexer import EntityBoxIndexer, EntitySpanIndexer -from .metadata import Metadata from .span import Span __all__ = [ diff --git a/papermage/magelib/document.py b/papermage/magelib/document.py index 54b4f85..61dce37 100644 --- a/papermage/magelib/document.py +++ b/papermage/magelib/document.py @@ -5,6 +5,7 @@ """ +import logging from itertools import chain from typing import Dict, List, NamedTuple, Optional, Tuple, Union @@ -13,40 +14,16 @@ from .image import Image from .layer import Layer from .metadata import Metadata +from .names import ( + EntitiesFieldName, + ImagesFieldName, + MetadataFieldName, + RelationsFieldName, + SymbolsFieldName, + TokensFieldName, +) from .span import Span -# document field names -SymbolsFieldName = "symbols" -ImagesFieldName = "images" -MetadataFieldName = "metadata" -EntitiesFieldName = "entities" -RelationsFieldName = "relations" - -PagesFieldName = "pages" -TokensFieldName = "tokens" -RowsFieldName = "rows" -BlocksFieldName = "blocks" -WordsFieldName = "words" -SentencesFieldName = "sentences" -ParagraphsFieldName = "paragraphs" - -# these come from vila -TitlesFieldName = "titles" -AuthorsFieldName = "authors" -AbstractsFieldName = "abstracts" -KeywordsFieldName = "keywords" -SectionsFieldName = "sections" -ListsFieldName = "lists" -BibliographiesFieldName = "bibliographies" -EquationsFieldName = "equations" -AlgorithmsFieldName = "algorithms" -FiguresFieldName = "figures" -TablesFieldName = "tables" -CaptionsFieldName = "captions" -HeadersFieldName = "headers" -FootersFieldName = "footers" -FootnotesFieldName = "footnotes" - class Prediction(NamedTuple): name: str @@ -198,7 +175,17 @@ def __repr__(self): def find(self, query: Union[Span, Box], name: str) -> List[Entity]: """Finds all entities that intersect with the query""" - return self.get_layer(name=name).find(query=query) + logger = logging.getLogger(__name__) + logger.warning( + "This method is deprecated due to ambiguity and will be removed in a future release." + "Please use Document.intersect_by_span or Document.intersect_by_box instead." + ) + if isinstance(query, Span): + return self.intersect_by_span(query=Entity(spans=[query]), name=name) + elif isinstance(query, Box): + return self.intersect_by_box(query=Entity(boxes=[query]), name=name) + else: + raise TypeError(f"Unsupported query type {type(query)}") def intersect_by_span(self, query: Entity, name: str) -> List[Entity]: """Finds all entities that intersect with the query""" diff --git a/papermage/magelib/entity.py b/papermage/magelib/entity.py index 1e374f8..8447019 100644 --- a/papermage/magelib/entity.py +++ b/papermage/magelib/entity.py @@ -4,15 +4,16 @@ """ +import logging from typing import TYPE_CHECKING, Dict, List, Optional, Union from .box import Box from .image import Image from .metadata import Metadata +from .names import TokensFieldName from .span import Span if TYPE_CHECKING: - from .document import TokensFieldName from .layer import Layer @@ -93,6 +94,12 @@ def id(self, id: int) -> None: def __getattr__(self, name: str) -> List["Entity"]: """This Overloading is convenient syntax since the `entity.layer` operation is intuitive for folks.""" + # add method deprecation warning + logger = logging.getLogger(__name__) + logger.warning( + "Entity.__getattr__ is deprecated due to ambiguity and will be removed in a future release." + "Please use Entity.intersect_by_span or Entity.intersect_by_box instead." + ) try: return self.intersect_by_span(name=name) except ValueError: @@ -164,10 +171,10 @@ def text(self) -> str: return maybe_text # return derived from symbols if self.symbols_from_spans: - return " ".join(self.symbols_from_spans) + return " ".join(self.symbols_from_spans).replace("\n", " ") # return derived from boxes and tokens if self.symbols_from_boxes: - return " ".join(self.symbols_from_boxes) + return " ".join(self.symbols_from_boxes).replace("\n", " ") return "" @text.setter diff --git a/papermage/magelib/layer.py b/papermage/magelib/layer.py index 2d7ba4f..846369a 100644 --- a/papermage/magelib/layer.py +++ b/papermage/magelib/layer.py @@ -6,6 +6,7 @@ """ +import logging from typing import TYPE_CHECKING, Dict, List, Optional, Union from .box import Box @@ -63,6 +64,11 @@ def from_json(cls, layer_json): return cls(entities=[Entity.from_json(entity_json) for entity_json in layer_json]) def find(self, query: Union[Span, Box]) -> List[Entity]: + logger = logging.getLogger(__name__) + logger.warning( + "This method is deprecated due to ambiguity and will be removed in a future release." + "Please use Layer.intersect_by_span or Layer.intersect_by_box instead." + ) if isinstance(query, Span): return self.intersect_by_span(query=Entity(spans=[query])) elif isinstance(query, Box): diff --git a/papermage/magelib/names.py b/papermage/magelib/names.py new file mode 100644 index 0000000..ae8046f --- /dev/null +++ b/papermage/magelib/names.py @@ -0,0 +1,37 @@ +""" + +@kylel + +""" + +# document field names +SymbolsFieldName = "symbols" +ImagesFieldName = "images" +MetadataFieldName = "metadata" +EntitiesFieldName = "entities" +RelationsFieldName = "relations" + +PagesFieldName = "pages" +TokensFieldName = "tokens" +RowsFieldName = "rows" +BlocksFieldName = "blocks" +WordsFieldName = "words" +SentencesFieldName = "sentences" +ParagraphsFieldName = "paragraphs" + +# these come from vila +TitlesFieldName = "titles" +AuthorsFieldName = "authors" +AbstractsFieldName = "abstracts" +KeywordsFieldName = "keywords" +SectionsFieldName = "sections" +ListsFieldName = "lists" +BibliographiesFieldName = "bibliographies" +EquationsFieldName = "equations" +AlgorithmsFieldName = "algorithms" +FiguresFieldName = "figures" +TablesFieldName = "tables" +CaptionsFieldName = "captions" +HeadersFieldName = "headers" +FootersFieldName = "footers" +FootnotesFieldName = "footnotes" diff --git a/papermage/predictors/vila_predictors.py b/papermage/predictors/vila_predictors.py index 76eaeaf..f7dcc39 100644 --- a/papermage/predictors/vila_predictors.py +++ b/papermage/predictors/vila_predictors.py @@ -7,6 +7,7 @@ @shannons, @kylel """ + import os os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -18,7 +19,7 @@ import torch from tqdm import tqdm -from vila.predictors import LayoutIndicatorPDFPredictor, SimplePDFPredictor +from vila.predictors import LayoutIndicatorPDFPredictor from papermage.magelib import ( BlocksFieldName, @@ -102,16 +103,23 @@ def shift_index_sequence_to_zero_start(sequence): # util def get_visual_group_id(token: Entity, field_name: str, defaults=-1) -> int: - if not hasattr(token, field_name): + field_value = token.intersect_by_span(name=field_name) + if not field_value: return defaults - field_value = getattr(token, field_name) if len(field_value) == 0 or field_value[0].id is None: return defaults return field_value[0].id + # if not hasattr(token, field_name): + # return defaults + # field_value = getattr(token, field_name) + # if len(field_value) == 0 or field_value[0].id is None: + # return defaults + # return field_value[0].id + # util -def convert_document_page_to_pdf_dict(doc: Document, page_width: int, page_height: int) -> Dict[str, List]: +def convert_document_page_to_pdf_dict(page: Entity, page_width: int, page_height: int) -> Dict[str, List]: """Convert a document to a dictionary of the form: { 'words': ['word1', 'word2', ...], @@ -146,7 +154,7 @@ def convert_document_page_to_pdf_dict(doc: Document, page_width: int, page_heigh get_visual_group_id(token, RowsFieldName, -1), # line_ids get_visual_group_id(token, BlocksFieldName, -1), # block_ids ) - for token in doc.tokens + for token in page.intersect_by_span(name=TokensFieldName) ] words, bbox, line_ids, block_ids = (list(l) for l in zip(*token_data)) @@ -227,38 +235,43 @@ def from_pretrained( def _predict(self, doc: Document, subpage_per_run: Optional[int] = None) -> List[Entity]: page_prediction_results = [] for page_id, page in enumerate(doc.pages): - if page.tokens: - page_width, page_height = doc.images[page_id].pilimage.size - pdf_dict = self.preprocess(page, page_width=page_width, page_height=page_height) + # skip pages without tokens + tokens_on_page = page.intersect_by_span(name=TokensFieldName) + if not tokens_on_page: + continue + + page_width, page_height = doc.images[page_id].pilimage.size + + pdf_dict = self.preprocess(page=page, page_width=page_width, page_height=page_height) - model_predictions = self.predictor.predict( - page_data=pdf_dict, - page_size=(page_width, page_height), - batch_size=subpage_per_run or self.subpage_per_run, - return_type="list", - ) + model_predictions = self.predictor.predict( + page_data=pdf_dict, + page_size=(page_width, page_height), + batch_size=subpage_per_run or self.subpage_per_run, + return_type="list", + ) - assert len(model_predictions) == len( - page.tokens - ), f"Model predictions and tokens are not the same length ({len(model_predictions)} != {len(page.tokens)}) for page {page_id}" + assert len(model_predictions) == len( + tokens_on_page + ), f"Model predictions and tokens are not the same length ({len(model_predictions)} != {len(tokens_on_page)}) for page {page_id}" - page_prediction_results.extend(self.postprocess(page, model_predictions)) + page_prediction_results.extend(self.postprocess(page=page, model_predictions=model_predictions)) return page_prediction_results - def preprocess(self, page: Document, page_width: float, page_height: float) -> Dict: + def preprocess(self, page: Entity, page_width: float, page_height: float) -> Dict: # In the latest vila implementations (after 0.4.0), the predictor will # handle all other preprocessing steps given the pdf_dict input format. - return convert_document_page_to_pdf_dict(page, page_width=page_width, page_height=page_height) + return convert_document_page_to_pdf_dict(page=page, page_width=page_width, page_height=page_height) - def postprocess(self, doc: Document, model_predictions) -> List[Entity]: + def postprocess(self, page: Entity, model_predictions) -> List[Entity]: token_prediction_spans = convert_sequence_tagging_to_spans(model_predictions) prediction_spans = [] for token_start, token_end, label in token_prediction_spans: - cur_spans = doc.tokens[token_start:token_end] + cur_spans = page.intersect_by_span(name=TokensFieldName)[token_start:token_end] start = min([ele.start for ele in cur_spans]) end = max([ele.end for ele in cur_spans]) diff --git a/papermage/recipes/core_recipe.py b/papermage/recipes/core_recipe.py index 79a5958..f0f4ffc 100644 --- a/papermage/recipes/core_recipe.py +++ b/papermage/recipes/core_recipe.py @@ -65,8 +65,8 @@ "Bibliography": BibliographiesFieldName, "Equation": EquationsFieldName, "Algorithm": AlgorithmsFieldName, - "Figure": FiguresFieldName, - "Table": TablesFieldName, + # "Figure": FiguresFieldName, + # "Table": TablesFieldName, "Caption": CaptionsFieldName, "Header": HeadersFieldName, "Footer": FootersFieldName, @@ -89,17 +89,12 @@ def __init__( self.parser = PDFPlumberParser() self.rasterizer = PDF2ImageRasterizer() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - self.word_predictor = SVMWordPredictor.from_path(svm_word_predictor_path) + # with warnings.catch_warnings(): + # warnings.simplefilter("ignore") + # self.word_predictor = SVMWordPredictor.from_path(svm_word_predictor_path) self.publaynet_block_predictor = LPEffDetPubLayNetBlockPredictor.from_pretrained() self.ivila_predictor = IVILATokenClassificationPredictor.from_pretrained(ivila_predictor_path) - self.bio_roberta_predictor = HFBIOTaggerPredictor.from_pretrained( - bio_roberta_predictor_path, - entity_name="tokens", - context_name="pages", - ) self.sent_predictor = PysbdSentencePredictor() self.logger.info("Finished instantiating recipe") @@ -114,9 +109,9 @@ def from_pdf(self, pdf: Path) -> Document: return self.from_doc(doc=doc) def from_doc(self, doc: Document) -> Document: - self.logger.info("Predicting words...") - words = self.word_predictor.predict(doc=doc) - doc.annotate_layer(name=WordsFieldName, entities=words) + # self.logger.info("Predicting words...") + # words = self.word_predictor.predict(doc=doc) + # doc.annotate_layer(name=WordsFieldName, entities=words) self.logger.info("Predicting sentences...") sentences = self.sent_predictor.predict(doc=doc) @@ -128,7 +123,20 @@ def from_doc(self, doc: Document) -> Document: blocks = self.publaynet_block_predictor.predict(doc=doc) doc.annotate_layer(name=BlocksFieldName, entities=blocks) - self.logger.info("Predicting vila...") + self.logger.info("Predicting figures and tables...") + figures = [] + tables = [] + for block in blocks: + if block.metadata.type == "Figure": + figure = Entity(boxes=block.boxes) + figures.append(figure) + elif block.metadata.type == "Table": + table = Entity(boxes=block.boxes) + tables.append(table) + doc.annotate_layer(name=FiguresFieldName, entities=figures) + doc.annotate_layer(name=TablesFieldName, entities=tables) + + # self.logger.info("Predicting vila...") vila_entities = self.ivila_predictor.predict(doc=doc) doc.annotate_layer(name="vila_entities", entities=vila_entities) @@ -138,7 +146,7 @@ def from_doc(self, doc: Document) -> Document: [b for t in doc.intersect_by_span(entity, name=TokensFieldName) for b in t.boxes] ) ] - entity.text = make_text(entity=entity, document=doc) + # entity.text = make_text(entity=entity, document=doc) preds = group_by(entities=vila_entities, metadata_field="label", metadata_values_map=VILA_LABELS_MAP) doc.annotate(*preds) return doc diff --git a/pyproject.toml b/pyproject.toml index 66f7e1f..e1a6f9d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = 'papermage' -version = '0.16.0' +version = '0.17.0' description = 'Papermage. Casting magic over scientific PDFs.' license = {text = 'Apache-2.0'} readme = 'README.md'