Skip to content

Commit

Permalink
Port over EntityClassifierPredictor inference
Browse files Browse the repository at this point in the history
  • Loading branch information
bnewm0609 committed Jul 21, 2023
1 parent 91ef873 commit 061133e
Show file tree
Hide file tree
Showing 6 changed files with 533 additions and 0 deletions.
54 changes: 54 additions & 0 deletions papermage/predictors/base_predictors/base_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from dataclasses import dataclass
from abc import abstractmethod
from typing import Union, List, Dict, Any

from papermage.types import Annotation, Document


class BasePredictor:

###################################################################
##################### Necessary Model Variables ###################
###################################################################

# TODO[Shannon] Add the check for required backends in the future.
# So different models might require different backends:
# For example, LayoutLM only needs transformers, but LayoutLMv2
# needs transformers and Detectron2. It is the model creators'
# responsibility to check the required backends.
@property
@abstractmethod
def REQUIRED_BACKENDS(self):
return None

@property
@abstractmethod
def REQUIRED_DOCUMENT_FIELDS(self):
"""Due to the dynamic nature of the document class as well the
models, we require the model creator to provide a list of required
fields in the document class. If not None, the predictor class
will perform the check to ensure that the document contains all
the specified fields.
"""
return None

###################################################################
######################### Core Methods ############################
###################################################################

def _doc_field_checker(self, document: Document) -> None:
if self.REQUIRED_DOCUMENT_FIELDS is not None:
for field in self.REQUIRED_DOCUMENT_FIELDS:
assert (
field in document.fields
), f"The input Document object {document} doesn't contain the required field {field}"

# TODO[Shannon] Allow for some preprocessed document input
# representation for better performance?
@abstractmethod
def predict(self, document: Document) -> List[Annotation]:
"""For all the predictors, the input is a document object, and
the output is a list of annotations.
"""
self._doc_field_checker(document)
return []
Empty file.
38 changes: 38 additions & 0 deletions papermage/predictors/hf_predictors/base_hf_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from abc import abstractmethod
from typing import Union, List, Dict, Any

from transformers import AutoTokenizer, AutoConfig, AutoModel

from papermage.types import Annotation, Document
from papermage.predictors.base_predictors.base_predictor import BasePredictor


class BaseHFPredictor(BasePredictor):
REQUIRED_BACKENDS = ["transformers", "torch"]

def __init__(self, model: Any, config: Any, tokenizer: Any):

self.model = model
self.config = config
self.tokenizer = tokenizer

@classmethod
def from_pretrained(cls, model_name_or_path: str, *args, **kwargs):
config = AutoConfig.from_pretrained(model_name_or_path)
model = AutoModel.from_pretrained(
model_name_or_path, config=config, *args, **kwargs
)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
return cls(model, config, tokenizer)

@abstractmethod
def preprocess(self, document: Document) -> List:
"""Convert the input document into the format that is required
by the model.
"""
raise NotImplementedError

@abstractmethod
def postprocess(self, model_outputs: Any) -> List[Annotation]:
"""Convert the model outputs into the Annotation format"""
raise NotImplementedError
Loading

0 comments on commit 061133e

Please sign in to comment.