Skip to content

Commit

Permalink
added AddEndToken Operator
Browse files Browse the repository at this point in the history
  • Loading branch information
SoluMilken committed Apr 11, 2019
1 parent e0ddef3 commit 46a7146
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 0 deletions.
64 changes: 64 additions & 0 deletions uttut/pipeline/ops/add_end_token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import List, Tuple

from .base import Operator
from .tokens import END_TOKEN
from .add_sos_eos import AddSosEosAligner
from ..edit import lst2lst
from ..edit.replacement import ReplacementGroup


class AddEndToken(Operator):

"""
Add end token (<eos>) to a sequence.
E.g.
>>> from uttut.pipeline.ops.add_sos_eos import AddSosEos
>>> op = AddEndToken()
>>> output_seq, label_aligner = op.transform(
['I', 'have', '10.7', 'dollars', '.'])
>>> output_labels = label_aligner.transform([1, 2, 3, 4, 5])
>>> output_seq
['I', 'have', '10.7', 'dollars', '.', '<eos>']
>>> output_labels
[1, 2, 3, 4, 5, 0]
>>> label_aligner.inverse_transform(output_labels)
[1, 2, 3, 4, 5]
"""

_input_type = list
_output_type = list

def __init__(self, end_token: str = END_TOKEN):
self.end_token = self._validate_end_token(end_token)

def _validate_end_token(self, end_token: str):
if end_token is None:
raise ValueError("The end token can not be None.")

if not isinstance(end_token, str):
raise TypeError("The end token should be a string.")

if len(end_token) == 0:
raise ValueError("The end token should not be an empty string.")

return end_token

def _transform(self, input_sequence: List[str]) -> Tuple[List[str], 'LabelAligner']:
forward_replacement_group = self._gen_forward_replacement_group(input_sequence)
output_sequence = lst2lst.apply(input_sequence, forward_replacement_group)

label_aligner = AddSosEosAligner(
input_sequence=input_sequence,
edit=forward_replacement_group,
output_length=len(output_sequence),
)
return output_sequence, label_aligner

def _gen_forward_replacement_group(self, input_lst: List[str]) -> ReplacementGroup:
seqlen = len(input_lst)
replacement_group = ReplacementGroup.add_all(
[(seqlen, seqlen, [self.end_token], 'add-end-token')],
)
return replacement_group
71 changes: 71 additions & 0 deletions uttut/pipeline/ops/tests/test_add_end_token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import pytest

from ..add_end_token import AddEndToken
from ..tokens import END_TOKEN
from .common_tests import OperatorTestTemplate, ParamTuple


class TestAddEndToken(OperatorTestTemplate):

params = [
ParamTuple(
['alvin', '喜歡', '吃', '榴槤'],
[1, 2, 3, 4],
['alvin', '喜歡', '吃', '榴槤', '<eos>'],
[1, 2, 3, 4, 0],
'zh',
),
]

@pytest.fixture(scope='class')
def op(self):
return AddEndToken()

def test_not_equal(self, op):
custom_op = AddEndToken(end_token='custom_EOS')
assert custom_op != op


@pytest.mark.parametrize(
"op, expected_configs",
[
pytest.param(
AddEndToken(),
{'end_token': END_TOKEN},
id="default",
),
pytest.param(
AddEndToken('1'),
{'end_token': '1'},
id="no keywords",
),
],
)
def test_correct_configs(op, expected_configs):
assert op.configs == expected_configs


@pytest.mark.parametrize(
"end_token, error",
[
pytest.param(
None,
ValueError,
id="The end token can not be None",
),
pytest.param(
1,
TypeError,
id="wrong type of end token",
),
pytest.param(
'',
ValueError,
id="empty string is not allowed.",
),
],
)
def test_invalid_end_token(end_token, error):
with pytest.raises(error):
AddEndToken(end_token)

0 comments on commit 46a7146

Please sign in to comment.