-
Notifications
You must be signed in to change notification settings - Fork 54
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Port over EntityClassifierPredictor inference
- Loading branch information
bnewm0609
committed
Jul 21, 2023
1 parent
91ef873
commit 061133e
Showing
6 changed files
with
533 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from dataclasses import dataclass | ||
from abc import abstractmethod | ||
from typing import Union, List, Dict, Any | ||
|
||
from papermage.types import Annotation, Document | ||
|
||
|
||
class BasePredictor: | ||
|
||
################################################################### | ||
##################### Necessary Model Variables ################### | ||
################################################################### | ||
|
||
# TODO[Shannon] Add the check for required backends in the future. | ||
# So different models might require different backends: | ||
# For example, LayoutLM only needs transformers, but LayoutLMv2 | ||
# needs transformers and Detectron2. It is the model creators' | ||
# responsibility to check the required backends. | ||
@property | ||
@abstractmethod | ||
def REQUIRED_BACKENDS(self): | ||
return None | ||
|
||
@property | ||
@abstractmethod | ||
def REQUIRED_DOCUMENT_FIELDS(self): | ||
"""Due to the dynamic nature of the document class as well the | ||
models, we require the model creator to provide a list of required | ||
fields in the document class. If not None, the predictor class | ||
will perform the check to ensure that the document contains all | ||
the specified fields. | ||
""" | ||
return None | ||
|
||
################################################################### | ||
######################### Core Methods ############################ | ||
################################################################### | ||
|
||
def _doc_field_checker(self, document: Document) -> None: | ||
if self.REQUIRED_DOCUMENT_FIELDS is not None: | ||
for field in self.REQUIRED_DOCUMENT_FIELDS: | ||
assert ( | ||
field in document.fields | ||
), f"The input Document object {document} doesn't contain the required field {field}" | ||
|
||
# TODO[Shannon] Allow for some preprocessed document input | ||
# representation for better performance? | ||
@abstractmethod | ||
def predict(self, document: Document) -> List[Annotation]: | ||
"""For all the predictors, the input is a document object, and | ||
the output is a list of annotations. | ||
""" | ||
self._doc_field_checker(document) | ||
return [] |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from abc import abstractmethod | ||
from typing import Union, List, Dict, Any | ||
|
||
from transformers import AutoTokenizer, AutoConfig, AutoModel | ||
|
||
from papermage.types import Annotation, Document | ||
from papermage.predictors.base_predictors.base_predictor import BasePredictor | ||
|
||
|
||
class BaseHFPredictor(BasePredictor): | ||
REQUIRED_BACKENDS = ["transformers", "torch"] | ||
|
||
def __init__(self, model: Any, config: Any, tokenizer: Any): | ||
|
||
self.model = model | ||
self.config = config | ||
self.tokenizer = tokenizer | ||
|
||
@classmethod | ||
def from_pretrained(cls, model_name_or_path: str, *args, **kwargs): | ||
config = AutoConfig.from_pretrained(model_name_or_path) | ||
model = AutoModel.from_pretrained( | ||
model_name_or_path, config=config, *args, **kwargs | ||
) | ||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) | ||
return cls(model, config, tokenizer) | ||
|
||
@abstractmethod | ||
def preprocess(self, document: Document) -> List: | ||
"""Convert the input document into the format that is required | ||
by the model. | ||
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def postprocess(self, model_outputs: Any) -> List[Annotation]: | ||
"""Convert the model outputs into the Annotation format""" | ||
raise NotImplementedError |
Oops, something went wrong.