16 #ifndef NLP_GRM2_SFST_BACKOFF_H_ 17 #define NLP_GRM2_SFST_BACKOFF_H_ 19 #include <sys/types.h> 26 #include <fst/float-weight.h> 28 #include <fst/matcher.h> 52 using Label =
typename Arc::Label;
54 using Matr = fst::SortedMatcher<fst::Fst<Arc>>;
59 bool require_backoff_complete =
true);
79 return states_[s].backoff_position;
86 auto iter = arc_map_.find(std::make_pair(s, pos));
87 return iter != arc_map_.end() ? iter->second : -1;
94 bool Error()
const {
return error_; }
98 void FindBackedOffArcs(
StateId s);
100 void SetError()
const { error_ =
true; }
103 struct BackoffState {
105 ssize_t backoff_position;
108 BackoffState(
StateId s, ssize_t p,
int o)
109 : backoff_state(s), backoff_position(p), order(o) {}
112 : backoff_state(fst::kNoStateId), backoff_position(-1), order(1) {}
116 using Pair = std::pair<StateId, ssize_t>;
119 size_t operator()(
const Pair &p)
const {
120 return (static_cast<size_t>(p.first) * 55697) ^
121 (
static_cast<size_t>(p.second) * 54631);
125 using PairArcMap = std::unordered_map<Pair, ssize_t, PairHash>;
127 const fst::Fst<Arc> &fst_;
130 std::vector<StateId> top_order_;
131 std::vector<BackoffState> states_;
134 bool require_backoff_complete_;
135 mutable bool is_backoff_;
144 bool require_backoff_complete)
146 phi_label_(phi_label),
148 require_backoff_complete_(require_backoff_complete),
153 if (fst_.Start() == f::kNoStateId) {
155 if (require_backoff_complete_) {
156 FSTERROR() <<
"Backoff: FST has no states";
162 if (!fst.Properties(f::kILabelSorted,
true)) {
163 FSTERROR() <<
"Backoff: FST is not canonical (ilabel-sorted)";
168 bool acyclic =
PhiTopOrder(fst_, phi_label_, &top_order_);
170 FSTERROR() <<
"Backoff: FST is not canonical (phi-cyclic)";
175 states_.resize(top_order_.size());
177 if (phi_label_ == f::kNoStateId)
return;
179 for (
StateId i = top_order_.size() - 1; i >= 0; --i) {
181 BackoffState &bo_state = states_[s];
182 Matr matcher(fst_, f::MATCH_INPUT);
184 if (matcher.Find(phi_label_)) {
185 for (; !matcher.Done(); matcher.Next()) {
186 const Arc &arc = matcher.Value();
188 if (arc.ilabel == f::kNoLabel)
continue;
190 if (bo_state.backoff_state != f::kNoStateId) {
191 FSTERROR() <<
"Backoff: FST is not canonical (phi-non-determinism)";
195 bo_state.backoff_state = arc.nextstate;
196 bo_state.backoff_position = matcher.Position();
197 bo_state.order = states_[bo_state.backoff_state].order + 1;
198 if (bo_state.order > order_) order_ = bo_state.order;
199 FindBackedOffArcs(s);
208 Matr matcher(fst_, f::MATCH_INPUT);
209 matcher.SetState(states_[s].backoff_state);
210 for (f::ArcIterator<f::Fst<Arc>> aiter(fst_, s); !aiter.Done();
212 const Arc &arc = aiter.Value();
213 if (arc.ilabel == phi_label_)
continue;
214 if (arc.ilabel == 0) {
216 if (require_backoff_complete_) {
217 FSTERROR() <<
"Backoff: non-failure epsilons disallowed";
223 if (matcher.Find(arc.ilabel)) {
224 Pair p(s, aiter.Position());
225 arc_map_[p] = matcher.Position();
228 if (require_backoff_complete_) {
229 FSTERROR() <<
"Backoff: no backed-off arc with label " << arc.ilabel
230 <<
" from state " << s;
241 typename Arc::Label phi_label) {
249 template <
class Arc,
class CompWeight = fst::Log64Weight>
252 using Weight =
typename Arc::Weight;
253 using StateId =
typename Arc::StateId;
254 fst::WeightConvert<CompWeight, Weight> from_compute_weight;
255 fst::WeightConvert<Weight, CompWeight> to_compute_weight;
258 for (
StateId i = 0; i < fst->NumStates(); ++i) {
262 if (bos == f::kNoStateId)
continue;
263 Weight final = fst->Final(s);
264 if (
final != Weight::Zero()) {
265 Weight bo_final = fst->Final(bos);
266 CompWeight w1 = to_compute_weight(bo_final);
267 CompWeight w2 = to_compute_weight(
final);
268 fst->SetFinal(bos, from_compute_weight(Plus(w1, w2)));
270 f::MutableArcIterator<f::MutableFst<Arc>> miter(fst, bos);
271 for (f::ArcIterator<f::Fst<Arc>> aiter(*fst, s); !aiter.Done();
273 const Arc &arc = aiter.Value();
274 ssize_t pos = aiter.Position();
275 if (arc.ilabel == phi_label)
continue;
277 CHECK_NE(bo_pos, -1);
279 Arc bo_arc = miter.Value();
280 CompWeight w1 = to_compute_weight(bo_arc.weight);
281 CompWeight w2 = to_compute_weight(arc.weight);
282 bo_arc.weight = from_compute_weight(Plus(w1, w2));
283 miter.SetValue(bo_arc);
287 return !backoff.
Error();
298 typename Arc::Weight bo_zero = Arc::Weight::Zero()) {
300 using Weight =
typename Arc::Weight;
301 using StateId =
typename Arc::StateId;
302 f::WeightConvert<f::Log64Weight, Weight> from_log;
303 f::WeightConvert<Weight, f::Log64Weight> to_log;
306 for (
StateId i = fst->NumStates() - 1; i >= 0; --i) {
310 if (bos == f::kNoStateId)
continue;
311 Weight final = fst->Final(s);
312 if (
final != Weight::Zero()) {
313 Weight bo_final = fst->Final(bos);
314 f::Log64Weight w1 = to_log(bo_final);
315 f::Log64Weight w2 = to_log(
final);
316 fst->SetFinal(bos,
Less(w2, w1) ? from_log(
Minus(w1, w2)) : bo_zero);
318 f::MutableArcIterator<f::MutableFst<Arc>> miter(fst, bos);
319 for (f::ArcIterator<f::Fst<Arc>> aiter(*fst, s); !aiter.Done();
321 const Arc &arc = aiter.Value();
322 ssize_t pos = aiter.Position();
323 if (arc.ilabel == phi_label)
continue;
325 CHECK_NE(bo_pos, -1);
327 Arc bo_arc = miter.Value();
328 f::Log64Weight w1 = to_log(bo_arc.weight);
329 f::Log64Weight w2 = to_log(arc.weight);
330 bo_arc.weight =
Less(w2, w1) ? from_log(
Minus(w1, w2)) : bo_zero;
331 miter.SetValue(bo_arc);
335 return !backoff.
Error();
340 #endif // NLP_GRM2_SFST_BACKOFF_H_ bool Less(fst::LogWeightTpl< T > weight1, fst::LogWeightTpl< T > weight2)
typename Arc::Label Label
bool PhiTopOrder(const fst::Fst< Arc > &fst, typename Arc::Label phi_label, std::vector< typename Arc::StateId > *top_order)
Backoff(const fst::Fst< Arc > &fst, Label phi_label, bool require_backoff_complete=true)
Entropy64Weight Minus(Entropy64Weight w1, Entropy64Weight w2)
bool SumBackoff(fst::MutableFst< Arc > *fst, typename Arc::Label phi_label)
typename Arc::Weight Weight
StateId GetBackoffState(StateId s) const
bool IsBackoffComplete() const
StateId GetBackoffPosition(StateId s) const
typename Arc::StateId StateId
ssize_t GetBackedOffArc(StateId s, ssize_t pos) const
StateId GetPhiTopOrder(StateId i) const
int Order(StateId s) const
bool DiffBackoff(fst::MutableFst< Arc > *fst, typename Arc::Label phi_label, typename Arc::Weight bo_zero=Arc::Weight::Zero())
fst::SortedMatcher< fst::Fst< Arc >> Matr