Skip to content

Commit

Permalink
Kylel/cleanup predictors (#48)
Browse files Browse the repository at this point in the history
* move span qa predictor

* HUGE reorganization of predictors into different files; fix recipes and tests

* increment pyproject; add sys.path debug to github workflow

* print sys path after

* move base predictor up

* making import relative

* basepredictors

* commit predictors

* remove circular imports
  • Loading branch information
kyleclo committed Sep 21, 2023
1 parent e8e08cf commit aecadfc
Show file tree
Hide file tree
Showing 34 changed files with 265 additions and 90 deletions.
9 changes: 8 additions & 1 deletion .github/workflows/papermage-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,17 @@ jobs:
python-version: ${{ matrix.python-version }}
architecture: x64

- name: Test with Python ${{ matrix.python-version }}
- name: Install dependencies
run: |
sudo apt-get update
sudo apt-get -y install poppler-utils
pip install --upgrade pip
pip install -e .[dev,predictors,visualizers]
- name: Print sys.path
run: |
python -c "import sys; print(sys.path)"
- name: Test with Python ${{ matrix.python-version }}
run: |
pytest --cov-fail-under=42 --log-disable=pdfminer.psparser --log-disable=pdfminer.pdfinterp --log-disable=pdfminer.cmapdb --log-disable=pdfminer.pdfdocument --log-disable=pdfminer.pdffont --log-disable=pdfminer.pdfparser --log-disable=pdfminer.converter --log-disable=pdfminer.converter --log-disable=pdfminer.pdfpage
26 changes: 26 additions & 0 deletions papermage/predictors/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,30 @@
# Predictors

Some rules about Predictors:

* Each Predictor is named after the type of Entities it produces. For example, `sentences` come from the `SentencePredictor`.

* Organization looks like:

```
predictors/
|-- base_predictors/
|-- hf_predictors.py
|-- lp_predictors.py
|-- api_predictors.py
|-- spacy_predictors.py
|-- sklearn_predictors.py
|-- block_predictors.py
|-- paragraph_predictors.py
|-- sentence_predictors.py
|-- word_predictors.py
|-- token_predictors.py
```
* Note that `base_predictors/` contains reusable implementations; users never have to import them, but developers may want to reuse these. Users of the library instead import the desired predictor for a given emitted entity type.


* We try to name our predictors `[Framework][Model][Dataset][Entity]`.

## `SpanQAPredictor` (Using GPT3 as a Predictor)

The `span_qa.predictor.py` file includes an example of using the `decontext` library to use GPT3 as a predictor. The example involves span-based classification: for example, a a user can highlight a span of text in a paper and ask a question about it. (The span is a field, and the question is metadata on the field.) The predictor runs retrieval over specified the specified units and feeds the question, context, and highlighted span to GPT3 to answer the question. See `tests/test_predictors/test_span_qa_predictor.py` for examples of how this predictor is used.
28 changes: 14 additions & 14 deletions papermage/predictors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
from papermage.predictors.api_predictors.span_qa_predictor import APISpanQAPredictor
from papermage.predictors.hf_predictors.bio_tagger_predictor import HFBIOTaggerPredictor
from papermage.predictors.hf_predictors.vila_predictor import (
IVILATokenClassificationPredictor,
)
from papermage.predictors.hf_predictors.whitespace_predictor import WhitespacePredictor
from papermage.predictors.lp_predictors.block_predictor import LPBlockPredictor
from papermage.predictors.sklearn_predictors.word_predictor import SVMWordPredictor
from papermage.predictors.spacy_predictors.sentence_predictor import (
PysbdSentencePredictor,
)
from papermage.predictors.base_predictors.base_predictor import BasePredictor
from papermage.predictors.base_predictors.hf_predictors import HFBIOTaggerPredictor
from papermage.predictors.block_predictors import LPEffDetPubLayNetBlockPredictor
from papermage.predictors.formula_predictors import LPEffDetFormulaPredictor
from papermage.predictors.sentence_predictors import PysbdSentencePredictor
from papermage.predictors.span_qa_predictors import APISpanQAPredictor
from papermage.predictors.token_predictors import HFWhitspaceTokenPredictor
from papermage.predictors.vila_predictors import IVILATokenClassificationPredictor
from papermage.predictors.word_predictors import SVMWordPredictor

__all__ = [
"HFBIOTaggerPredictor",
"APISpanQAPredictor",
"LPBlockPredictor",
"IVILATokenClassificationPredictor",
"WhitespacePredictor",
"HFWhitspaceTokenPredictor",
"SVMWordPredictor",
"PysbdSentencePredictor",
"LPEffDetPubLayNetBlockPredictor",
"LPEffDetFormulaPredictor",
"APISpanQAPredictor",
"BasePredictor",
]
5 changes: 5 additions & 0 deletions papermage/predictors/base_predictors/api_predictors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""
@kylel
"""
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,7 @@
class BasePredictor:
@property
@abstractmethod
def REQUIRED_BACKENDS(self):
raise NotImplementedError

@property
@abstractmethod
def REQUIRED_DOCUMENT_FIELDS(self):
def REQUIRED_DOCUMENT_FIELDS(self) -> List[str]:
"""Due to the dynamic nature of the document class as well the
models, we require the model creator to provide a list of required
fields in the document class. If not None, the predictor class
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import os
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple, Union

Expand All @@ -19,7 +20,10 @@
)

