Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: make Python algorithms consistently return a VectorFst #280

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions rustfst-python/rustfst/algorithms/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from rustfst.fst.vector_fst import VectorFst


def optimize(fst: VectorFst):
def optimize(fst: VectorFst) -> VectorFst:
"""
Optimize an fst in-place
Args:
Expand All @@ -19,9 +19,10 @@ def optimize(fst: VectorFst):
ret_code = lib.fst_optimize(fst.ptr)
err_msg = "Error during optimize"
check_ffi_error(ret_code, err_msg)
return fst


def optimize_in_log(fst: VectorFst):
def optimize_in_log(fst: VectorFst) -> VectorFst:
"""
Optimize an fst in-place in the log semiring.
Args:
Expand All @@ -30,3 +31,4 @@ def optimize_in_log(fst: VectorFst):
ret_code = lib.fst_optimize_in_log(ctypes.byref(fst.ptr))
err_msg = "Error during optimize_in_log"
check_ffi_error(ret_code, err_msg)
return fst
12 changes: 5 additions & 7 deletions rustfst-python/rustfst/algorithms/rm_epsilon.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from __future__ import annotations
import ctypes
from rustfst.ffi_utils import (
lib,
check_ffi_error,
Expand All @@ -10,16 +9,15 @@

def rm_epsilon(fst: VectorFst) -> VectorFst:
"""
Return an equivalent FST with epsilon transitions removed.
Remove epsilon transitions in-place
Args:
fst: Fst
fst: Fst to remove epsilons from
Returns:
Newly created FST with epsilon transitions removed.
fst: Same FST, modified in place
"""

rm_epsilon_fst = ctypes.c_void_p()
ret_code = lib.fst_rm_epsilon(fst.ptr, ctypes.byref(rm_epsilon_fst))
ret_code = lib.fst_rm_epsilon(fst.ptr)
err_msg = "Error during rm_epsilon"
check_ffi_error(ret_code, err_msg)

return VectorFst(ptr=rm_epsilon_fst)
return fst
20 changes: 14 additions & 6 deletions rustfst-python/rustfst/algorithms/tr_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,22 @@
from rustfst.fst.vector_fst import VectorFst


def tr_sort(fst: VectorFst, ilabel_cmp: bool):
"""
tr_sort(fst)
sort fst trs according to their ilabel or olabel
:param fst: Fst
:param ilabel_cmp: bool
def tr_sort(fst: VectorFst, ilabel_cmp: bool) -> VectorFst:
"""Sort trs for an FST in-place according to their input or
output label.

This is often necessary for composition to work properly. It
corresponds to `ArcSort` in OpenFST.

Args:
fst: FST to be tr-sorted in-place.
ilabel_cmp: Sort on input labels if `True`, output labels
if `False`.
Returns:
fst: Same FST that was modified in-place.
"""

ret_code = lib.fst_tr_sort(fst.ptr, ctypes.c_bool(ilabel_cmp))
err_msg = "Error during tr_sort"
check_ffi_error(ret_code, err_msg)
return fst
11 changes: 8 additions & 3 deletions rustfst-python/rustfst/algorithms/tr_unique.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,19 @@
from rustfst.fst.vector_fst import VectorFst


def tr_unique(fst: VectorFst):
def tr_unique(fst: VectorFst) -> VectorFst:
"""
Keep a single instance of trs leaving the same state, going to the same state and
with the same input labels, output labels and weight.
Modify an FST in-place, keeping a single instance of trs
leaving the same state, going to the same state and with the same
input labels, output labels and weight.

Args:
fst: Fst to modify
Returns:
fst: Same FST, modified in-place
"""

ret_code = lib.fst_tr_unique(fst.ptr)
err_msg = "Error during tr_unique"
check_ffi_error(ret_code, err_msg)
return fst
23 changes: 13 additions & 10 deletions rustfst-python/rustfst/fst/vector_fst.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,9 +609,9 @@ def reverse(self) -> VectorFst:

def rm_epsilon(self) -> VectorFst:
"""
Return an equivalent FST with epsilon transitions removed.
Remove epsilon transitions in-place.
Returns:
Newly created FST with epsilon transitions removed.
self: Same FST, modified in place
"""
from rustfst.algorithms.rm_epsilon import rm_epsilon

Expand Down Expand Up @@ -673,8 +673,7 @@ def optimize(self) -> VectorFst:
"""
from rustfst.algorithms.optimize import optimize

optimize(self)
return self
return optimize(self)

def optimize_in_log(self) -> VectorFst:
"""
Expand All @@ -684,10 +683,9 @@ def optimize_in_log(self) -> VectorFst:
"""
from rustfst.algorithms.optimize import optimize_in_log

optimize_in_log(self)
return self
return optimize_in_log(self)

def tr_sort(self, ilabel_cmp: bool = True):
def tr_sort(self, ilabel_cmp: bool = True) -> VectorFst:
"""Sort trs for an FST in-place according to their input or
output label.

Expand All @@ -697,19 +695,24 @@ def tr_sort(self, ilabel_cmp: bool = True):
Args:
ilabel_cmp: Sort on input labels if `True`, output labels
if `False`.
Returns:
self
"""
from rustfst.algorithms.tr_sort import tr_sort

tr_sort(self, ilabel_cmp)
return tr_sort(self, ilabel_cmp)

def tr_unique(self):
def tr_unique(self) -> VectorFst:
"""Modify an FST in-place, keeping a single instance of trs
leaving the same state, going to the same state and with the
same input labels, output labels and weight.

Returns:
self
"""
from rustfst.algorithms.tr_unique import tr_unique

tr_unique(self)
return tr_unique(self)

def isomorphic(self, other: VectorFst) -> bool:
"""
Expand Down
3 changes: 2 additions & 1 deletion rustfst-python/tests/algorithms/test_rm_epsilon.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def test_rm_epsilon():
tr1_2 = Tr(1, 0, 3.0, s2)
expected_fst.add_tr(s1, tr1_2)

fst1.rm_epsilon()
rv = fst1.rm_epsilon()
assert rv == fst1

assert expected_fst == fst1