Skip to content

Commit

Permalink
Merge pull request #42 from allenai/soldni/api
Browse files Browse the repository at this point in the history
adding vila sections in recipe
  • Loading branch information
soldni committed Aug 7, 2023
2 parents 7638bb3 + 765bcef commit c9a4084
Show file tree
Hide file tree
Showing 17 changed files with 280 additions and 95 deletions.
3 changes: 2 additions & 1 deletion papermage/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .utils.version import get_version
from .magelib import *
from .utils import get_version, group_by

__version__ = get_version()
95 changes: 64 additions & 31 deletions papermage/magelib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,49 +7,82 @@
"""


from papermage.magelib.image import Image
from papermage.magelib.span import Span
from papermage.magelib.box import Box
from papermage.magelib.metadata import Metadata
from papermage.magelib.annotation import Annotation
from papermage.magelib.entity import Entity
from papermage.magelib.indexer import EntitySpanIndexer, EntityBoxIndexer
from papermage.magelib.document import Document
from papermage.magelib.document import (
MetadataFieldName,
from .annotation import Annotation
from .box import Box
from .document import (
AbstractsFieldName,
AlgorithmsFieldName,
AuthorsFieldName,
BibliographiesFieldName,
BlocksFieldName,
CaptionsFieldName,
Document,
Prediction,
EntitiesFieldName,
SymbolsFieldName,
RelationsFieldName,
EquationsFieldName,
FiguresFieldName,
FootersFieldName,
FootnotesFieldName,
HeadersFieldName,
ImagesFieldName,
KeywordsFieldName,
ListsFieldName,
MetadataFieldName,
PagesFieldName,
TokensFieldName,
ParagraphsFieldName,
RelationsFieldName,
RowsFieldName,
BlocksFieldName,
ImagesFieldName,
WordsFieldName,
SectionsFieldName,
SentencesFieldName,
ParagraphsFieldName
SymbolsFieldName,
TablesFieldName,
TitlesFieldName,
TokensFieldName,
WordsFieldName,
)
from .entity import Entity
from .image import Image
from .indexer import EntityBoxIndexer, EntitySpanIndexer
from .metadata import Metadata
from .span import Span

__all__ = [
"Document",
"Annotation" "Entity",
"Relation",
"Span",
"AbstractsFieldName",
"AlgorithmsFieldName",
"Annotation",
"AuthorsFieldName",
"BibliographiesFieldName",
"BlocksFieldName",
"Box",
"Prediction",
"CaptionsFieldName",
"Document",
"EntitiesFieldName",
"Entity",
"EntityBoxIndexer",
"EntitySpanIndexer",
"EquationsFieldName",
"FiguresFieldName",
"FootersFieldName",
"FootnotesFieldName",
"HeadersFieldName",
"Image",
"ImagesFieldName",
"KeywordsFieldName",
"KeywordsFieldName",
"ListsFieldName",
"Metadata",
"EntitySpanIndexer",
"EntityBoxIndexer",
"ImageFieldName",
"SymbolsFieldName",
"MetadataFieldName",
"EntitiesFieldName",
"RelationsFieldName",
"PagesFieldName",
"TokensFieldName",
"ParagraphsFieldName",
"RelationsFieldName",
"RowsFieldName",
"BlocksFieldName",
"WordsFieldName",
"SectionsFieldName",
"SentencesFieldName",
"ParagraphsFieldName",
"Span",
"SymbolsFieldName",
"TablesFieldName",
"TitlesFieldName",
"TokensFieldName",
"WordsFieldName",
]
4 changes: 1 addition & 3 deletions papermage/magelib/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,11 @@
"""

import logging
from abc import abstractmethod
from typing import TYPE_CHECKING, Dict, List, Optional, Union

if TYPE_CHECKING:
# from papermage.types.document import Document
pass
from .document import Document


class Annotation:
Expand Down
2 changes: 1 addition & 1 deletion papermage/magelib/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import numpy as np

from papermage.magelib import Span
from .span import Span


class _BoxSpan(Span):
Expand Down
52 changes: 36 additions & 16 deletions papermage/magelib/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,16 @@
"""

from typing import Dict, Iterable, List, Optional, Union

from papermage.magelib import (
Box,
Entity,
EntityBoxIndexer,
EntitySpanIndexer,
Image,
Metadata,
Span,
)
from itertools import chain
from typing import Dict, List, NamedTuple, Optional, Tuple, Union

from .span import Span
from .box import Box
from .image import Image
from .metadata import Metadata
from .entity import Entity
from .indexer import EntitySpanIndexer, EntityBoxIndexer


# document field names
SymbolsFieldName = "symbols"
Expand All @@ -32,6 +31,28 @@
SentencesFieldName = "sentences"
ParagraphsFieldName = "paragraphs"

# these come from vila
TitlesFieldName = "titles"
AuthorsFieldName = "authors"
AbstractsFieldName = "abstracts"
KeywordsFieldName = "keywords"
SectionsFieldName = "sections"
ListsFieldName = "lists"
BibliographiesFieldName = "bibliographies"
EquationsFieldName = "equations"
AlgorithmsFieldName = "algorithms"
FiguresFieldName = "figures"
TablesFieldName = "tables"
CaptionsFieldName = "captions"
HeadersFieldName = "headers"
FootersFieldName = "footers"
FootnotesFieldName = "footnotes"


class Prediction(NamedTuple):
name: str
entities: List[Entity]


class Document:
SPECIAL_FIELDS = [SymbolsFieldName, ImagesFieldName, MetadataFieldName, EntitiesFieldName, RelationsFieldName]
Expand Down Expand Up @@ -73,11 +94,10 @@ def check_field_name_availability(self, field_name: str) -> None:
def get_entity(self, field_name: str) -> List[Entity]:
return getattr(self, field_name)

def annotate(self, field_name: str, entities: List[Entity]) -> None:
if all(isinstance(e, Entity) for e in entities):
self.annotate_entity(field_name=field_name, entities=entities)
else:
raise NotImplementedError(f"entity list contains non-entities: {[type(e) for e in entities]}")
def annotate(self, *predictions: Union[Prediction, Tuple[Prediction, ...]]) -> None:
all_preds = chain.from_iterable([p] if isinstance(p, Prediction) else p for p in predictions)
for prediction in all_preds:
self.annotate_entity(field_name=prediction.name, entities=prediction.entities)

def annotate_entity(self, field_name: str, entities: List[Entity]) -> None:
self.check_field_name_availability(field_name=field_name)
Expand Down
11 changes: 6 additions & 5 deletions papermage/magelib/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
"""

