Skip to content

Commit

Permalink
Merge pull request #39 from allenai/soldni/more_eval
Browse files Browse the repository at this point in the history
Fixes for grobid eval
  • Loading branch information
soldni authored Aug 7, 2023
2 parents 4b3187e + 10045f2 commit 1637601
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 117 deletions.
145 changes: 77 additions & 68 deletions papermage/parsers/grobid_parser.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
"""
@geli-gel
@geli-gel, @amanpreet692, @soldni
"""
import json
import os
import re
import warnings
import xml.etree.ElementTree as et
from collections import defaultdict
from tempfile import NamedTemporaryFile
from typing import Any, Dict, List, Optional, cast
from typing import Any, Dict, List, Optional, Tuple, cast

import numpy as np
from grobid_client.grobid_client import GrobidClient

from papermage.magelib import (
Annotation,
Document,
Entity,
Metadata,
Expand Down Expand Up @@ -67,9 +67,9 @@
}


def find_zero_spans(array):
# Add a sentinel value at the end to not miss end of last span
array = np.append(array, 1)
def find_contiguous_ones(array):
# Add a sentinel value at the beginning/end
array = np.concatenate([[0], array, [0]])

# Find the indexes where the array changes
diff_indices = np.where(np.diff(array) != 0)[0] + 1
Expand All @@ -80,7 +80,7 @@ def find_zero_spans(array):
spans = list(zip(zero_start_indices, zero_end_indices))

# Exclude the spans with no element
spans = [(start, end + 1) for start, end in spans if end - start >= 0]
spans = [(start - 1, end) for start, end in spans if end - start >= 0]

return spans

