Skip to content

Commit

Permalink
Add a serializer for the IOBES format (#12)
Browse files Browse the repository at this point in the history
* Add IOBESSerializer

* Add a test case for IOBESSerializer

* Bump version
  • Loading branch information
yasufumy authored Oct 4, 2021
1 parent 6be660d commit 6189d84
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 2 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "seqlabel"
version = "0.2.3"
version = "0.2.4"
description = "Rule-based Text Labeling Framework Aiming at Flexibility"
authors = ["Yasufumi Taniguchi <yasufumi.taniguchi@gmail.com>"]
license = "MIT"
Expand Down
35 changes: 35 additions & 0 deletions seqlabel/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,41 @@ def save(self, text: StringSequence, entities: List[Entity]) -> str:
return "\n".join(result)


class IOBESSerializer(Serializer):
"""IOBES format Serializer."""

def save(self, text: StringSequence, entities: List[Entity]) -> str:
"""Converts a text and entities as IOBES format.
Args:
text: A text.
entities: A list of entities appeared in a given text.
Returns:
A IOBES format string.
"""
sequence = list(text)
n = len(sequence)
tags = ["O"] * n
for entity in entities:
start, end = text.align_offsets(entity.start_offset, entity.end_offset)

if any(tag != "O" for tag in tags[start : end + 1]):
raise ValueError("Overlapping spans are found.")

if start == end:
tags[start] = f"S-{entity.label}"
continue
tags[start] = f"B-{entity.label}"
tags[start + 1 : end] = [f"I-{entity.label}"] * (end - start - 1)
tags[end] = f"E-{entity.label}"

result = []
for item, tag in zip(sequence, tags):
result.append(f"{item}\t{tag}")
return "\n".join(result)


class BILOUSerializer(Serializer):
"""BILOU format Serializer."""

Expand Down
16 changes: 15 additions & 1 deletion tests/test_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

from seqlabel.core import Entity, StringSequence
from seqlabel.serializers import BILOUSerializer, IOB2Serializer, JSONLSerializer
from seqlabel.serializers import BILOUSerializer, IOB2Serializer, IOBESSerializer, JSONLSerializer


@pytest.mark.parametrize(
Expand Down Expand Up @@ -35,6 +35,20 @@ def test_iob2_serializer_save(text_ja: StringSequence, entities: List[Entity], e
assert serializer.save(text_ja, entities) == expected


@pytest.mark.parametrize(
"entities,expected",
[
(
[Entity(6, 8, "LOC")],
"日\tO\n\tO\n\tO\n\tO\n\tO\n\tO\n\tB-LOC\n\tI-LOC\n\tE-LOC\n\tO\n\tO\n\tO",
)
],
)
def test_iobes_serializer_save(text_ja: StringSequence, entities: List[Entity], expected: str) -> None:
serializer = IOBESSerializer()
assert serializer.save(text_ja, entities) == expected


@pytest.mark.parametrize(
"entities,expected",
[
Expand Down

0 comments on commit 6189d84

Please sign in to comment.