Skip to content

Commit

Permalink
Merge branch 'main' into python311-support
Browse files Browse the repository at this point in the history
  • Loading branch information
kyleclo authored Sep 6, 2023
2 parents 311e240 + 8dcb080 commit 3922ee6
Showing 1 changed file with 109 additions and 0 deletions.
109 changes: 109 additions & 0 deletions tests/test_predictors/test_bio_tagger_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""
@kylel, benjaminn
"""

import json
import pathlib
import unittest

import transformers

from papermage.magelib import Document, Entity, Span
from papermage.parsers import PDFPlumberParser
from papermage.predictors.hf_predictors.bio_tagger_predictor import HFBIOTaggerPredictor, BIOPrediction

TEST_SCIBERT_WEIGHTS = "allenai/scibert_scivocab_uncased"


class TestBioTaggerPredictor(unittest.TestCase):
def setUp(self):
transformers.set_seed(407)
self.fixture_path = pathlib.Path(__file__).parent.parent / "fixtures"
with open(self.fixture_path / "entity_classification_predictor_test_doc_papermage.json", "r") as f:
test_doc_json = json.load(f)
self.doc = Document.from_json(doc_json=test_doc_json)
ent1 = Entity(spans=[Span(start=86, end=456)])
ent2 = Entity(spans=[Span(start=457, end=641)])
self.doc.annotate_entity(field_name="bibs", entities=[ent1, ent2])

# self.predictor = HFBIOTaggerPredictor.from_pretrained(
# model_name_or_path=TEST_SCIBERT_WEIGHTS, entity_name="tokens", context_name="pages"
# )
self.id2label = {0: "O", 1: "B_Label", 2: "I_Label"}
self.label2id = {label: id_ for id_, label in self.id2label.items()}
self.predictor = HFBIOTaggerPredictor.from_pretrained(
model_name_or_path=TEST_SCIBERT_WEIGHTS,
entity_name="tokens",
context_name="pages",
**{"num_labels": len(self.id2label), "id2label": self.id2label, "label2id": self.label2id},
)

def test_predict_pages_tokens(self):
predictor = HFBIOTaggerPredictor.from_pretrained(
model_name_or_path=TEST_SCIBERT_WEIGHTS,
entity_name="tokens",
context_name="pages",
**{"num_labels": len(self.id2label), "id2label": self.id2label, "label2id": self.label2id},
)
token_tags = predictor.predict(doc=self.doc)
# import pytest

# pytest.set_trace()
assert len(token_tags) == 340

self.doc.annotate_entity(field_name="token_tags", entities=token_tags)
for token_tag in token_tags:
assert isinstance(token_tag.metadata.label, str)
assert isinstance(token_tag.metadata.score, float)

def test_predict_bibs_tokens(self):
self.predictor.context_name = "bibs"
token_tags = self.predictor.predict(doc=self.doc)
assert len(token_tags) == 38

def test_missing_fields(self):
self.predictor.entity_name = "OHNO"
with self.assertRaises(AssertionError) as e:
self.predictor.predict(doc=self.doc)
assert "OHNO" in e.exception

self.predictor.entity_name = "tokens"
self.predictor.context_name = "BLABLA"
with self.assertRaises(AssertionError) as e:
self.predictor.predict(doc=self.doc)
assert "BLABLA" in e.exception

self.predictor.context_name = "pages"

def test_predict_pages_tokens_roberta(self):
predictor = HFBIOTaggerPredictor.from_pretrained(
model_name_or_path="roberta-base",
entity_name="tokens",
context_name="pages",
add_prefix_space=True, # Needed for roberta
**{"num_labels": len(self.id2label), "id2label": self.id2label, "label2id": self.label2id},
)
token_tags = predictor.predict(doc=self.doc)
assert len(token_tags) == 924

self.doc.annotate_entity(field_name="token_tags", entities=token_tags)
for token_tag in token_tags:
assert isinstance(token_tag.metadata.label, str)
assert isinstance(token_tag.metadata.score, float)

# def test_postprocess(self):
# self.predictor.postprocess(
# doc=self.doc,
# context_name="pages",
# preds=[
# BIOPrediction(context_id=0, entity_id=0, label="B-Label", score=0.4),
# BIOPrediction(context_id=0, entity_id=1, label="I-Label", score=0.2),
# BIOPrediction(context_id=0, entity_id=2, label="O", score=0.3),
# BIOPrediction(context_id=0, entity_id=3, label=None, score=None),
# BIOPrediction(context_id=0, entity_id=4, label="B-Label", score=0.4),
# BIOPrediction(context_id=0, entity_id=5, label=None, score=None),
# BIOPrediction(context_id=0, entity_id=6, label="I-Label", score=0.2),
# ],
# )

0 comments on commit 3922ee6

Please sign in to comment.