Expand All @@ -95,13 +95,14 @@ class GrobidFullParser(Parser):
> docker run -t --rm -p 8070:8070 lfoppiano/grobid:0.7.2
"""

def __init__(self, grobid_config: Optional[dict] = None, check_server: bool = True):
self.grobid_config = grobid_config or {
def __init__(self, check_server: bool = True, **grobid_config: Any):
self.grobid_config = {
"grobid_server": "http://localhost:8070",
"batch_size": 1000,
"sleep_time": 5,
"timeout": 60,
"coordinates": sorted(set((*GROBID_VILA_MAP.keys(), "s", "ref", "body", "item", "persName"))),
**grobid_config,
}
assert "coordinates" in self.grobid_config, "Grobid config must contain 'coordinates' key"

Expand All @@ -113,14 +114,19 @@ def __init__(self, grobid_config: Optional[dict] = None, check_server: bool = Tr

os.remove(config_path)

def parse(self, input_pdf_path: str, doc: Document, xml_out_dir: Optional[str] = None) -> Document:
def parse( # type: ignore
self,
input_pdf_path: str,
doc: Document,
xml_out_dir: Optional[str] = None
) -> Document:
assert doc.symbols != ""
for field in REQUIRED_DOCUMENT_FIELDS:
assert field in doc.fields

(_, _, xml) = self.client.process_pdf(
"processFulltextDocument",
input_pdf_path,
service="processFulltextDocument",
pdf_file=input_pdf_path,
generateIDs=False,
consolidate_header=False,
consolidate_citations=False,
Expand All @@ -137,16 +143,16 @@ def parse(self, input_pdf_path: str, doc: Document, xml_out_dir: Optional[str] =
with open(xmlfile, "w") as f_out:
f_out.write(xml)

self._parse_xml_onto_doc(xml, doc)
self._parse_xml_onto_doc(xml=xml, doc=doc)

for p in doc.p:
for p in getattr(doc, "p", []):
grobid_text_elems = [s.metadata["grobid_text"] for s in p.s]
grobid_text = " ".join(filter(lambda text: isinstance(text, str), grobid_text_elems))
p.metadata["grobid_text"] = grobid_text

# add vila-like entities
doc.annotate_entity(entities=self._make_vila_groups(doc), field_name="vila_entities")

vila_entities = self._make_vila_groups(doc)
doc.annotate_entity(entities=vila_entities, field_name="vila_entities")
return doc

def _make_spans_from_boxes(self, doc: Document, entity: Entity) -> List[Span]:
Expand All @@ -173,18 +179,20 @@ def _make_entities_of_type(
]
return entities

def _update_reserved_positions(self, reserved_positions: np.ndarray, entities: List[Entity]) -> List[Entity]:
def _update_reserved_positions(
self, reserved_positions: np.ndarray, entities: List[Entity]
) -> Tuple[List[Entity], np.ndarray]:
new_entities: List[Entity] = []
for ent in entities:
new_spans = []
for span in ent.spans:
already_reserved = reserved_positions[span.start : span.end]
for start, end in find_zero_spans(already_reserved):
new_spans.append(Span(start=start, end=end))
for start, end in find_contiguous_ones(~already_reserved):
new_spans.append(Span(start=start + span.start, end=end + span.start))
reserved_positions[span.start : span.end] = True
if new_spans:
new_entities.append(Entity(spans=new_spans, boxes=ent.boxes, metadata=ent.metadata))
return new_entities
return new_entities, reserved_positions

def _make_vila_groups(self, doc: Document) -> List[Entity]:
ents: List[Entity] = []
Expand All @@ -196,106 +204,104 @@ def _make_vila_groups(self, doc: Document) -> List[Entity]:

if h := getattr(doc, "author", []):
h_ = self._make_entities_of_type(doc=doc, entities=h, entity_type="Author", id_offset=len(ents))
h_ = self._update_reserved_positions(reserved_positions, h_)
ents.extend(h_)
h__, reserved_positions = self._update_reserved_positions(reserved_positions, h_)
ents.extend(h__)

if a := getattr(doc, "abstract", []):
a_ = self._make_entities_of_type(doc=doc, entities=a, entity_type="Abstract", id_offset=len(ents))
a_ = self._update_reserved_positions(reserved_positions, a_)
ents.extend(a_)
a__, reserved_positions = self._update_reserved_positions(reserved_positions, a_)
ents.extend(a__)

if _ := getattr(doc, "keywords", []):
# keywords has no coordinates, so we can't recover their positions!
pass

if s := getattr(doc, "head", []):
s_ = self._make_entities_of_type(doc=doc, entities=s, entity_type="Section", id_offset=len(ents))
s_ = self._update_reserved_positions(reserved_positions, s_)
ents.extend(s_)
s__, reserved_positions = self._update_reserved_positions(reserved_positions, s_)
ents.extend(s__)

if l := getattr(doc, "list", []):
l_ = self._make_entities_of_type(doc=doc, entities=l, entity_type="List", id_offset=len(ents))
l_ = self._update_reserved_positions(reserved_positions, l_)
ents.extend(l_)
if t := getattr(doc, "list", []):
t_ = self._make_entities_of_type(doc=doc, entities=t, entity_type="List", id_offset=len(ents))
t__, reserved_positions = self._update_reserved_positions(reserved_positions, t_)
ents.extend(t__)

if b := getattr(doc, "biblStruct", []):
b_ = self._make_entities_of_type(doc=doc, entities=b, entity_type="Bibliography", id_offset=len(ents))
b_ = self._update_reserved_positions(reserved_positions, b_)
ents.extend(b_)
b__, reserved_positions = self._update_reserved_positions(reserved_positions, b_)
ents.extend(b__)

if e := getattr(doc, "formula", []):
e_ = self._make_entities_of_type(doc=doc, entities=e, entity_type="Equation", id_offset=len(ents))
e_ = self._update_reserved_positions(reserved_positions, e_)
ents.extend(e_)
e__, reserved_positions = self._update_reserved_positions(reserved_positions, e_)
ents.extend(e__)

if figures := getattr(doc, "figure", []):
for figure in figures:
current_boxes = [Box(l=b.l, t=b.t, w=b.w, h=b.h, page=b.page) for b in figure.boxes]
if figs := getattr(doc, "figure", []):
for fig in figs:
current_boxes = [Box(l=b.l, t=b.t, w=b.w, h=b.h, page=b.page) for b in fig.boxes]

if "figDesc" in doc.fields:
caption_boxes = [b for d in doc.find_by_box(figure, "figDesc") for b in d.boxes]
caption_boxes = [b for d in doc.find_by_box(fig, "figDesc") for b in d.boxes]
current_boxes = [b for b in current_boxes if b not in caption_boxes]

if "table" in doc.fields:
table_boxes = [b for d in doc.find_by_box(figure, "table") for b in d.boxes]
table_boxes = [b for d in doc.find_by_box(fig, "table") for b in d.boxes]
current_boxes = [b for b in current_boxes if b not in table_boxes]

if not current_boxes:
continue

new_figure = Entity(
spans=None,
new_fig = Entity(
spans=self._make_spans_from_boxes(doc, Entity(boxes=current_boxes)),
boxes=current_boxes,
metadata=Metadata(**figure.metadata.to_json(), type="Figure", id=len(ents)),
metadata=Metadata(**fig.metadata.to_json(), label="Figure", id=len(ents)),
)
ents.append(new_figure)
new_figs, reserved_positions = self._update_reserved_positions(reserved_positions, [new_fig])
ents.extend(new_figs)

if t := getattr(doc, "table", []):
t_ = self._make_entities_of_type(doc=doc, entities=t, entity_type="Table", id_offset=len(ents))
t_ = self._update_reserved_positions(reserved_positions, t_)
ents.extend(t_)
t__, reserved_positions = self._update_reserved_positions(reserved_positions, t_)
ents.extend(t__)

if c := getattr(doc, "figDesc", []):
c_ = self._make_entities_of_type(doc=doc, entities=c, entity_type="Caption", id_offset=len(ents))
c_ = self._update_reserved_positions(reserved_positions, c_)
ents.extend(c_)
c__, reserved_positions = self._update_reserved_positions(reserved_positions, c_)
ents.extend(c__)

if _ := getattr(doc, "note", []):
# notes have no coordinates, so we can't recover their positions!
pass

if p := getattr(doc, "p", []):
p_ = self._make_entities_of_type(doc=doc, entities=p, entity_type="Paragraph", id_offset=len(ents))
p_ = self._update_reserved_positions(reserved_positions, p_)
ents.extend(p_)
p__, reserved_positions = self._update_reserved_positions(reserved_positions, p_)
ents.extend(p__)

return ents

def _parse_xml_onto_doc(self, xml: str, doc: Document) -> Document:
xml_root = et.fromstring(xml)
try:
xml_root = et.fromstring(xml)
except Exception as e:
if xml == "[GENERAL] An exception occurred while running Grobid.":
warnings.warn("Grobid returned an error; check server logs")
return doc
raise e

all_box_groups = self._get_box_groups(xml_root)
for field, box_groups in all_box_groups.items():
# span_groups = box_groups_to_span_groups(
# box_groups=box_groups, doc=doc, center=True
# )
# assert len(box_groups) == len(span_groups), (
# f"Annotations and SpanGroups for {field} are not the same length"
# )
# for bg, sg in zip(box_groups, span_groups):
# sg.metadata = bg.metadata
#
# # note for if/when adding in relations between mention sources and
# # bib targets: big_entries metadata contains original grobid id
# # attached to the Annotation.
doc.annotate_entity(field_name=field, entities=box_groups)

return doc

def _xml_coords_to_boxes(self, coords_attribute: str, page_sizes: dict):
def _xml_coords_to_boxes(self, coords_attribute: str, page_sizes: dict) -> List[Box]:
coords_list = coords_attribute.split(";")
boxes = []
for coords in coords_list:
if coords == "":
# this page has no coordinates
continue
pg, x, y, w, h = coords.split(",")
proper_page = int(pg) - 1
boxes.append(
Expand All @@ -305,7 +311,7 @@ def _xml_coords_to_boxes(self, coords_attribute: str, page_sizes: dict):
)
return boxes

def _get_box_groups(self, root: et.Element) -> Dict[str, List[Annotation]]:
def _get_box_groups(self, root: et.Element) -> Dict[str, List[Entity]]:
page_size_root = root.find(".//tei:facsimile", NS)
assert page_size_root is not None, "No facsimile found in Grobid XML"

Expand All @@ -314,7 +320,7 @@ def _get_box_groups(self, root: et.Element) -> Dict[str, List[Annotation]]:
for data in page_size_data:
page_sizes[int(data.attrib["n"]) - 1] = [float(data.attrib["lrx"]), float(data.attrib["lry"])]

all_boxes: Dict[str, List[Annotation]] = defaultdict(list)
all_boxes: Dict[str, List[Entity]] = defaultdict(list)

for field in self.grobid_config["coordinates"]:
structs = root.findall(f".//tei:{field}", NS)
Expand All @@ -326,7 +332,11 @@ def _get_box_groups(self, root: et.Element) -> Dict[str, List[Annotation]]:
if coords_str == "":
continue

boxes = self._xml_coords_to_boxes(coords_str, page_sizes)
if not (boxes := self._xml_coords_to_boxes(coords_str, page_sizes)):
# we check if the boxes are empty because sometimes users monkey-patch
# _xml_coords_to_boxes to filter by page
continue

metadata_dict: Dict[str, Any] = {
f"grobid_{re.sub(r'[^a-zA-Z0-9_]+', '_', k)}": v
for k, v in struct.attrib.items()
Expand All @@ -336,7 +346,6 @@ def _get_box_groups(self, root: et.Element) -> Dict[str, List[Annotation]]:
metadata_dict["grobid_text"] = struct.text
metadata = Metadata.from_json(metadata_dict)
box_group = Entity(boxes=boxes, metadata=metadata)
print(box_group)
all_boxes[field].append(box_group)

return all_boxes
Expand All @@ -354,6 +363,6 @@ def _get_box_groups(self, root: et.Element) -> Dict[str, List[Annotation]]:
doc = PDFPlumberParser().parse(opts.pdf_path)
doc = GrobidFullParser().parse(opts.pdf_path, doc)

for p in doc.p:
for p in getattr(doc, "p", []):
for s in p.s:
print(s.metadata.grobid_text)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = 'papermage'
version = '0.10.0'
version = '0.11.0'
description = 'Papermage. Casting magic over scientific PDFs.'
license = {text = 'Apache-2.0'}
readme = 'README.md'
Expand Down
Loading

0 comments on commit 1637601

Please sign in to comment.