Skip to content
This repository has been archived by the owner on May 3, 2024. It is now read-only.

Commit

Permalink
Fix determinize-star
Browse files Browse the repository at this point in the history
  • Loading branch information
MBkkt committed Aug 21, 2023
1 parent c87ad21 commit 7ca8bef
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 80 deletions.
40 changes: 18 additions & 22 deletions core/utils/fstext/fst_draw.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@

namespace fst {

template<typename Label>
template<typename Arc>
struct LabelToString {
std::string operator()(Label label) const { return std::to_string(label); }
std::string operator()(const Arc&, typename Arc::Label label,
std::string_view) const {
return std::to_string(label);
}
};

// Print a binary FST in GraphViz textual format (helper class for fstdraw.cc).
Expand Down Expand Up @@ -135,12 +138,12 @@ class FstDrawer {

void PrintStateId(StateId s) const { PrintId(s, ssyms_, "state ID"); }

void PrintILabel(Label label) const {
PrintLabel(label, isyms_, "arc input label");
void PrintILabel(const Arc& arc) const {
PrintLabel(arc, arc.ilabel, isyms_, "arc input label");
}

void PrintOLabel(Label label) const {
PrintLabel(label, osyms_, "arc output label");
void PrintOLabel(const Arc& arc) const {
PrintLabel(arc, arc.olabel, osyms_, "arc output label");
}

void PrintWeight(Weight w) const {
Expand All @@ -153,22 +156,15 @@ class FstDrawer {
*ostrm_ << t;
}

void PrintLabel(int32_t id, const SymbolTable* syms, const char* name) const {
template<class T>
void PrintLabel(const Arc& arc, T label, const SymbolTable* syms,
std::string_view name) const {
if (syms) {
auto symbol = syms->Find(id);
if (!symbol.empty()) {
PrintString(Escape(symbol));
} else {
PrintString(label_to_string_(id));
if (auto symbol = syms->Find(label); !symbol.empty()) {
return PrintString(Escape(symbol));
}
} else {
PrintString(label_to_string_(id));
}
}

template<class T>
void PrintLabel(const T& label, const SymbolTable*, const char*) const {
*ostrm_ << label_to_string_(label);
PrintString(label_to_string_(arc, label, name));
}

template<class T>
Expand Down Expand Up @@ -208,10 +204,10 @@ class FstDrawer {
PrintString(" -> ");
Print(arc.nextstate);
PrintString(" [label = \"");
PrintILabel(arc.ilabel);
PrintILabel(arc);
if (!accep_) {
PrintString(":");
PrintOLabel(arc.olabel);
PrintOLabel(arc);
}
if (show_weight_one_ || (arc.weight != Weight::One())) {
PrintString("/");
Expand Down Expand Up @@ -268,7 +264,7 @@ inline void drawFst(
}

template<typename Fst,
typename LabelToString = fst::LabelToString<typename Fst::Arc::Label>>
typename LabelToString = fst::LabelToString<typename Fst::Arc>>
inline bool drawFst(const Fst& fst, const std::string& dest,
const LabelToString& label_to_string = {},
const SymbolTable* isyms = nullptr,
Expand Down
137 changes: 80 additions & 57 deletions external/kaldi/src/fstext/determinize-star-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ template<class F> class DeterminizerStar {
assert(cur_id == 0 && "Do not call Determinize twice.");
}
std::vector<Element> closed_subset;
std::vector<std::pair<std::pair<Label, bool>, Element>> all_elems;
std::vector<RangeElement> all_elems;
while (!Q_.empty()) {
std::pair<std::vector<Element>*, OutputStateId> cur_pair = Q_.front();
Q_.pop_front();
Expand Down Expand Up @@ -269,6 +269,15 @@ template<class F> class DeterminizerStar {
}
};

struct RangeElement {
Label min;
Label max;
// TODO(MBkkt) unify max and is_max: is_max == true <=> max == NoLabel?
bool is_max;
// TODO(MBkkt) store element in separate array, to avoid unnecessary copies
Element element;
};

// Arcs in the format we temporarily create in this class (a representation, essentially of
// a Gallic Fst).
struct TempArc {
Expand Down Expand Up @@ -507,87 +516,101 @@ template<class F> class DeterminizerStar {
// and output_arcs_.
void ProcessTransitions(
std::vector<Element>& closed_subset,
std::vector<std::pair<std::pair<Label, bool>, Element>>& all_elems,
std::vector<RangeElement>& all_elems,
OutputStateId state) {
all_elems.clear();

{ // Push back into "all_elems", elements corresponding to all non-epsilon-input transitions
{
std::vector<Label> seq;
// Push back into "all_elems", elements corresponding to all non-epsilon-input transitions
// out of all states in "closed_subset".
for (const Element& elem : closed_subset) {
for (ArcIterator<Fst<Arc> > aiter(*ifst_, elem.state); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
if (arc.ilabel != 0) { // Non-epsilon transition -- ignore epsilons here.
std::pair<std::pair<Label, bool>, Element> this_pr;
Element &next_elem(this_pr.second);
next_elem.state = arc.nextstate;
next_elem.weight = Times(elem.weight, arc.weight);
if (arc.olabel == 0) // output epsilon-- this is simple case so
// handle separately for efficiency
next_elem.string = elem.string;
else {
std::vector<Label> seq;
Element element{
.state = arc.nextstate,
.weight = Times(elem.weight, arc.weight),
};
if (arc.olabel == 0) {
// output epsilon-- this is simple case so handle separately for efficiency
element.string = elem.string;
} else {
seq.clear();
repository_.SeqOfId(elem.string, &seq);
seq.push_back(arc.olabel);
next_elem.string = repository_.IdOfSeq(seq);
element.string = repository_.IdOfSeq(seq);
}

this_pr.first = { arc.min, false };
all_elems.emplace_back(this_pr);
this_pr.first = { arc.max, true };
all_elems.emplace_back(this_pr);
all_elems.emplace_back(RangeElement{arc.min, arc.max, false, element});
all_elems.emplace_back(RangeElement{arc.max, arc.max, true, std::move(element)});
}
}
}
}
// now sorted first on input label bound, bound type, then on state.
if (!epsilon_closure_.FstSorted() || closed_subset.size() > 1) {
std::sort(all_elems.begin(), all_elems.end(), [](const auto &p1, const auto &p2) noexcept {
if (p1.first.first < p2.first.first) {
std::sort(all_elems.begin(), all_elems.end(), [](const auto& lhs, const auto& rhs) {
if (lhs.min < rhs.min) {
return true;
} else if (p1.first.first > p2.first.first) {
} else if (lhs.min > rhs.min) {
return false;
} else if (p1.first.second < p2.first.second) {
}
if (!lhs.is_max && rhs.is_max) {
return true;
} else if (p1.first.second > p2.first.second) {
} else if (lhs.is_max && !rhs.is_max) {
return false;
} else {
return p1.second.state < p2.second.state;
}
if (lhs.is_max) {
// same max elements just sorted by state
IRS_ASSERT(lhs.element.state != rhs.element.state && &lhs != &rhs);
return lhs.element.state < rhs.element.state;
}
// same min elements sorted opposite to their max elements
if (lhs.max < rhs.max) {
return false;
} else if (lhs.max > rhs.max) {
return true;
}
IRS_ASSERT(lhs.element.state != rhs.element.state && &lhs != &rhs);
return lhs.element.state > rhs.element.state;
});
}

// reuse memory as we don't need data anymore
std::vector<Element>& subset = closed_subset;
subset.clear();
fsa::RangeLabel label;

for (auto& e : all_elems) {
const auto [bound, is_max] = e.first;

if (!is_max) {
if (label.ilabel != fst::kNoLabel && label.min != bound) {
label.max = bound - 1;
assert(!subset.empty() && label.min <= label.max);
ProcessTransition(state, label.ilabel, &subset);
}

subset.emplace_back(e.second);
label.min = bound;
} else {
if (label.max != bound) {
label.max = bound;
assert(!subset.empty() && label.min <= label.max);
ProcessTransition(state, label.ilabel, &subset);
label.min = bound + 1;
}

assert(!subset.empty());
subset.pop_back();
if (subset.empty()) {
label.ilabel = fst::kNoLabel;
}
}
}
// reuse memory as we don't need data anymore
std::vector<Element>& subset = closed_subset;
subset.clear();
fsa::RangeLabel label;

for (auto& e : all_elems) {
const auto bound = e.min;
const bool is_max = e.is_max;

if (!is_max) {
if (label.ilabel != fst::kNoLabel && label.min != bound) {
label.max = bound - 1;
assert(!subset.empty() && label.min <= label.max);
ProcessTransition(state, label.ilabel, &subset);
}

// TODO(MBkkt) move?
subset.emplace_back(e.element);
label.min = bound;
} else {
if (label.max != bound) {
label.max = bound;
assert(!subset.empty() && label.min <= label.max);
ProcessTransition(state, label.ilabel, &subset);
label.min = bound + 1;
}

assert(!subset.empty());
subset.pop_back();
if (subset.empty()) {
label.ilabel = fst::kNoLabel;
}
}
}

assert(subset.empty());
}
Expand Down Expand Up @@ -630,7 +653,7 @@ template<class F> class DeterminizerStar {
// to the queue).
void ProcessSubset(
const std::pair<std::vector<Element>*, OutputStateId>& pair,
std::vector<std::pair<std::pair<Label, bool>, Element>>* all_elems,
std::vector<RangeElement>* all_elems,
std::vector<Element>* closed_subset) { // subset after epsilon closure.
OutputStateId state = pair.second;

Expand Down
64 changes: 63 additions & 1 deletion tests/utils/wildcard_utils_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,72 @@ class wildcard_utils_test : public test_base {
fst::kILabelSorted | fst::kOLabelSorted | fst::kIDeterministic |
fst::kAcceptor | fst::kUnweighted;

ASSERT_EQ(EXPECTED_PROPERTIES, a.Properties(EXPECTED_PROPERTIES, true));
EXPECT_EQ(EXPECTED_PROPERTIES, a.Properties(EXPECTED_PROPERTIES, true));
}
};

TEST_F(wildcard_utils_test, same_start) {
{
auto a = irs::from_wildcard("%р%");
assert_properties(a);

bool r = irs::accept<irs::byte_type>(
a, irs::ViewCast<irs::byte_type>(std::string_view("р")));
EXPECT_TRUE(r);
r = irs::accept<irs::byte_type>(
a, irs::ViewCast<irs::byte_type>(std::string_view("с")));
EXPECT_FALSE(r);
r = irs::accept<irs::byte_type>(
a, irs::ViewCast<irs::byte_type>(std::string_view("ё")));
EXPECT_FALSE(r);
}
{
auto a = irs::from_wildcard("%ара%");
assert_properties(a);

bool r = irs::accept<irs::byte_type>(
a, irs::ViewCast<irs::byte_type>(std::string_view("ара")));
EXPECT_TRUE(r);
r = irs::accept<irs::byte_type>(
a, irs::ViewCast<irs::byte_type>(std::string_view("аса")));
EXPECT_FALSE(r);
r = irs::accept<irs::byte_type>(
a, irs::ViewCast<irs::byte_type>(std::string_view("аёа")));
EXPECT_FALSE(r);
}
}

TEST_F(wildcard_utils_test, same_end) {
{
auto a = irs::from_wildcard("%ѿ%");
assert_properties(a);

bool r = irs::accept<irs::byte_type>(
a, irs::ViewCast<irs::byte_type>(std::string_view("ѿ")));
EXPECT_TRUE(r);
r = irs::accept<irs::byte_type>(
a, irs::ViewCast<irs::byte_type>(std::string_view("с")));
EXPECT_FALSE(r);
r = irs::accept<irs::byte_type>(
a, irs::ViewCast<irs::byte_type>(std::string_view("ё")));
EXPECT_FALSE(r);
}
{
auto a = irs::from_wildcard("%аѿа%");
assert_properties(a);

bool r = irs::accept<irs::byte_type>(
a, irs::ViewCast<irs::byte_type>(std::string_view("аѿа")));
EXPECT_TRUE(r);
r = irs::accept<irs::byte_type>(
a, irs::ViewCast<irs::byte_type>(std::string_view("аса")));
EXPECT_FALSE(r);
r = irs::accept<irs::byte_type>(
a, irs::ViewCast<irs::byte_type>(std::string_view("аёа")));
EXPECT_FALSE(r);
}
}

TEST_F(wildcard_utils_test, match_wildcard) {
{
auto a = irs::from_wildcard("%rc%");
Expand Down

0 comments on commit 7ca8bef

Please sign in to comment.