Skip to content

Commit

Permalink
Kylel/layers (#50)
Browse files Browse the repository at this point in the history
* add layer

* remove unused vila pred

* add layer impl

* remove ANnotation; migrate code to Entity

* remove all occurrences of Annotation in favor of Entity
  • Loading branch information
kyleclo authored Oct 10, 2023
1 parent aecadfc commit 3a318fc
Show file tree
Hide file tree
Showing 18 changed files with 261 additions and 245 deletions.
4 changes: 2 additions & 2 deletions papermage/magelib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
"""


from .annotation import Annotation
from .box import Box
from .layer import Layer
from .document import (
AbstractsFieldName,
AlgorithmsFieldName,
Expand Down Expand Up @@ -49,7 +49,6 @@
__all__ = [
"AbstractsFieldName",
"AlgorithmsFieldName",
"Annotation",
"AuthorsFieldName",
"BibliographiesFieldName",
"BlocksFieldName",
Expand All @@ -70,6 +69,7 @@
"ImagesFieldName",
"KeywordsFieldName",
"KeywordsFieldName",
"Layer",
"ListsFieldName",
"Metadata",
"MetadataFieldName",
Expand Down
95 changes: 0 additions & 95 deletions papermage/magelib/annotation.py

This file was deleted.

73 changes: 69 additions & 4 deletions papermage/magelib/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@
"""

from typing import Dict, List, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Union

from .annotation import Annotation
from .box import Box
from .image import Image
from .metadata import Metadata
from .span import Span

if TYPE_CHECKING:
from .document import Document

class Entity(Annotation):

class Entity:
def __init__(
self,
spans: Optional[List[Span]] = None,
Expand All @@ -27,7 +29,8 @@ def __init__(
self.boxes = boxes if boxes else []
self.images = images if images else []
self.metadata = metadata if metadata else Metadata()
super().__init__()
self._id = None
self._doc = None

def __repr__(self):
if self.doc:
Expand All @@ -51,6 +54,45 @@ def from_json(cls, entity_json: Dict) -> "Entity":
metadata=Metadata.from_json(entity_json.get("metadata", {})),
)

@property
def doc(self) -> Optional["Document"]:
return self._doc

@doc.setter
def doc(self, doc: Optional["Document"]) -> None:
"""This method attaches a Document to this Entity, allowing the Entity
to access things beyond itself within the Document (e.g. neighboring entities)"""
if self.doc and doc:
raise AttributeError(
"Already has an attached Document. Since Entity should be"
"specific to a given Document, we recommend creating a new"
"Entity from scratch and then attaching your Document."
)
self._doc = doc

@property
def id(self) -> Optional[int]:
return self._id

@id.setter
def id(self, id: int) -> None:
"""This method assigns an ID to an Entity. Requires a Document to be attached
to this Entity. ID basically gives the Entity itself awareness of its
position within the broader Document."""
if self.id:
raise AttributeError(f"This Entity already has an ID: {self.id}")
if not self.doc:
raise AttributeError("This Entity is missing a Document")
self._id = id

def __getattr__(self, field: str) -> List["Entity"]:
"""This Overloading is convenient syntax since the `entity.layer` operation is intuitive for folks."""
try:
return self.find_by_span(field=field)
except ValueError:
# maybe users just want some attribute of the Entity object
return self.__getattribute__(field)

@property
def start(self) -> Union[int, float]:
return min([span.start for span in self.spans]) if len(self.spans) > 0 else float("-inf")
Expand Down Expand Up @@ -101,3 +143,26 @@ def __lt__(self, other: "Entity"):
return self.id < other.id
else:
return self.start < other.start

def find_by_span(self, field: str) -> List["Entity"]:
"""This method allows you to access overlapping Entities
within the Document based on Span"""
if self.doc is None:
raise ValueError("This entity is not attached to a document")

if field in self.doc.fields:
return self.doc.find_by_span(self, field)
else:
raise ValueError(f"Field {field} not found in Document")

def find_by_box(self, field: str) -> List["Entity"]:
"""This method allows you to access overlapping Entities
within the Document based on Box"""

if self.doc is None:
raise ValueError("This entity is not attached to a document")

if field in self.doc.fields:
return self.doc.find_by_box(self, field)
else:
raise ValueError(f"Field {field} not found in Document")
7 changes: 3 additions & 4 deletions papermage/magelib/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,17 @@
import numpy as np
from ncls import NCLS

from .annotation import Annotation
from .box import Box
from .entity import Entity


class Indexer:
"""Stores an index for a particular collection of Annotations.
"""Stores an index for a particular collection of Entities.
Indexes in this library focus on *INTERSECT* relations."""

@abstractmethod
def find(self, query: Annotation) -> List[Annotation]:
"""Returns all matching Annotations given a suitable query"""
def find(self, query: Entity) -> List[Entity]:
"""Returns all matching Entities given a suitable query"""
raise NotImplementedError()


Expand Down
43 changes: 43 additions & 0 deletions papermage/magelib/layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""
Layers are collections of Entities. Supports indexing and slicing.
@kylel
"""

from typing import List

from .entity import Entity


class Layer:
"""Views into a document. Immutable. Lightweight."""

__slots__ = ["entities"]

def __init__(self, entities: List[Entity]):
self.entities = entities

def __repr__(self):
entity_repr = "\n".join([f"\t{e}" for e in self.entities])
return f"Layer with {len(self)} Entities:\n{entity_repr}"

def __getitem__(self, key):
return self.entities[key]

def __len__(self):
return len(self.entities)

def __iter__(self):
return iter(self.entities)

def __contains__(self, item):
return item in self.entities

def to_json(self):
return [entity.to_json() for entity in self.entities]

@classmethod
def from_json(cls, layer_json):
return cls(entities=[Entity.from_json(entity_json) for entity_json in layer_json])
2 changes: 1 addition & 1 deletion papermage/parsers/grobid_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def find_contiguous_ones(array):
class GrobidFullParser(Parser):
"""Grobid parser that uses Grobid python client to hit a running
Grobid server and convert resulting grobid XML TEI coordinates into
PaperMage Annotations to annotate an existing Document.
PaperMage Entities to annotate an existing Document.
Run a Grobid server (from https://grobid.readthedocs.io/en/latest/Grobid-docker/):
> docker pull lfoppiano/grobid:0.7.2
Expand Down
6 changes: 3 additions & 3 deletions papermage/predictors/base_predictors/base_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from abc import abstractmethod
from typing import Any, Dict, List, Union

from papermage.magelib import Annotation, Document
from papermage.magelib import Document, Entity


class BasePredictor:
Expand All @@ -31,13 +31,13 @@ def _doc_field_checker(self, doc: Document) -> None:
field in doc.fields
), f"The input Document object {doc} doesn't contain the required field {field}"

def predict(self, doc: Document) -> List[Annotation]:
def predict(self, doc: Document) -> List[Entity]:
"""For all the predictors, the input is a document object, and
the output is a list of annotations.
"""
self._doc_field_checker(doc)
return self._predict(doc=doc)

@abstractmethod
def _predict(self, doc: Document) -> List[Annotation]:
def _predict(self, doc: Document) -> List[Entity]:
raise NotImplementedError
Loading

0 comments on commit 3a318fc

Please sign in to comment.