From 570d9b341cbd0a38a283e7d08861b80c439e2c11 Mon Sep 17 00:00:00 2001 From: Valery Mironov <32071355+MBkkt@users.noreply.github.com> Date: Thu, 7 Dec 2023 14:05:46 +0100 Subject: [PATCH] Fix --- .../analysis/multi_delimited_token_stream.cpp | 15 ++--- .../kaldi/src/fstext/determinize-star-inl.h | 65 +++++++++++-------- .../multi_delimited_token_stream_tests.cpp | 4 +- 3 files changed, 48 insertions(+), 36 deletions(-) diff --git a/core/analysis/multi_delimited_token_stream.cpp b/core/analysis/multi_delimited_token_stream.cpp index 6f4851cd0..22c1813da 100644 --- a/core/analysis/multi_delimited_token_stream.cpp +++ b/core/analysis/multi_delimited_token_stream.cpp @@ -180,8 +180,7 @@ void make_string(automaton& a, bytes_view str) { // add reset edges if (last_no_match != -1) { a.EmplaceArc(current_state, - range_label::fromRange(last_no_match, c - 1), - 0); + range_label::fromRange(last_no_match, c - 1), 0); last_no_match = -1; } // add forward edge @@ -192,8 +191,7 @@ void make_string(automaton& a, bytes_view str) { // add reset edges if (last_no_match != -1) { a.EmplaceArc(current_state, - range_label::fromRange(last_no_match, c - 1), - 0); + range_label::fromRange(last_no_match, c - 1), 0); last_no_match = -1; } @@ -223,13 +221,12 @@ void make_string(automaton& a, bytes_view str) { if (last_no_match != -1) { a.EmplaceArc(current_state, - range_label::fromRange(last_no_match, UCHAR_MAX), - 0); + range_label::fromRange(last_no_match, UCHAR_MAX), 0); last_no_match = -1; } } - //a.EmplaceArc(first_state + str.length(), range_label::fromRange(0), 1); + // a.EmplaceArc(first_state + str.length(), range_label::fromRange(0), 1); a.EmplaceArc(0, range_label::fromRange(0), first_state); } @@ -270,8 +267,10 @@ class multi_delimited_token_stream_generic final fst::DeterminizeStar(nfa, &dfa); std::cout << "number of states (dfa) = " << dfa.NumStates() << std::endl; - //fst::Minimize(&dfa); + // fst::Minimize(&dfa); + std::cout << "HUI\n"; fst::drawFst(dfa, std::cout); + std::cout << "HUI\n"; std::cout << "number of states = " << dfa.NumStates() << std::endl; diff --git a/external/kaldi/src/fstext/determinize-star-inl.h b/external/kaldi/src/fstext/determinize-star-inl.h index 2268af896..b0a14b4ff 100644 --- a/external/kaldi/src/fstext/determinize-star-inl.h +++ b/external/kaldi/src/fstext/determinize-star-inl.h @@ -269,23 +269,24 @@ template class DeterminizerStar { }; struct RangeElement { - explicit RangeElement(Label min, Label max, const Element& element, std::size_t unique_id) - : bound{min}, max{max}, unique_id(unique_id), element{element} {} - explicit RangeElement(Label max, Element&& element, std::size_t unique_id) - : bound{max}, max{fst::kNoLabel}, unique_id(unique_id), element{std::move(element)} {} + explicit RangeElement(Label min, Label max, const Element& to_element, InputStateId from_id, size_t unique_id) + : bound{min}, max{max}, from_id{from_id}, unique_id{unique_id}, to_element{to_element} {} + explicit RangeElement(Label max, Element&& to_element, InputStateId from_id, size_t unique_id) + : bound{max}, max{fst::kNoLabel}, from_id{from_id}, unique_id{unique_id}, to_element{std::move(to_element)} {} Label bound; Label max; - std::size_t unique_id; + InputStateId from_id; + size_t unique_id; bool IsMax() const noexcept { return max == fst::kNoLabel; } - Element& Get() noexcept { return element; } - const Element& Get() const noexcept { return element; } + Element& Get() noexcept { return to_element; } + const Element& Get() const noexcept { return to_element; } private: // TODO(MBkkt) store element in separate array to avoid unnecessary copies - Element element; + Element to_element; }; // Arcs in the format we temporarily create in this class (a representation, essentially of @@ -548,8 +549,8 @@ template class DeterminizerStar { seq_.push_back(arc.olabel); element.string = repository_.IdOfSeq(seq_); } - all_elems_.emplace_back(arc.min, arc.max, element, unique_id); - all_elems_.emplace_back(arc.max, std::move(element), unique_id++); + all_elems_.emplace_back(arc.min, arc.max, element, elem.state, unique_id); + all_elems_.emplace_back(arc.max, std::move(element), elem.state, unique_id++); } } } @@ -557,6 +558,10 @@ template class DeterminizerStar { // now sorted first on input label bound, bound type, then read comparator if (!epsilon_closure_.FstSorted() || closed_subset_.size() > 1) { std::sort(all_elems_.begin(), all_elems_.end(), [](const auto& lhs, const auto& rhs) { + if (lhs.from_id != rhs.from_id) { + return lhs.from_id < rhs.from_id; + } + // TODO(MBkkt) We maybe want to move to_id here if (lhs.bound != rhs.bound) { return lhs.bound < rhs.bound; } @@ -565,19 +570,24 @@ template class DeterminizerStar { if (lhs_max != rhs_max) { return lhs_max < rhs_max; } - const auto& lhs_e = lhs.Get(); - const auto& rhs_e = rhs.Get(); + const auto& lhs_to_id = lhs.Get().state; + const auto& rhs_to_id = rhs.Get().state; + if (lhs_max) { - // same max elements just sorted by state - IRS_ASSERT((&lhs != &rhs) == (lhs_e.state != rhs_e.state)); - return lhs_e.state < rhs_e.state; + // same max elements just sorted by to_id and unique_id + if (lhs_to_id != rhs_to_id) { + return lhs_to_id > rhs_to_id; + } + return lhs.unique_id > rhs.unique_id; } // same min elements sorted opposite to their max elements if (lhs.max != rhs.max) { return lhs.max > rhs.max; } - IRS_ASSERT((&lhs != &rhs) == (lhs_e.state != rhs_e.state)); - return lhs_e.state > rhs_e.state; + if (lhs_to_id != rhs_to_id) { + return lhs_to_id < rhs_to_id; + } + return lhs.unique_id < rhs.unique_id; }); } @@ -585,25 +595,28 @@ template class DeterminizerStar { closed_subset_.clear(); fsa::RangeLabel label; - std::cout << "ALL ELEMS SIZE = " << all_elems_.size() << std::endl; +#ifdef IRESEARCH_DEBUG std::vector brackets; +#endif for (auto& e : all_elems_) { const auto& bound = e.bound; const bool is_max = e.IsMax(); +#ifdef IRESEARCH_DEBUG if (!is_max) { - brackets.push_back(true); + brackets.push_back(e.unique_id); } else { - assert(!brackets.empty()); - assert(brackets.back() == e.unique_id); + IRS_ASSERT(!brackets.empty()); + IRS_ASSERT(brackets.back() == e.unique_id); brackets.pop_back(); } +#endif if (!is_max) { if (label.ilabel != fst::kNoLabel && label.min != bound) { label.max = bound - 1; - assert(!closed_subset_.empty()); - assert(label.min <= label.max); + IRS_ASSERT(!closed_subset_.empty()); + IRS_ASSERT(label.min <= label.max); ProcessTransition(state, label.ilabel, &closed_subset_); } @@ -612,13 +625,13 @@ template class DeterminizerStar { } else { if (label.max != bound) { label.max = bound; - assert(!closed_subset_.empty()); - assert(label.min <= label.max); + IRS_ASSERT(!closed_subset_.empty()); + IRS_ASSERT(label.min <= label.max); ProcessTransition(state, label.ilabel, &closed_subset_); label.min = bound + 1; } - assert(!closed_subset_.empty()); + IRS_ASSERT(!closed_subset_.empty()); closed_subset_.pop_back(); if (closed_subset_.empty()) { label.ilabel = fst::kNoLabel; diff --git a/tests/analysis/multi_delimited_token_stream_tests.cpp b/tests/analysis/multi_delimited_token_stream_tests.cpp index 4c5b99247..15634720d 100644 --- a/tests/analysis/multi_delimited_token_stream_tests.cpp +++ b/tests/analysis/multi_delimited_token_stream_tests.cpp @@ -154,8 +154,8 @@ TEST_F(multi_delimited_token_stream_tests, no_delimiter) { TEST_F(multi_delimited_token_stream_tests, multi_words) { auto stream = irs::analysis::multi_delimited_token_stream::make( - {.delimiters = {"f"_b, "b"_b}}); - //{.delimiters = {"f"_b, "g"_b, "h"_b, "j"_b}}); + {.delimiters = {"foo"_b, "bar"_b, "bas"_b}}); + //{.delimiters = {"f"_b, "g"_b, "h"_b, "j"_b}}); ASSERT_EQ(irs::type::id(), stream->type());