from papermage.magelib import Annotation, Box, Document, Entity, Metadata, Span
from papermage.predictors.base_predictor import BasePredictor

from .base_predictor import BasePredictor

os.environ["TOKENIZERS_PARALLELISM"] = "false"


class BIOBatch:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
"""
Base predictors of bounding box detection models from layoutparser
@shannons, @kylel
"""

from typing import Any, Dict, List, Optional, Union

import layoutparser as lp
Expand All @@ -13,20 +21,22 @@
Metadata,
PagesFieldName,
)
from papermage.predictors.base_predictor import BasePredictor

from .base_predictor import BasePredictor


class LPBlockPredictor(BasePredictor):
REQUIRED_BACKENDS = ["layoutparser"]
REQUIRED_DOCUMENT_FIELDS = [PagesFieldName, ImagesFieldName]
class LPPredictor(BasePredictor):
@property
def REQUIRED_DOCUMENT_FIELDS(self) -> List[str]:
return [PagesFieldName, ImagesFieldName]

def __init__(self, model):
self.model = model

@classmethod
def from_pretrained(
cls,
config_path: str = "lp://efficientdet/PubLayNet",
config_path: str,
model_path: Optional[str] = None,
label_map: Optional[Dict] = None,
extra_config: Optional[Dict] = None,
Expand Down
5 changes: 5 additions & 0 deletions papermage/predictors/base_predictors/sklearn_predictors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""
@kylel
"""
5 changes: 5 additions & 0 deletions papermage/predictors/base_predictors/spacy_predictors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""
@kylel
"""
23 changes: 23 additions & 0 deletions papermage/predictors/block_predictors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
Predict visual blocks in a page.
@kylel
"""

from typing import Dict, Optional

from papermage.predictors.base_predictors.lp_predictors import LPPredictor


class LPEffDetPubLayNetBlockPredictor(LPPredictor):
@classmethod
def from_pretrained(
cls,
device: Optional[str] = None,
):
return super().from_pretrained(
config_path="lp://efficientdet/PubLayNet",
device=device,
)
23 changes: 23 additions & 0 deletions papermage/predictors/formula_predictors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
Predict formula blocks in a page.
@kylel
"""

from typing import Dict, Optional

from papermage.predictors.base_predictors.lp_predictors import LPPredictor


class LPEffDetFormulaPredictor(LPPredictor):
@classmethod
def from_pretrained(
cls,
device: Optional[str] = None,
):
return super().from_pretrained(
config_path="lp://efficientdet/MFD",
device=device,
)
3 changes: 0 additions & 3 deletions papermage/predictors/hf_predictors/__init__.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
TokensFieldName,
WordsFieldName,
)
from papermage.predictors.base_predictor import BasePredictor
from papermage.predictors import BasePredictor
from papermage.utils.merge import cluster_and_merge_neighbor_spans


