From 49df315411b7797b87e967c1298eb49982b9ea6b Mon Sep 17 00:00:00 2001 From: David Huggins-Daines Date: Thu, 15 Aug 2024 14:32:19 -0400 Subject: [PATCH 1/4] fix: make algorithms consistently return a VectorFst --- rustfst-python/rustfst/algorithms/optimize.py | 6 ++++-- rustfst-python/rustfst/algorithms/tr_sort.py | 3 ++- .../rustfst/algorithms/tr_unique.py | 3 ++- rustfst-python/rustfst/fst/vector_fst.py | 19 +++++++++++-------- 4 files changed, 19 insertions(+), 12 deletions(-) diff --git a/rustfst-python/rustfst/algorithms/optimize.py b/rustfst-python/rustfst/algorithms/optimize.py index 13b8bd168..0e657f8b5 100644 --- a/rustfst-python/rustfst/algorithms/optimize.py +++ b/rustfst-python/rustfst/algorithms/optimize.py @@ -10,7 +10,7 @@ from rustfst.fst.vector_fst import VectorFst -def optimize(fst: VectorFst): +def optimize(fst: VectorFst) -> VectorFst: """ Optimize an fst in-place Args: @@ -19,9 +19,10 @@ def optimize(fst: VectorFst): ret_code = lib.fst_optimize(fst.ptr) err_msg = "Error during optimize" check_ffi_error(ret_code, err_msg) + return fst -def optimize_in_log(fst: VectorFst): +def optimize_in_log(fst: VectorFst) -> VectorFst: """ Optimize an fst in-place in the log semiring. Args: @@ -30,3 +31,4 @@ def optimize_in_log(fst: VectorFst): ret_code = lib.fst_optimize_in_log(ctypes.byref(fst.ptr)) err_msg = "Error during optimize_in_log" check_ffi_error(ret_code, err_msg) + return fst diff --git a/rustfst-python/rustfst/algorithms/tr_sort.py b/rustfst-python/rustfst/algorithms/tr_sort.py index 5958e4e0e..ab17deda5 100644 --- a/rustfst-python/rustfst/algorithms/tr_sort.py +++ b/rustfst-python/rustfst/algorithms/tr_sort.py @@ -8,7 +8,7 @@ from rustfst.fst.vector_fst import VectorFst -def tr_sort(fst: VectorFst, ilabel_cmp: bool): +def tr_sort(fst: VectorFst, ilabel_cmp: bool) -> VectorFst: """ tr_sort(fst) sort fst trs according to their ilabel or olabel @@ -19,3 +19,4 @@ def tr_sort(fst: VectorFst, ilabel_cmp: bool): ret_code = lib.fst_tr_sort(fst.ptr, ctypes.c_bool(ilabel_cmp)) err_msg = "Error during tr_sort" check_ffi_error(ret_code, err_msg) + return fst diff --git a/rustfst-python/rustfst/algorithms/tr_unique.py b/rustfst-python/rustfst/algorithms/tr_unique.py index a7d3c210d..1e1ffb71b 100644 --- a/rustfst-python/rustfst/algorithms/tr_unique.py +++ b/rustfst-python/rustfst/algorithms/tr_unique.py @@ -7,7 +7,7 @@ from rustfst.fst.vector_fst import VectorFst -def tr_unique(fst: VectorFst): +def tr_unique(fst: VectorFst) -> VectorFst: """ Keep a single instance of trs leaving the same state, going to the same state and with the same input labels, output labels and weight. @@ -18,3 +18,4 @@ def tr_unique(fst: VectorFst): ret_code = lib.fst_tr_unique(fst.ptr) err_msg = "Error during tr_unique" check_ffi_error(ret_code, err_msg) + return fst diff --git a/rustfst-python/rustfst/fst/vector_fst.py b/rustfst-python/rustfst/fst/vector_fst.py index 06f5b5919..4194a082b 100644 --- a/rustfst-python/rustfst/fst/vector_fst.py +++ b/rustfst-python/rustfst/fst/vector_fst.py @@ -673,8 +673,7 @@ def optimize(self) -> VectorFst: """ from rustfst.algorithms.optimize import optimize - optimize(self) - return self + return optimize(self) def optimize_in_log(self) -> VectorFst: """ @@ -684,10 +683,9 @@ def optimize_in_log(self) -> VectorFst: """ from rustfst.algorithms.optimize import optimize_in_log - optimize_in_log(self) - return self + return optimize_in_log(self) - def tr_sort(self, ilabel_cmp: bool = True): + def tr_sort(self, ilabel_cmp: bool = True) -> VectorFst: """Sort trs for an FST in-place according to their input or output label. @@ -697,19 +695,24 @@ def tr_sort(self, ilabel_cmp: bool = True): Args: ilabel_cmp: Sort on input labels if `True`, output labels if `False`. + Returns: + self """ from rustfst.algorithms.tr_sort import tr_sort - tr_sort(self, ilabel_cmp) + return tr_sort(self, ilabel_cmp) - def tr_unique(self): + def tr_unique(self) -> VectorFst: """Modify an FST in-place, keeping a single instance of trs leaving the same state, going to the same state and with the same input labels, output labels and weight. + + Returns: + self """ from rustfst.algorithms.tr_unique import tr_unique - tr_unique(self) + return tr_unique(self) def isomorphic(self, other: VectorFst) -> bool: """ From 7968300a1e4f068fa5317c7e73d3dd664ce8d5ed Mon Sep 17 00:00:00 2001 From: David Huggins-Daines Date: Thu, 15 Aug 2024 16:46:34 -0400 Subject: [PATCH 2/4] fix: rm_epsilon **is** an in-place operation!!! --- rustfst-python/rustfst/algorithms/rm_epsilon.py | 12 +++++------- rustfst-python/rustfst/algorithms/tr_sort.py | 17 ++++++++++++----- rustfst-python/rustfst/algorithms/tr_unique.py | 8 ++++++-- 3 files changed, 23 insertions(+), 14 deletions(-) diff --git a/rustfst-python/rustfst/algorithms/rm_epsilon.py b/rustfst-python/rustfst/algorithms/rm_epsilon.py index f619ab185..240f0e339 100644 --- a/rustfst-python/rustfst/algorithms/rm_epsilon.py +++ b/rustfst-python/rustfst/algorithms/rm_epsilon.py @@ -1,5 +1,4 @@ from __future__ import annotations -import ctypes from rustfst.ffi_utils import ( lib, check_ffi_error, @@ -10,16 +9,15 @@ def rm_epsilon(fst: VectorFst) -> VectorFst: """ - Return an equivalent FST with epsilon transitions removed. + Remove epsilon transitions in-place Args: - fst: Fst + fst: Fst to remove epsilons from Returns: - Newly created FST with epsilon transitions removed. + fst: Same FST, modified in place """ - rm_epsilon_fst = ctypes.c_void_p() - ret_code = lib.fst_rm_epsilon(fst.ptr, ctypes.byref(rm_epsilon_fst)) + ret_code = lib.fst_rm_epsilon(fst.ptr) err_msg = "Error during rm_epsilon" check_ffi_error(ret_code, err_msg) - return VectorFst(ptr=rm_epsilon_fst) + return fst diff --git a/rustfst-python/rustfst/algorithms/tr_sort.py b/rustfst-python/rustfst/algorithms/tr_sort.py index ab17deda5..896ba1679 100644 --- a/rustfst-python/rustfst/algorithms/tr_sort.py +++ b/rustfst-python/rustfst/algorithms/tr_sort.py @@ -9,11 +9,18 @@ def tr_sort(fst: VectorFst, ilabel_cmp: bool) -> VectorFst: - """ - tr_sort(fst) - sort fst trs according to their ilabel or olabel - :param fst: Fst - :param ilabel_cmp: bool + """Sort trs for an FST in-place according to their input or + output label. + + This is often necessary for composition to work properly. It + corresponds to `ArcSort` in OpenFST. + + Args: + fst: FST to be tr-sorted in-place. + ilabel_cmp: Sort on input labels if `True`, output labels + if `False`. + Returns: + fst: Same FST that was modified in-place. """ ret_code = lib.fst_tr_sort(fst.ptr, ctypes.c_bool(ilabel_cmp)) diff --git a/rustfst-python/rustfst/algorithms/tr_unique.py b/rustfst-python/rustfst/algorithms/tr_unique.py index 1e1ffb71b..e9bef8733 100644 --- a/rustfst-python/rustfst/algorithms/tr_unique.py +++ b/rustfst-python/rustfst/algorithms/tr_unique.py @@ -9,10 +9,14 @@ def tr_unique(fst: VectorFst) -> VectorFst: """ - Keep a single instance of trs leaving the same state, going to the same state and - with the same input labels, output labels and weight. + Modify an FST in-place, keeping a single instance of trs + leaving the same state, going to the same state and with the same + input labels, output labels and weight. + Args: fst: Fst to modify + Returns: + fst: Same FST, modified in-place """ ret_code = lib.fst_tr_unique(fst.ptr) From b2728d228630299bc839258f988061b0ad7d0e2e Mon Sep 17 00:00:00 2001 From: David Huggins-Daines Date: Thu, 15 Aug 2024 16:50:24 -0400 Subject: [PATCH 3/4] tests: test that rm_epsilon returns fst --- rustfst-python/tests/algorithms/test_rm_epsilon.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/rustfst-python/tests/algorithms/test_rm_epsilon.py b/rustfst-python/tests/algorithms/test_rm_epsilon.py index 2f6d6f943..9796d922d 100644 --- a/rustfst-python/tests/algorithms/test_rm_epsilon.py +++ b/rustfst-python/tests/algorithms/test_rm_epsilon.py @@ -47,6 +47,7 @@ def test_rm_epsilon(): tr1_2 = Tr(1, 0, 3.0, s2) expected_fst.add_tr(s1, tr1_2) - fst1.rm_epsilon() + rv = fst1.rm_epsilon() + assert rv == fst1 assert expected_fst == fst1 From e42c8aff4e21812b6bf36b88808b4e588e16aa64 Mon Sep 17 00:00:00 2001 From: David Huggins-Daines Date: Fri, 16 Aug 2024 15:45:45 -0400 Subject: [PATCH 4/4] fix: correct documentation for rm_epsilon (again) --- rustfst-python/rustfst/fst/vector_fst.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rustfst-python/rustfst/fst/vector_fst.py b/rustfst-python/rustfst/fst/vector_fst.py index 4194a082b..353f9f351 100644 --- a/rustfst-python/rustfst/fst/vector_fst.py +++ b/rustfst-python/rustfst/fst/vector_fst.py @@ -609,9 +609,9 @@ def reverse(self) -> VectorFst: def rm_epsilon(self) -> VectorFst: """ - Return an equivalent FST with epsilon transitions removed. + Remove epsilon transitions in-place. Returns: - Newly created FST with epsilon transitions removed. + self: Same FST, modified in place """ from rustfst.algorithms.rm_epsilon import rm_epsilon