19 #ifndef NGRAM_NGRAM_COMPLETE_H_ 20 #define NGRAM_NGRAM_COMPLETE_H_ 25 #include <fst/arcsort.h> 27 #include <fst/matcher.h> 28 #include <fst/mutable-fst.h> 37 const fst::Fst<Arc> &
fst,
int order,
typename Arc::Label backoff_label,
38 std::vector<std::vector<typename Arc::StateId>> *order_states,
39 std::vector<int> *state_orders,
40 std::vector<typename Arc::StateId> *backoff_states) {
41 if (order >= order_states->size())
return false;
42 for (
int i = 0; i < (*order_states)[
order].size(); ++i) {
43 auto s = (*order_states)[
order][i];
44 if ((*state_orders)[s] !=
order) {
45 NGRAMERROR() <<
"State " << s <<
" included in vector of states with " 46 <<
"order " << order <<
", but that is not the case";
49 for (fst::ArcIterator<fst::Fst<Arc>> aiter(fst, s); !aiter.Done();
51 const Arc &arc = aiter.Value();
52 if (arc.ilabel == backoff_label) {
53 (*backoff_states)[s] = arc.nextstate;
54 }
else if ((*state_orders)[arc.nextstate] <= 0) {
55 if (order_states->size() <= order + 1) order_states->resize(order + 2);
56 (*state_orders)[arc.nextstate] = order + 1;
57 (*order_states)[order + 1].push_back(arc.nextstate);
60 if (order > 1 && (*backoff_states)[s] == fst::kNoStateId) {
61 NGRAMERROR() <<
"No backoff state for higher order state " << s;
70 typename Arc::Label backoff_label = 0) {
71 typedef typename Arc::StateId StateId;
72 typedef typename Arc::Weight Weight;
73 typedef typename Arc::Label Label;
74 if (fst->NumStates() < 2)
return true;
75 fst::ArcIterator<fst::Fst<Arc>> aiter(*fst, fst->Start());
76 const Arc &arc = aiter.Value();
77 if (arc.ilabel != backoff_label) {
78 NGRAMERROR() <<
"First arc out of start state is not backoff arc";
81 StateId unigram_state = arc.nextstate;
82 std::vector<int> state_orders(fst->NumStates());
83 std::vector<StateId> backoff_states(fst->NumStates(), fst::kNoStateId);
84 std::vector<std::vector<StateId>> order_states(3);
85 order_states[1].push_back(unigram_state);
86 state_orders[unigram_state] = 1;
87 order_states[2].push_back(fst->Start());
88 state_orders[fst->Start()] = 2;
89 backoff_states[fst->Start()] = unigram_state;
91 while (st_order < order_states.size() && !order_states[st_order].empty()) {
93 &order_states, &state_orders,
99 for (StateId s = 0; s < fst->NumStates(); ++s) {
100 if (s != unigram_state && backoff_states[s] == fst::kNoStateId) {
101 NGRAMERROR() <<
"state with no backoff state: " << s;
104 if (state_orders[s] <= 0) {
105 NGRAMERROR() <<
"state not on ascending path: " << s;
110 std::vector<std::set<Label>> label_sets(fst->NumStates());
111 std::set<StateId> new_final_states;
113 std::unique_ptr<fst::Matcher<fst::Fst<Arc>>> matcher(
114 new fst::Matcher<fst::Fst<Arc>>(*fst, fst::MATCH_INPUT));
117 for (
int idx = 0; idx < order_states[
order].size(); ++idx) {
118 StateId s = order_states[
order][idx];
119 StateId bs = backoff_states[s];
120 if (state_orders[s] != state_orders[bs] + 1) {
121 NGRAMERROR() <<
"State " << s <<
" backs off more than one order";
124 matcher->SetState(bs);
125 for (fst::ArcIterator<fst::Fst<Arc>> aiter(*fst, s);
126 !aiter.Done(); aiter.Next()) {
127 const Arc &arc = aiter.Value();
128 if (arc.ilabel == backoff_label || matcher->Find(arc.ilabel))
continue;
129 label_sets[bs].insert(arc.ilabel);
132 for (
auto iter = label_sets[s].begin(); iter != label_sets[s].end();
134 if (matcher->Find(*iter))
continue;
135 label_sets[bs].insert(*iter);
140 new_final_states.count(s) != 0) &&
143 new_final_states.insert(bs);
149 for (
int idx = 0; idx < order_states[
order].size(); ++idx) {
150 StateId s = order_states[
order][idx];
151 if (label_sets[s].empty())
continue;
152 StateId bs = backoff_states[s];
153 std::unique_ptr<fst::Matcher<fst::Fst<Arc>>> updated_matcher(
154 new fst::Matcher<fst::Fst<Arc>>(*fst, fst::MATCH_INPUT));
156 updated_matcher->SetState(s);
158 updated_matcher->SetState(bs);
161 std::vector<Arc> arcs;
162 arcs.reserve(fst->NumArcs(s) + label_sets[s].size());
164 for (
auto iter = label_sets[s].begin(); iter != label_sets[s].end();
166 StateId nextstate = unigram_state;
167 if (bs >= 0 && updated_matcher->Find(*iter))
168 nextstate = updated_matcher->Value().nextstate;
178 for (fst::ArcIterator<fst::Fst<Arc>> aiter(*fst, s);
179 !aiter.Done(); aiter.Next())
180 arcs.push_back(aiter.Value());
183 std::sort(arcs.begin(), arcs.end(), fst::ILabelCompare<Arc>());
185 for (
size_t i = 0; i < arcs.size(); ++i) fst->AddArc(s, arcs[i]);
189 for (
auto iter = new_final_states.begin(); iter != new_final_states.end();
202 #endif // NGRAM_NGRAM_COMPLETE_H_
bool NGramComplete(fst::MutableFst< Arc > *fst, typename Arc::Label backoff_label=0)
bool AscendAndCollectStateInfo(const fst::Fst< Arc > &fst, int order, typename Arc::Label backoff_label, std::vector< std::vector< typename Arc::StateId >> *order_states, std::vector< int > *state_orders, std::vector< typename Arc::StateId > *backoff_states)