Skip to content

Commit

Permalink
Merge pull request #47 from allenai/kylel/merge
Browse files Browse the repository at this point in the history
dedupe two implementations of same merging function; add tests
  • Loading branch information
kyleclo authored Sep 11, 2023
2 parents 1dc71bb + b347e07 commit 7652fd5
Show file tree
Hide file tree
Showing 8 changed files with 259 additions and 219 deletions.
127 changes: 8 additions & 119 deletions papermage/magelib/span.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@
"""

from typing import List, Dict, List

from collections import defaultdict
from typing import Dict, List


class Span:
__slots__ = ['start', 'end']
__slots__ = ["start", "end"]

def __init__(self, start: int, end: int):
self.start = start
Expand All @@ -29,14 +28,14 @@ def from_json(cls, span_json: List) -> "Span":
return Span(start=span_json[0], end=span_json[-1])

def __repr__(self):
return f'Span{self.to_json()}'
return f"Span{self.to_json()}"

def __eq__(self, other: object) -> bool:
if not isinstance(other, Span):
return False
return self.start == other.start and self.end == other.end

def __lt__(self, other: 'Span'):
def __lt__(self, other: "Span"):
"""Useful for sort(). Orders according to the start index.
If ties, then order according to the end index."""
if self.start == other.start:
Expand All @@ -46,19 +45,15 @@ def __lt__(self, other: 'Span'):
def __hash__(self) -> int:
return hash((self.start, self.end))

def is_overlap(self, other: 'Span') -> bool:
def is_overlap(self, other: "Span") -> bool:
"""Whether self overlaps with the other Span object."""
return (
self.start <= other.start < self.end
or other.start <= self.start < other.end
or self == other
)
return self.start <= other.start < self.end or other.start <= self.start < other.end or self == other

@classmethod
def create_enclosing_span(cls, spans: List['Span']) -> 'Span':
def create_enclosing_span(cls, spans: List["Span"]) -> "Span":
"""Create the narrowest Span that completely encloses all the input Spans."""
if not spans:
raise ValueError(f'`spans` should be non-empty.')
raise ValueError(f"`spans` should be non-empty.")
start = spans[0].start
end = spans[0].end
for span in spans[1:]:
Expand All @@ -67,109 +62,3 @@ def create_enclosing_span(cls, spans: List['Span']) -> 'Span':
if span.end > end:
end = span.end
return Span(start=start, end=end)


class MergeClusterSpans:
"""
Merge neighboring spans which are index distance apart
Inspired by https://leetcode.com/problems/merge-intervals/
Originally @egork, Revised @kylel
"""

def __init__(
self,
spans: List[Span],
index_distance: int = 1
) -> None:
"""
Args
index_distance (int): Distance between the spans
"""
self._spans = spans
self._index_distance = index_distance
self._graph = self._build_graph(spans=spans, index_distance=index_distance)
self._clusters = None

@property
def spans(self) -> List[Span]:
return self._spans

@property
def index_distance(self) -> int:
return self._index_distance

@index_distance.setter
def index_distance(self, d: int):
"""If modify this distance, everything that's been computed before
should be recomputed."""
self._index_distance = d
self._graph = self._build_graph(spans=self.spans, index_distance=d)
if self._clusters:
self._clusters = self._cluster(spans=self.spans)

@property
def clusters(self) -> List[List[Span]]:
if not self._clusters:
self._clusters = self._cluster(spans=self.spans)
return self._clusters

@staticmethod
def _is_neighboring_spans(span1: Span, span2: Span, index_distance: int) -> bool:
"""Whether two spans are considered neighboring"""
return min(
abs(span1.start - span2.end), abs(span1.end - span2.start)
) <= index_distance

def _build_graph(self, spans: List[Span], index_distance: int) -> Dict[int, List[int]]:
"""
Build graph, each node is the position within the input list of spans.
Spans are considered overlapping if they are index_distance apart
"""
graph = defaultdict(list)
for i, span_i in enumerate(spans):
for j in range(i + 1, len(spans)):
if self._is_neighboring_spans(span1=span_i, span2=spans[j], index_distance=index_distance):
graph[i].append(j)
graph[j].append(i)
return graph

def _cluster(self, spans: List[Span]) -> List[List[Span]]:
"""Cluster nodes (i.e. spans) by finding connected components"""
if len(spans) == 0:
return [[]]

visited = set()
num_components = 0
component_id_to_members = defaultdict(list)

def _dfs(start: int):
stack = [start]
while stack:
pos = stack.pop()
if pos not in visited:
visited.add(pos)
component_id_to_members[num_components].append(pos)
stack.extend(self._graph[pos])

# mark all nodes in the same connected component with the same integer.
for i, span in enumerate(spans):
if i not in visited:
_dfs(start=i)
num_components += 1

return [
[spans[member_id] for member_id in sorted(component_id_to_members[n])]
for n in range(num_components)
]

def merge(self) -> List[Span]:
"""
For each of the lists of the connected nodes, merge into bigger Spans
"""
merged_spans = []
for cluster in self.clusters:
if cluster:
merged_span = Span.create_enclosing_span(spans=cluster)
merged_spans.append(merged_span)
return merged_spans
14 changes: 7 additions & 7 deletions papermage/parsers/grobid_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
TokensFieldName,
)
from papermage.magelib.box import Box
from papermage.magelib.span import MergeClusterSpans
from papermage.parsers.parser import Parser
from papermage.utils.merge import cluster_and_merge_neighbor_spans

REQUIRED_DOCUMENT_FIELDS = [PagesFieldName, RowsFieldName, TokensFieldName]
NS = {"tei": "http://www.tei-c.org/ns/1.0"}
Expand Down Expand Up @@ -115,10 +115,7 @@ def __init__(self, check_server: bool = True, **grobid_config: Any):
os.remove(config_path)

def parse( # type: ignore
self,
input_pdf_path: str,
doc: Document,
xml_out_dir: Optional[str] = None
self, input_pdf_path: str, doc: Document, xml_out_dir: Optional[str] = None
) -> Document:
assert doc.symbols != ""
for field in REQUIRED_DOCUMENT_FIELDS:
Expand Down Expand Up @@ -157,8 +154,11 @@ def parse( # type: ignore

def _make_spans_from_boxes(self, doc: Document, entity: Entity) -> List[Span]:
tokens = [cast(Entity, t) for match in doc.find_by_box(entity, "tokens") for t in match.tokens]
spans = MergeClusterSpans(sorted(set(s for t in tokens for s in t.spans), key=lambda x: x.start)).merge()
return spans
results = cluster_and_merge_neighbor_spans(
spans=sorted(set(s for t in tokens for s in t.spans), key=lambda x: x.start)
)
merged_spans = results.merged
return merged_spans

def _make_spans_from_boxes_if_not_found(self, doc: Document, entity: Entity) -> List[Span]:
spans = [Span(start=s.start, end=s.end) for s in entity.spans]
Expand Down
52 changes: 4 additions & 48 deletions papermage/predictors/spacy_predictors/sentence_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,60 +13,16 @@
import pysbd

from papermage.magelib import (
Annotation,
Document,
Entity,
PagesFieldName,
Span,
TokensFieldName,
WordsFieldName,
Annotation
)
from papermage.predictors.base_predictor import BasePredictor


def merge_neighbor_spans(spans: List[Span], distance=1) -> List[Span]:
"""Merge neighboring spans in a list of un-overlapped spans:
when the gaps between neighboring spans is not larger than the
specified distance, they are considered as the neighbors.
Args:
spans (List[Span]): The input list of spans.
distance (int, optional):
The upper bound of interval gaps between two neighboring spans.
Defaults to 1.
Returns:
List[Span]: A list of merged spans
"""

is_neighboring_spans = (
lambda span1, span2: min(abs(span1.start - span2.end), abs(span1.end - span2.start)) <= distance
)

# It assumes non-overlapped intervals within the list
def merge_neighboring_spans(span1, span2):
return Span(min(span1.start, span2.start), max(span1.end, span2.end))

spans = sorted(spans, key=lambda ele: ele.start)
# When sorted, only one iteration round is needed.

if len(spans) == 0:
return []
if len(spans) == 1:
return spans

cur_merged_spans = [spans[0]]

for cur_span in spans[1:]:
prev_span = cur_merged_spans.pop()
if is_neighboring_spans(cur_span, prev_span):
cur_merged_spans.append(merge_neighboring_spans(prev_span, cur_span))
else:
# In this case, the prev_span should be moved to the
# bottom of the stack
cur_merged_spans.extend([prev_span, cur_span])

return cur_merged_spans
from papermage.utils.merge import cluster_and_merge_neighbor_spans


class PysbdSentencePredictor(BasePredictor):
Expand Down Expand Up @@ -142,7 +98,7 @@ def _predict(self, doc: Document) -> List[Annotation]:
cur_spans = getattr(doc, attr_name)[start:end]

all_token_spans = list(itertools.chain.from_iterable([ele.spans for ele in cur_spans]))

sentence_spans.append(Entity(spans=merge_neighbor_spans(all_token_spans)))
results = cluster_and_merge_neighbor_spans(all_token_spans)
sentence_spans.append(Entity(spans=results.merged))

return sentence_spans
52 changes: 52 additions & 0 deletions papermage/utils/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# utilities

Note from @kylel: Honestly, I kind of hate this; feels like some of these utilities should be baked into `magelib` directly, but not sure where. Will get to it eventually. For now, this README helps keep things a bit more organized.


### merge

Methods that make it easier to combine Entities, which is a fairly common operation when defining new Predictors. For example, many of the Spans of larger Entities like sentences are derived from Token spans merged together.


#### Sep 2023 - Deduping implementations

We had two competing implementations of merging spans. This is just some documentation to help keep track of why we dropped one vs other.

Benchmarking:

```
import random
import time
from papermage.magelib import Span
from papermage.utils.merge import *
# generate random spans, some overlapping some disjoint
many_spans = []
start = 0
for _ in range(10000):
increment_start = random.choices(population=[-1, 0, 1], weights=[0.2, 0.3, 0.5])
is_save = random.choices(population=[False, True], weights=[0.8, 0.2])
if increment_start:
start += increment_start[0]
if is_save:
end = start + random.choices(population=[1, 2, 3, 4, 5],
weights=[0.2, 0.2, 0.2, 0.2, 0.2])[0]
new_span = Span(start=start, end=end)
many_spans.append(new_span)
start = end
# Elapsed: 10.210576057434082
start = time.time()
mcs = MergeClusterSpans(spans=many_spans)
mcs.merge()
end = time.time()
print(f"Elapsed: {end - start}")
# Elapsed: 0.004858255386352539
start = time.time()
results = merge_neighbor_spans(spans=many_spans)
end = time.time()
print(f"Elapsed: {end - start}")
```

MergeClusterSpans is super inefficient. Let's kill it.
Loading

0 comments on commit 7652fd5

Please sign in to comment.