Skip to content

Commit

Permalink
rename layer stuff (#53)
Browse files Browse the repository at this point in the history
* rename layer stuff

* bugfix

* woah; big change; basically migrate all crossreferencing operation in Doc and Entity to Layer

* increment version
  • Loading branch information
kyleclo authored Oct 14, 2023
1 parent 309f7c9 commit 3781ae0
Show file tree
Hide file tree
Showing 23 changed files with 329 additions and 221 deletions.
138 changes: 78 additions & 60 deletions papermage/magelib/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,12 @@
from itertools import chain
from typing import Dict, List, NamedTuple, Optional, Tuple, Union

from .span import Span
from .box import Box
from .entity import Entity
from .image import Image
from .layer import Layer
from .metadata import Metadata
from .entity import Entity
from .indexer import EntitySpanIndexer, EntityBoxIndexer

from .span import Span

# document field names
SymbolsFieldName = "symbols"
Expand Down Expand Up @@ -55,74 +54,78 @@ class Prediction(NamedTuple):


class Document:
SPECIAL_FIELDS = [SymbolsFieldName, ImagesFieldName, MetadataFieldName, EntitiesFieldName, RelationsFieldName]

def __init__(self, symbols: str, metadata: Optional[Metadata] = None):
self.symbols = symbols
tokens: Layer
rows: Layer
blocks: Layer
words: Layer
sentences: Layer
paragraphs: Layer
pages: Layer

SPECIAL_FIELDS = [SymbolsFieldName, ImagesFieldName, MetadataFieldName]

def __init__(
self,
symbols: Optional[str] = None,
images: Optional[List[Image]] = None,
metadata: Optional[Metadata] = None,
):
self.symbols = symbols if symbols else None
self.images = images if images else None
if not self.symbols and not self.images:
raise ValueError("Document must have at least one of `symbols` or `images`")
self.metadata = metadata if metadata else Metadata()
self.__entity_span_indexers: Dict[str, EntitySpanIndexer] = {}
self.__entity_box_indexers: Dict[str, EntityBoxIndexer] = {}
self._layers: List[str] = []

@property
def fields(self) -> List[str]:
return list(self.__entity_span_indexers.keys()) + self.SPECIAL_FIELDS

def find(self, query: Union[Span, Box], field_name: str) -> List[Entity]:
if isinstance(query, Span):
return self.__entity_span_indexers[field_name].find(query=Entity(spans=[query]))
elif isinstance(query, Box):
return self.__entity_box_indexers[field_name].find(query=Entity(boxes=[query]))
else:
raise TypeError(f"Unsupported query type {type(query)}")

def find_by_span(self, query: Entity, field_name: str) -> List[Entity]:
# TODO: will rename this to `intersect_by_span`
return self.__entity_span_indexers[field_name].find(query=query)
def layers(self) -> List[str]:
return self.SPECIAL_FIELDS + self._layers

def find_by_box(self, query: Entity, field_name: str) -> List[Entity]:
# TODO: will rename this to `intersect_by_span`
return self.__entity_box_indexers[field_name].find(query=query)
def validate_layer_name_availability(self, name: str) -> None:
if name in self.layers:
raise AssertionError(f"{name} not allowed Document.SPECIAL_FIELDS.")
if name in self.layers:
raise AssertionError(f'{name} already exists. Try `doc.remove_layer("{name}")` first.')
if name in dir(self):
raise AssertionError(f"{name} clashes with Document class properties.")

def check_field_name_availability(self, field_name: str) -> None:
if field_name in self.SPECIAL_FIELDS:
raise AssertionError(f"{field_name} not allowed Document.SPECIAL_FIELDS.")
if field_name in self.__entity_span_indexers.keys():
raise AssertionError(f'{field_name} already exists. Try `doc.remove_entity("{field_name}")` first.')
if field_name in dir(self):
raise AssertionError(f"{field_name} clashes with Document class properties.")

def get_entity(self, field_name: str) -> List[Entity]:
return getattr(self, field_name)
def get_layer(self, name: str) -> Layer:
"""Gets a layer by name. For example, `doc.get_layer("sentences")` returns sentences."""
return getattr(self, name)

def annotate(self, *predictions: Union[Prediction, Tuple[Prediction, ...]]) -> None:
"""Annotates the document with predictions."""
all_preds = chain.from_iterable([p] if isinstance(p, Prediction) else p for p in predictions)
for prediction in all_preds:
self.annotate_entity(field_name=prediction.name, entities=prediction.entities)

def annotate_entity(self, field_name: str, entities: List[Entity]) -> None:
self.check_field_name_availability(field_name=field_name)
self.annotate_layer(name=prediction.name, entities=prediction.entities)

for i, entity in enumerate(entities):
entity.doc = self
entity.id = i
def annotate_layer(self, name: str, entities: Union[List[Entity], Layer]) -> None:
self.validate_layer_name_availability(name=name)

self.__entity_span_indexers[field_name] = EntitySpanIndexer(entities=entities)
self.__entity_box_indexers[field_name] = EntityBoxIndexer(entities=entities)
setattr(self, field_name, entities)
if isinstance(entities, list):
layer = Layer(entities=entities)
else:
layer = entities

def remove_entity(self, field_name: str):
for entity in getattr(self, field_name):
entity.doc = None
layer.doc = self
layer.name = name
setattr(self, name, layer)
self._layers.append(name)

delattr(self, field_name)
del self.__entity_span_indexers[field_name]
del self.__entity_box_indexers[field_name]
def remove_layer(self, name: str):
if name not in self.layers:
pass
else:
getattr(self, name).doc = None
getattr(self, name).name = None
delattr(self, name)
self._layers.remove(name)

def get_relation(self, name: str) -> List["Relation"]:
raise NotImplementedError

def annotate_relation(self, name: str) -> None:
self.check_field_name_availability(field_name=name)
self.validate_layer_name_availability(name=name)
raise NotImplementedError

def remove_relation(self, name: str) -> None:
Expand All @@ -145,7 +148,7 @@ def annotate_images(self, images: List[Image]) -> None:
def remove_images(self) -> None:
raise NotImplementedError

def to_json(self, field_names: Optional[List[str]] = None, with_images: bool = False) -> Dict:
def to_json(self, layers: Optional[List[str]] = None, with_images: bool = False) -> Dict:
"""Returns a dictionary that's suitable for serialization
Use `fields` to specify a subset of groups in the Document to include (e.g. 'sentences')
Expand All @@ -166,10 +169,10 @@ def to_json(self, field_names: Optional[List[str]] = None, with_images: bool = F
RelationsFieldName: {},
}

# 2) serialize each field to JSON
field_names = list(self.__entity_span_indexers.keys()) if field_names is None else field_names
for field_name in field_names:
doc_dict[EntitiesFieldName][field_name] = [entity.to_json() for entity in getattr(self, field_name)]
# 2) serialize each layer to JSON
layers = self._layers if layers is None else layers
for layer in layers:
doc_dict[EntitiesFieldName][layer] = [entity.to_json() for entity in getattr(self, layer)]

# 3) serialize images if `with_images == True`
if with_images:
Expand All @@ -186,6 +189,21 @@ def from_json(cls, doc_json: Dict) -> "Document":
# 2) instantiate entities
for field_name, entity_jsons in doc_json[EntitiesFieldName].items():
entities = [Entity.from_json(entity_json=entity_json) for entity_json in entity_jsons]
doc.annotate_entity(field_name=field_name, entities=entities)
doc.annotate_layer(name=field_name, entities=entities)

return doc

def __repr__(self):
return f"Document with {len(self.layers)} layers: {self.layers}"

def find(self, query: Union[Span, Box], name: str) -> List[Entity]:
"""Finds all entities that intersect with the query"""
return self.get_layer(name=name).find(query=query)

def intersect_by_span(self, query: Entity, name: str) -> List[Entity]:
"""Finds all entities that intersect with the query"""
return self.get_layer(name=name).intersect_by_span(query=query)

def intersect_by_box(self, query: Entity, name: str) -> List[Entity]:
"""Finds all entities that intersect with the query"""
return self.get_layer(name=name).intersect_by_box(query=query)
127 changes: 71 additions & 56 deletions papermage/magelib/entity.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
An annotated "unit" on a Document.
An annotated "unit" in a Layer.
"""

Expand All @@ -12,11 +12,12 @@
from .span import Span

if TYPE_CHECKING:
from .document import Document
from .document import TokensFieldName
from .layer import Layer


class Entity:
__slots__ = ["spans", "boxes", "images", "metadata", "_id", "_doc"]
__slots__ = ["spans", "boxes", "images", "metadata", "_id", "_layer"]

def __init__(
self,
Expand All @@ -27,15 +28,18 @@ def __init__(
):
if not spans and not boxes:
raise ValueError(f"At least one of `spans` or `boxes` must be set.")
self.spans = spans if spans else []
self.boxes = boxes if boxes else []
self.images = images if images else []
self.spans = spans if spans is not None else []
self.boxes = boxes if boxes is not None else []
self.images = images if images is not None else []
self.metadata = metadata if metadata else Metadata()
# TODO: it's confusing that `id` is both reading order as well as direct reference
# TODO: maybe Layer() should house reading order, and Entity() should have a unique ID
# TODO: hashing would be interesting, but Metadata() is allowed to mutate so that's a problem
self._id = None
self._doc = None
self._layer = None

def __repr__(self):
if self.doc:
if self.layer:
return f"Annotated Entity:\tID: {self.id}\tSpans: {True if self.spans else False}\tBoxes: {True if self.boxes else False}\tText: {self.text}"
return f"Unannotated Entity: {self.to_json()}"

Expand All @@ -57,20 +61,20 @@ def from_json(cls, entity_json: Dict) -> "Entity":
)

@property
def doc(self) -> Optional["Document"]:
return self._doc

@doc.setter
def doc(self, doc: Optional["Document"]) -> None:
"""This method attaches a Document to this Entity, allowing the Entity
to access things beyond itself within the Document (e.g. neighboring entities)"""
if self.doc and doc:
def layer(self) -> Optional["Layer"]:
return self._layer

@layer.setter
def layer(self, layer: Optional["Layer"]) -> None:
"""This method attaches a Layer to this Entity, allowing the Entity
to access things beyond itself within the Layer (e.g. neighboring Entities)"""
if self.layer and layer:
raise AttributeError(
"Already has an attached Document. Since Entity should be"
"specific to a given Document, we recommend creating a new"
"Entity from scratch and then attaching your Document."
"Already has an attached Layer. Since Entity should correspond"
"to only a specific Layer, we recommend creating a new"
"Entity from scratch and then attaching your Layer."
)
self._doc = doc
self._layer = layer

@property
def id(self) -> Optional[int]:
Expand All @@ -83,17 +87,39 @@ def id(self, id: int) -> None:
position within the broader Document."""
if self.id:
raise AttributeError(f"This Entity already has an ID: {self.id}")
if not self.doc:
if not self.layer:
raise AttributeError("This Entity is missing a Document")
self._id = id

def __getattr__(self, field: str) -> List["Entity"]:
def __getattr__(self, name: str) -> List["Entity"]:
"""This Overloading is convenient syntax since the `entity.layer` operation is intuitive for folks."""
try:
return self.find_by_span(field=field)
return self.intersect_by_span(name=name)
except ValueError:
# maybe users just want some attribute of the Entity object
return self.__getattribute__(field)
return self.__getattribute__(name)

def intersect_by_span(self, name: str) -> List["Entity"]:
"""This method allows you to access overlapping Entities
within the Document based on Span"""
if self.layer is None:
raise ValueError("This Entity is not attached to a Layer")

if self.layer.doc is None:
raise ValueError("This Entity's Layer is not attached to a Document")

return self.layer.doc.intersect_by_span(query=self, name=name)

def intersect_by_box(self, name: str) -> List["Entity"]:
"""This method allows you to access overlapping Entities
within the Document based on Box"""
if self.layer is None:
raise ValueError("This Entity is not attached to a Layer")

if self.layer.doc is None:
raise ValueError("This Entity's Layer is not attached to a Document")

return self.layer.doc.intersect_by_box(query=self, name=name)

@property
def start(self) -> Union[int, float]:
Expand All @@ -105,18 +131,30 @@ def end(self) -> Union[int, float]:

@property
def symbols_from_spans(self) -> List[str]:
if self.doc is not None:
return [self.doc.symbols[span.start : span.end] for span in self.spans]
else:
return []
if self.layer is None:
raise ValueError("This Entity is not attached to a Layer")

if self.layer.doc is None:
raise ValueError("This Entity's Layer is not attached to a Document")

if self.layer.doc.symbols is None:
raise ValueError("This Entity's Document is missing symbols")

return [self.layer.doc.symbols[span.start : span.end] for span in self.spans]

@property
def symbols_from_boxes(self) -> List[str]:
if self.doc is not None:
matched_tokens = self.doc.find_by_box(query=self, field_name="tokens")
return [self.doc.symbols[span.start : span.end] for t in matched_tokens for span in t.spans]
else:
return []
if self.layer is None:
raise ValueError("This Entity is not attached to a Layer")

if self.layer.doc is None:
raise ValueError("This Entity's Layer is not attached to a Document")

if self.layer.doc.symbols is None:
raise ValueError("This Entity's Document is missing symbols")

matched_tokens = self.intersect_by_box(name=TokensFieldName)
return [self.layer.doc.symbols[span.start : span.end] for t in matched_tokens for span in t.spans]

@property
def text(self) -> str:
Expand Down Expand Up @@ -145,26 +183,3 @@ def __lt__(self, other: "Entity"):
return self.id < other.id
else:
return self.start < other.start

def find_by_span(self, field: str) -> List["Entity"]:
"""This method allows you to access overlapping Entities
within the Document based on Span"""
if self.doc is None:
raise ValueError("This entity is not attached to a document")

if field in self.doc.fields:
return self.doc.find_by_span(self, field)
else:
raise ValueError(f"Field {field} not found in Document")

def find_by_box(self, field: str) -> List["Entity"]:
"""This method allows you to access overlapping Entities
within the Document based on Box"""

if self.doc is None:
raise ValueError("This entity is not attached to a document")

if field in self.doc.fields:
return self.doc.find_by_box(self, field)
else:
raise ValueError(f"Field {field} not found in Document")
Loading

0 comments on commit 3781ae0

Please sign in to comment.