Expand All @@ -35,8 +35,9 @@ class PysbdSentencePredictor(BasePredictor):
>>> doc.annotate(sentences=sentence_spans)
"""

REQUIRED_BACKENDS = ["pysbd"]
REQUIRED_DOCUMENT_FIELDS = [TokensFieldName] # type: ignore
@property
def REQUIRED_DOCUMENT_FIELDS(self) -> List[str]:
return [TokensFieldName] # type: ignore

def __init__(self) -> None:
self._segmenter = pysbd.Segmenter(language="en", clean=False, char_span=True)
Expand Down
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
"""
QA given an Entity of interest (with span)
@benjaminn
"""

import json
from typing import List

Expand All @@ -10,12 +18,10 @@
from decontext.step.qa import TemplateRetrievalQAStep

from papermage.magelib import Annotation, Document, Entity
from papermage.predictors.base_predictor import BasePredictor
from papermage.predictors import BasePredictor


class APISpanQAPredictor(BasePredictor):
REQUIRED_BACKENDS = ["transformers", "torch", "decontext"]

def __init__(self, context_unit_name: str = "paragraph", span_name: str = "user_selected_span"):
self.context_unit_name = context_unit_name
self.span_name = span_name
Expand Down Expand Up @@ -81,7 +87,7 @@ def _predict(self, doc: Document) -> List[Annotation]:
new_user_selected_span.metadata["question"] = paper_snippet.qae[0].question
new_user_selected_span.metadata["answer"] = paper_snippet.qae[0].answer
annotations.append(new_user_selected_span)

# add the context with span
context_with_span = getattr(getattr(doc, self.span_name)[0], self.context_unit_name)[0]
# context_with_span = doc.get_entity(self.context_unit_name)[paper_snippet.paragraph_with_snippet.index]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,29 @@
"""

import os
from typing import List, Optional, Set, Tuple

import tokenizers

from papermage.magelib import Document, Entity, Metadata, Span, TokensFieldName
from papermage.predictors.base_predictor import BasePredictor
os.environ["TOKENIZERS_PARALLELISM"] = "false"


class WhitespacePredictor(BasePredictor):
REQUIRED_BACKENDS = None
REQUIRED_DOCUMENT_FIELDS = []
from papermage.magelib import Document, Entity, Metadata, Span
from papermage.predictors import BasePredictor


class HFWhitspaceTokenPredictor(BasePredictor):
@property
def REQUIRED_DOCUMENT_FIELDS(self) -> List[str]:
return []

_dictionary: Optional[Set[str]] = None

def __init__(self) -> None:
self.whitespace_tokenizer = tokenizers.pre_tokenizers.WhitespaceSplit()

def predict(self, doc: Document) -> List[Entity]:
def _predict(self, doc: Document) -> List[Entity]:
self._doc_field_checker(doc)

# 1) whitespace tokenization on symbols. each token is a nested tuple ('text', (start, end))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
@shannons, @kylel
"""
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"

import inspect
import itertools
Expand All @@ -29,7 +31,7 @@
Span,
TokensFieldName,
)
from papermage.predictors.base_predictor import BasePredictor
from papermage.predictors import BasePredictor

# Two constants for the constraining the size of the page for
# inputs to the model.
Expand All @@ -54,7 +56,7 @@
"Caption",
"Header",
"Footer",
"Footnote"
"Footnote",
]


Expand Down Expand Up @@ -188,8 +190,10 @@ def convert_sequence_tagging_to_spans(


class BaseSinglePageTokenClassificationPredictor(BasePredictor):
REQUIRED_BACKENDS = ["transformers", "torch", "vila"]
REQUIRED_DOCUMENT_FIELDS = [PagesFieldName, TokensFieldName]
@property
def REQUIRED_DOCUMENT_FIELDS(self) -> List[str]:
return [PagesFieldName, TokensFieldName]

DEFAULT_SUBPAGE_PER_RUN = 2 # TODO: Might remove this in the future for longformer-like models

@property
Expand Down Expand Up @@ -221,7 +225,7 @@ def from_pretrained(

return cls(predictor, subpage_per_run)

def predict(self, doc: Document, subpage_per_run: Optional[int] = None) -> List[Annotation]:
def _predict(self, doc: Document, subpage_per_run: Optional[int] = None) -> List[Annotation]:
page_prediction_results = []
for page_id, page in enumerate(doc.pages):
if page.tokens:
Expand Down
Loading

0 comments on commit aecadfc

Please sign in to comment.