Skip to content

Commit

Permalink
better repr for biotagger; added preprocess test
Browse files Browse the repository at this point in the history
  • Loading branch information
kyleclo committed Sep 7, 2023
1 parent a41c10d commit d61cc18
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 13 deletions.
2 changes: 1 addition & 1 deletion papermage/magelib/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(

def __repr__(self):
if self.doc:
return f"Annotated Entity:\tSpans: {True if self.spans else False}\tBoxes: {True if self.boxes else False}\nText: {self.text}"
return f"Annotated Entity:\tID: {self.id}\tSpans: {True if self.spans else False}\tBoxes: {True if self.boxes else False}\tText: {self.text}"
return f"Unannotated Entity: {self.to_json()}"

def to_json(self) -> Dict:
Expand Down
10 changes: 7 additions & 3 deletions papermage/predictors/hf_predictors/bio_tagger_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def __init__(
self.entity_ids = entity_ids
self.context_id = context_id

def __repr__(self) -> str:
return f"BIOBatch({self.__dict__})"


class BIOPrediction:
def __init__(self, context_id: int, entity_id: int, label: str, score: float):
Expand All @@ -54,6 +57,9 @@ def __init__(self, context_id: int, entity_id: int, label: str, score: float):
self.label = label
self.score = score

def __repr__(self) -> str:
return f"BIOPrediction({self.__dict__})"


class HFBIOTaggerPredictor(BasePredictor):
"""
Expand Down Expand Up @@ -346,9 +352,7 @@ def _predict(self, doc: Document) -> List[Annotation]:
return annotations

def _predict_batch(
self,
batch: BIOBatch,
device: Union[None, str, torch.device] = None
self, batch: BIOBatch, device: Union[None, str, torch.device] = None
) -> List[BIOPrediction]:
#
# preprocessing!! (padding & tensorification)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = 'papermage'
version = '0.12.0'
version = '0.12.1'
description = 'Papermage. Casting magic over scientific PDFs.'
license = {text = 'Apache-2.0'}
readme = 'README.md'
Expand Down
35 changes: 27 additions & 8 deletions tests/test_predictors/test_bio_tagger_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
import transformers

from papermage.magelib import Document, Entity, Span
from papermage.parsers import PDFPlumberParser
from papermage.predictors.hf_predictors.bio_tagger_predictor import HFBIOTaggerPredictor, BIOPrediction
from papermage.predictors.hf_predictors.bio_tagger_predictor import (
BIOBatch,
BIOPrediction,
HFBIOTaggerPredictor,
)

TEST_SCIBERT_WEIGHTS = "allenai/scibert_scivocab_uncased"

Expand All @@ -21,16 +24,16 @@ class TestBioTaggerPredictor(unittest.TestCase):
def setUp(self):
transformers.set_seed(407)
self.fixture_path = pathlib.Path(__file__).parent.parent / "fixtures"

# setup document
with open(self.fixture_path / "entity_classification_predictor_test_doc_papermage.json", "r") as f:
test_doc_json = json.load(f)
self.doc = Document.from_json(doc_json=test_doc_json)
ent1 = Entity(spans=[Span(start=86, end=456)])
ent2 = Entity(spans=[Span(start=457, end=641)])
self.doc.annotate_entity(field_name="bibs", entities=[ent1, ent2])

# self.predictor = HFBIOTaggerPredictor.from_pretrained(
# model_name_or_path=TEST_SCIBERT_WEIGHTS, entity_name="tokens", context_name="pages"
# )
# setup predictor
self.id2label = {0: "O", 1: "B_Label", 2: "I_Label"}
self.label2id = {label: id_ for id_, label in self.id2label.items()}
self.predictor = HFBIOTaggerPredictor.from_pretrained(
Expand All @@ -40,6 +43,25 @@ def setUp(self):
**{"num_labels": len(self.id2label), "id2label": self.id2label, "label2id": self.label2id},
)

def test_preprocess(self):
doc = Document(symbols="This is a test document.")
tokens = [
Entity(spans=[Span(start=0, end=4)]),
Entity(spans=[Span(start=5, end=7)]),
Entity(spans=[Span(start=8, end=9)]),
Entity(spans=[Span(start=10, end=14)]),
Entity(spans=[Span(start=15, end=23)]),
Entity(spans=[Span(start=23, end=24)]),
]
doc.annotate_entity(field_name="tokens", entities=tokens)
sents = [Entity(spans=[Span(start=0, end=24)])]
doc.annotate_entity(field_name="sents", entities=sents)

batches = self.predictor.preprocess(doc=doc, context_name="sents")
self.assertIsInstance(batches[0], BIOBatch)
decoded_batch = self.predictor.tokenizer.batch_decode(batches[0].input_ids)
self.assertListEqual(decoded_batch, ["[CLS] this is a test document. [SEP]"])

def test_predict_pages_tokens(self):
predictor = HFBIOTaggerPredictor.from_pretrained(
model_name_or_path=TEST_SCIBERT_WEIGHTS,
Expand All @@ -48,9 +70,6 @@ def test_predict_pages_tokens(self):
**{"num_labels": len(self.id2label), "id2label": self.id2label, "label2id": self.label2id},
)
token_tags = predictor.predict(doc=self.doc)
# import pytest

# pytest.set_trace()
assert len(token_tags) == 340

self.doc.annotate_entity(field_name="token_tags", entities=token_tags)
Expand Down

0 comments on commit d61cc18

Please sign in to comment.