16 #ifndef NLP_GRM2_SFST_TOPOLOGY_H_ 17 #define NLP_GRM2_SFST_TOPOLOGY_H_ 19 #include <sys/types.h> 26 #include <fst/arcsort.h> 27 #include <fst/float-weight.h> 29 #include <fst/mutable-fst.h> 30 #include <fst/weight.h> 41 using Label =
typename Arc::Label;
58 NGramState(
StateId s,
int o) : backoff_state(s), order(o) { }
61 using Pair = std::pair<ssize_t, ssize_t>;
64 size_t operator()(
const Pair &p)
const {
65 return (static_cast<size_t>(p.first) * 55697) ^
66 (
static_cast<size_t>(p.second) * 54631);
70 using DestMap = std::unordered_map<Pair, ssize_t, PairHash>;
78 void UpdateFinalNGram(
StateId state_id) {
80 while (state_id != f::kNoStateId &&
81 ofst_->Final(state_id) == Weight::Zero()) {
82 ofst_->SetFinal(state_id, Weight::One());
83 state_id = states_[state_id].backoff_state;
89 fst::MutableFst<Arc> *ofst_;
90 std::vector<NGramState> states_;
93 fst::WeightConvert<fst::Log64Weight, Weight> from_log_;
94 fst::WeightConvert<Weight, fst::Log64Weight> to_log_;
102 int order,
Label phi_label, fst::MutableFst<Arc> *ofst)
104 phi_label_(phi_label),
107 ofst->DeleteStates();
110 FSTERROR() <<
"NGramTopology: order must be greater than 0";
111 ofst->SetProperties(f::kError, f::kError);
114 if (phi_label_ == f::kNoLabel) {
115 FSTERROR() <<
"NGramTopology: bad phi label: " << phi_label_;
116 ofst->SetProperties(f::kError, f::kError);
120 backoff_ = ofst->AddState();
121 states_.push_back(NGramState(f::kNoStateId, 1));
125 initial_ = ofst->AddState();
126 states_.push_back(NGramState(backoff_, 2));
128 ofst->SetStart(initial_);
135 if (ifst.Start() == f::kNoStateId) {
136 FSTERROR() <<
"NGramTopology: input FST has no states";
137 ofst_->SetProperties(f::kError, f::kError);
140 if (!ifst.Properties(f::kAcceptor,
true)) {
141 FSTERROR() <<
"NGramTopology: input FST not an acceptor";
142 ofst_->SetProperties(f::kError, f::kError);
147 std::vector<Pair> queue;
148 std::unordered_set<Pair, PairHash> pairset;
150 Pair start_pair(ifst.Start(), initial_);
151 pairset.insert(start_pair);
152 queue.push_back(start_pair);
153 bool non_trivial_label =
false;
154 while (!queue.empty()) {
155 Pair current_pair = queue.back();
156 auto fst_state = current_pair.first;
157 StateId ngram_state = current_pair.second;
159 for (f::ArcIterator<f::Fst<Arc>> aiter(ifst, fst_state);
160 !aiter.Done(); aiter.Next()) {
161 const auto &arc = aiter.Value();
163 Pair next_pair(arc.nextstate, backoff_);
164 if (arc.ilabel && arc.ilabel != phi_label_) {
166 next_pair.second = UpdateNGram(ngram_state, arc.ilabel, &dest_map);
167 non_trivial_label =
true;
168 }
else if (phi_label_ != 0 && arc.ilabel == 0) {
170 next_pair.second = ngram_state;
172 auto iter = pairset.find(next_pair);
173 if (iter == pairset.end()) {
174 pairset.insert(next_pair);
175 queue.push_back(next_pair);
178 if (ifst.Final(fst_state) != Weight::Zero())
179 UpdateFinalNGram(ngram_state);
184 if (!non_trivial_label) {
185 FSTERROR() <<
"NGramTopology: input FST has no non-trivial path";
186 ofst_->SetProperties(f::kError, f::kError);
191 for (
StateId s = 0; s < states_.size(); ++s) {
192 if (states_[s].backoff_state != f::kNoStateId) {
193 ofst_->AddArc(s, Arc(phi_label_, phi_label_, Weight::One(),
194 states_[s].backoff_state));
198 f::ArcSort(ofst_, f::ILabelCompare<Arc>());
199 ofst_->SetInputSymbols(ifst.InputSymbols());
200 ofst_->SetOutputSymbols(ifst.OutputSymbols());
209 Pair p(state_id, label);
210 auto iter = dest_map->find(p);
211 if (iter != dest_map->end())
215 NGramState ngram_state = states_[state_id];
217 Arc arc(label, label, Weight::One(), initial_);
221 StateId backoff_dest = ngram_state.backoff_state == f::kNoStateId ?
223 UpdateNGram(ngram_state.backoff_state, label, dest_map);
226 if (ngram_state.order == order_) {
228 arc.nextstate = backoff_dest;
231 arc.nextstate = ofst_->AddState();
232 NGramState next_ngram_state(
233 backoff_dest == f::kNoStateId ? backoff_ : backoff_dest,
234 ngram_state.order + 1);
235 states_.push_back(next_ngram_state);
238 ofst_->AddArc(state_id, arc);
239 (*dest_map)[p] = arc.nextstate;
240 return arc.nextstate;
245 #endif // NLP_GRM2_SFST_TOPOLOGY_H_
typename Arc::Label Label
void FindNGrams(const fst::Fst< Arc > &ifst)
NGramTopology(int order, Label phi_label, fst::MutableFst< Arc > *ofst)
typename Arc::Weight Weight
typename Arc::StateId StateId