from typing import TYPE_CHECKING, Dict, List, Optional, Union
from typing import Dict, List, Optional, Union

from papermage.magelib import Annotation, Box, Image, Metadata, Span

if TYPE_CHECKING:
from papermage.magelib.document import Document
from .annotation import Annotation
from .box import Box
from .image import Image
from .metadata import Metadata
from .span import Span


class Entity(Annotation):
Expand Down
1 change: 0 additions & 1 deletion papermage/magelib/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
"""

import base64
import logging
import os
from io import BytesIO

Expand Down
4 changes: 3 additions & 1 deletion papermage/magelib/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
import numpy as np
from ncls import NCLS

from papermage.magelib import Annotation, Box, Entity
from .annotation import Annotation
from .box import Box
from .entity import Entity


class Indexer:
Expand Down
2 changes: 1 addition & 1 deletion papermage/magelib/span.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""

from typing import List, Optional, Dict, Tuple, List
from typing import List, Dict, List

from collections import defaultdict

Expand Down
3 changes: 3 additions & 0 deletions papermage/predictors/hf_predictors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"
20 changes: 20 additions & 0 deletions papermage/predictors/hf_predictors/vila_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,26 @@
MAX_PAGE_WIDTH = 1000
MAX_PAGE_HEIGHT = 1000

# these are the labels that are used in the VILA model
VILA_LABELS = [
"Title",
"Author",
"Abstract",
"Keywords",
"Section",
"Paragraph",
"List",
"Bibliography",
"Equation",
"Algorithm",
"Figure",
"Table",
"Caption",
"Header",
"Footer",
"Footnote"
]


# util
def columns_used_in_model_inputs(model):
Expand Down
23 changes: 20 additions & 3 deletions papermage/predictors/sklearn_predictors/word_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,39 @@
import tarfile
import tempfile
from collections import defaultdict
from typing import Dict, List, Set, Tuple
from typing import Dict, List, Set, Tuple, Union
from urllib.parse import urlparse

import numpy as np
import requests
from joblib import load
from scipy.sparse import hstack

from papermage.magelib import Document, Entity, Metadata, Span
from papermage.magelib import Document, Entity, Metadata, Span, WordsFieldName, Annotation
from papermage.parsers.pdfplumber_parser import PDFPlumberParser
from papermage.predictors.base_predictor import BasePredictor
from papermage.predictors.hf_predictors.whitespace_predictor import WhitespacePredictor


logger = logging.getLogger(__name__)


def make_text(entity: Union[Entity, Annotation], document: Document, field: str = WordsFieldName) -> str:
candidate_words = document.find_by_span(entity, field)
candidate_text: List[str] = []

for i in range(len(candidate_words)):
candidate_text.append(str(candidate_words[i].text))
if i < len(candidate_words) - 1:
next_word_start = candidate_words[i + 1].start
curr_word_end = candidate_words[i].end
assert isinstance(next_word_start, int), f"{candidate_words[i + 1]} has no span (non-int start)"
assert isinstance(curr_word_end, int), f"{candidate_words[i]} has no span (non-int end)"
if curr_word_end != next_word_start:
candidate_text.append(document.symbols[curr_word_end : next_word_start])

return "".join(candidate_text)


class IsWordResult:
def __init__(self, original: str, new: str, is_edit: bool) -> None:
self.original = original
Expand Down
7 changes: 4 additions & 3 deletions papermage/predictors/spacy_predictors/sentence_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Span,
TokensFieldName,
WordsFieldName,
Annotation
)
from papermage.predictors.base_predictor import BasePredictor

Expand Down Expand Up @@ -79,7 +80,7 @@ class PysbdSentencePredictor(BasePredictor):
"""

REQUIRED_BACKENDS = ["pysbd"]
REQUIRED_DOCUMENT_FIELDS = [PagesFieldName, TokensFieldName] # type: ignore
REQUIRED_DOCUMENT_FIELDS = [TokensFieldName] # type: ignore

def __init__(self) -> None:
self._segmenter = pysbd.Segmenter(language="en", clean=False, char_span=True)
Expand Down Expand Up @@ -119,7 +120,7 @@ def split_token_based_on_sentences_boundary(self, words: List[str]) -> List[Tupl
token_id_start = token_id_end
return split

def predict(self, doc: Document) -> List[Entity]:
def _predict(self, doc: Document) -> List[Annotation]:
if hasattr(doc, WordsFieldName):
words = [word.text for word in getattr(doc, WordsFieldName)]
attr_name = WordsFieldName
Expand All @@ -131,7 +132,7 @@ def predict(self, doc: Document) -> List[Entity]:

split = self.split_token_based_on_sentences_boundary(words)

sentence_spans = []
sentence_spans: List[Annotation] = []
for start, end in split:
if end - start == 0:
continue
Expand Down
Loading

0 comments on commit c9a4084

Please sign in to comment.