16 #ifndef NLP_GRM2_SFST_COUNT_H_ 17 #define NLP_GRM2_SFST_COUNT_H_ 19 #include <sys/types.h> 26 #include <fst/arcsort.h> 27 #include <fst/compose.h> 28 #include <fst/expanded-fst.h> 29 #include <fst/float-weight.h> 31 #include <fst/matcher.h> 32 #include <fst/mutable-fst.h> 33 #include <fst/properties.h> 34 #include <fst/signed-log-weight.h> 36 #include <fst/weight.h> 53 using Label =
typename Arc::Label;
55 using SLArc = fst::SignedLog64Arc;
67 Counter(
Label phi_label,
float delta, fst::MutableFst<Arc> *ofst);
75 void Count(
const fst::Fst<Arc> &ifst);
88 StateId AddSuperFinal(fst::MutableFst<Arc> *
fst,
bool *have_phi);
91 void BuildComp(
const fst::Fst<Arc> &ifst,
92 fst::MutableFst<SLArc> *ofst)
const;
96 bool CheckNorm(
const fst::Fst<SLArc> &fst)
const {
98 using StateItr = f::StateIterator<f::Fst<SLArc>>;
99 using ArcItr = f::ArcIterator<f::Fst<SLArc>>;
100 for (StateItr siter(fst); !siter.Done(); siter.Next()) {
102 f::Adder<SLWeight> adder(fst.Final(s));
103 for (ArcItr aiter(fst, s); !aiter.Done(); aiter.Next()) {
104 const SLArc &arc = aiter.Value();
105 adder.Add(arc.weight);
107 if (!ApproxEqual(adder.Sum(), SLWeight::One()))
115 bool ComputeDistances(
bool norm, fst::MutableFst<SLArc> *cfst,
116 std::vector<SLWeight> *distance,
117 std::vector<SLWeight> *rdistance)
const;
120 void CountArcs(
const fst::ExpandedFst<SLArc> &fst,
121 const std::vector<SLWeight> &distance,
122 const std::vector<SLWeight> &rdistance);
132 fst::MutableFst<Arc> *ofst_;
134 std::vector<SLWeight> arc_counts_;
135 std::vector<SLWeight> phi_counts_;
139 fst::WeightConvert<SLWeight, Weight> from_sl_;
147 fst::MutableFst<Arc> *ofst)
148 : phi_label_(phi_label), delta_(delta), ofst_(ofst), sf_label_(-2) {
151 if (ofst_->Start() == f::kNoStateId) {
152 FSTERROR() <<
"Counter: topology FST has no states";
153 ofst_->SetProperties(f::kError, f::kError);
158 const uint64_t props = ofst_->Properties(
159 f::kAcceptor | f::kIDeterministic | f::kILabelSorted | f::kNoEpsilons,
161 if ((props & f::kAcceptor) != f::kAcceptor) {
162 FSTERROR() <<
"Counter: topology FST not an acceptor";
163 ofst_->SetProperties(f::kError, f::kError);
166 if ((props & f::kIDeterministic) != f::kIDeterministic) {
167 FSTERROR() <<
"Counter: topology FST not i-deterministic";
168 ofst_->SetProperties(f::kError, f::kError);
171 if ((props & f::kILabelSorted) != f::kILabelSorted) {
172 FSTERROR() <<
"Counter: topology FST not i-label sorted";
173 ofst_->SetProperties(f::kError, f::kError);
176 if (phi_label != 0 && (props & f::kNoEpsilons) != f::kNoEpsilons) {
177 FSTERROR() <<
"Counter: topology FST not epsilon-free";
178 ofst_->SetProperties(f::kError, f::kError);
188 superfinal_ = AddSuperFinal(ofst_, &have_ophi_);
190 for (
StateId s = 0; s < ofst_->NumStates(); ++s) {
191 for (f::MutableArcIterator<f::MutableFst<Arc>> aiter(ofst_, s);
192 !aiter.Done(); aiter.Next()) {
193 Arc arc = aiter.Value();
194 arc.weight = Weight::One();
195 if (arc.ilabel != phi_label_) {
196 ssize_t arc_id = arc_counts_.size();
200 if (arc.olabel != arc_id) {
201 FSTERROR() <<
"Counter: arc label size needs to be large" 202 <<
" enough to store arc IDs in construction";
203 ofst_->SetProperties(f::kError, f::kError);
206 arc_counts_.push_back(SLWeight::Zero());
215 fst::MutableFst<Arc> *
fst,
bool *have_phi) {
221 StateId sf_state = fst->AddState();
222 fst->SetFinal(sf_state, Weight::One());
223 for (
StateId s = 0; s < fst->NumStates() - 1; ++s) {
224 for (f::ArcIterator<f::Fst<Arc>> aiter(*fst, s); !aiter.Done();
226 const Arc &arc = aiter.Value();
227 if (arc.ilabel == sf_label_) {
228 FSTERROR() <<
"Counter: label " << sf_label_
229 <<
" reserved for the superfinal label";
230 ofst_->SetProperties(f::kError, f::kError);
231 return f::kNoStateId;
233 if (arc.ilabel == phi_label_) *have_phi =
true;
236 if (fst->Final(s) != Weight::Zero()) {
237 fst->AddArc(s, Arc(sf_label_, sf_label_, fst->Final(s), sf_state));
238 fst->SetFinal(s, Weight::Zero());
241 f::ArcSort(fst, f::ILabelCompare<Arc>());
248 if (ifst.Start() == f::kNoStateId) {
249 FSTERROR() <<
"Counter: input FST has no states";
250 ofst_->SetProperties(f::kError, f::kError);
253 if (!ifst.Properties(f::kAcceptor,
true)) {
254 FSTERROR() <<
"Counter: input FST not an acceptor";
255 ofst_->SetProperties(f::kError, f::kError);
259 f::VectorFst<SLArc> cfst;
262 f::VectorFst<Arc> sffst(ifst);
263 AddSuperFinal(&sffst, &have_iphi_);
267 BuildComp(sffst, &cfst);
271 bool cnorm = CheckNorm(cfst);
273 std::vector<SLWeight> distance, rdistance;
274 if (!ComputeDistances(cnorm, &cfst, &distance, &rdistance)) {
275 FSTERROR() <<
"Counter: shortest-distance computation failed.";
276 ofst_->SetProperties(f::kError, f::kError);
281 CountArcs(cfst, distance, rdistance);
286 fst::MutableFst<SLArc> *ofst)
const {
290 f::ComposeFstOptions<Arc, PM, PF> copts;
292 if (have_iphi_ || have_ophi_) {
293 Label iphi_label = have_iphi_ ? phi_label_ : f::kNoLabel;
294 Label ophi_label = have_ophi_ ? phi_label_ : f::kNoLabel;
295 f::MatchType imatch_type = have_iphi_ ? f::MATCH_OUTPUT : f::MATCH_NONE;
296 f::MatchType omatch_type = have_ophi_ ? f::MATCH_INPUT : f::MATCH_NONE;
297 copts.matcher1 =
new PM(ifst, imatch_type, iphi_label);
298 copts.matcher2 =
new PM(*ofst_, omatch_type, ophi_label);
301 f::ComposeFst<Arc> cfst(ifst, *ofst_, copts);
303 const Label rm_phi_label =
304 have_iphi_ && have_ophi_ ? phi_label_ : f::kNoLabel;
310 std::vector<SLWeight> *distance,
311 std::vector<SLWeight> *rdistance)
const {
319 sdist(cfst, phi_label_, delta_);
323 rdistance->resize(distance->size(), SLWeight::One());
328 sdist(cfst, phi_label_, delta_);
336 const std::vector<SLWeight> &distance,
337 const std::vector<SLWeight> &rdistance) {
339 for (
StateId s = 0; s < fst.NumStates(); ++s) {
340 if (s >= distance.size())
continue;
343 for (f::ArcIterator<f::Fst<SLArc>> aiter(fst, s); !aiter.Done();
345 const auto &arc = aiter.Value();
346 if (arc.ilabel && arc.ilabel != phi_label_) {
348 if (arc.nextstate >= rdistance.size())
continue;
349 SLWeight rdist = rdistance[arc.nextstate];
351 count =
Times(count, rdist);
353 ssize_t arc_id = arc.olabel;
354 arc_counts_[arc_id] = Plus(arc_counts_[arc_id], count);
363 using Matr = f::ExplicitMatcher<f::Matcher<f::Fst<Arc>>>;
365 StateId initial = ofst_->Start();
366 phi_counts_.resize(ofst_->NumStates(), SLWeight::Zero());
368 if (phi_label_ == f::kNoLabel)
return;
370 for (
StateId s = 0; s < ofst_->NumStates(); ++s) {
371 for (f::ArcIterator<f::Fst<Arc>> aiter(*ofst_, s); !aiter.Done();
373 const Arc &arc = aiter.Value();
374 if (arc.ilabel == phi_label_)
continue;
375 ssize_t arc_id = arc.olabel;
376 phi_counts_[arc.nextstate] =
377 Plus(phi_counts_[arc.nextstate], arc_counts_[arc_id]);
378 phi_counts_[s] =
Minus(phi_counts_[s], arc_counts_[arc_id]);
384 phi_counts_[initial] = Plus(phi_counts_[initial], phi_counts_[superfinal_]);
386 std::vector<StateId> top_order;
387 bool acyclic =
PhiTopOrder(*ofst_, phi_label_, &top_order);
389 FSTERROR() <<
"Counter: topology FST is not canonical (phi-cyclic)";
390 ofst_->SetProperties(f::kError, f::kError);
394 Matr matcher(*ofst_, f::MATCH_INPUT);
395 for (
StateId i = 0; i < ofst_->NumStates(); ++i) {
398 if (matcher.Find(phi_label_)) {
399 const Arc &arc = matcher.Value();
401 phi_counts_[fails] = Plus(phi_counts_[fails], phi_counts_[s]);
410 for (
StateId s = 0; s < ofst_->NumStates(); ++s) {
411 for (f::MutableArcIterator<f::MutableFst<Arc>> aiter(ofst_, s);
412 !aiter.Done(); aiter.Next()) {
413 Arc arc = aiter.Value();
415 if (arc.ilabel != phi_label_) {
416 arc_count = arc_counts_[arc.olabel];
418 arc_count = phi_counts_[s];
421 if (
Less(arc_count, SLWeight::Zero())) {
422 arc.weight = Weight::Zero();
424 arc.weight = from_sl_(arc_count);
427 if (arc.ilabel != sf_label_) {
428 arc.olabel = arc.ilabel;
431 ofst_->SetFinal(s, arc.weight);
437 std::vector<StateId> dstates(1, superfinal_);
438 ofst_->DeleteStates(dstates);
439 f::ArcSort(ofst_, f::ILabelCompare<Arc>());
448 float delta = fst::kDelta) {
450 using StateId =
typename Arc::StateId;
451 using Weight =
typename Arc::Weight;
452 using ArcItr = f::ArcIterator<f::Fst<Arc>>;
453 using Log64Weight = f::Log64Weight;
455 StateId nstates = CountStates(fst);
457 f::WeightConvert<Weight, Log64Weight> to_log64;
458 std::vector<Log64Weight> in_weight(nstates, Log64Weight::Zero());
459 std::vector<Log64Weight> out_weight(nstates, Log64Weight::Zero());
461 for (
StateId s = 0; s < nstates; ++s) {
462 for (ArcItr aiter(fst, s); !aiter.Done(); aiter.Next()) {
463 const Arc &arc = aiter.Value();
464 const Log64Weight weight = to_log64(arc.weight);
465 in_weight[arc.nextstate] = Plus(in_weight[arc.nextstate], weight);
466 out_weight[s] = Plus(out_weight[s], weight);
468 if (fst.Final(s) != Weight::Zero()) {
469 const Log64Weight weight = to_log64(fst.Final(s));
470 out_weight[s] = Plus(out_weight[s], weight);
471 in_weight[fst.Start()] = Plus(in_weight[fst.Start()], weight);
474 for (
StateId s = 0; s < nstates; ++s) {
475 if (!ApproxEqual(in_weight[s], out_weight[s], delta)) {
476 VLOG(1) <<
"state: " << s
477 <<
" in_weight: " << in_weight[s]
478 <<
" out_weight: " << out_weight[s];
488 #endif // NLP_GRM2_SFST_COUNT_H_ Counter(Label phi_label, float delta, fst::MutableFst< Arc > *ofst)
bool Less(fst::LogWeightTpl< T > weight1, fst::LogWeightTpl< T > weight2)
typename Arc::StateId StateId
typename Arc::Weight Weight
bool PhiTopOrder(const fst::Fst< Arc > &fst, typename Arc::Label phi_label, std::vector< typename Arc::StateId > *top_order)
Entropy64Weight Minus(Entropy64Weight w1, Entropy64Weight w2)
bool ApproxZero(fst::Log64Weight weight, fst::Log64Weight approx_zero=kApproxZeroWeight)
bool IsConservative(const fst::Fst< Arc > &fst, float delta=fst::kDelta)
typename Arc::Label Label
fst::SignedLog64Weight SLWeight
void RmPhi(const fst::Fst< IArc > &ifst, fst::MutableFst< OArc > *ofst, typename IArc::Label phi_label=fst::kNoLabel, fst::MatcherRewriteMode rewrite_mode=fst::MATCHER_REWRITE_AUTO, const WC &weight_convert=WC())
fst::SignedLog64Arc::StateId SLStateId
bool ComputeDistance(std::vector< Weight > *distance, bool reverse=false)
fst::SignedLog64Arc SLArc
void Count(const fst::Fst< Arc > &ifst)
Real64Weight Times(SignedLog64Weight w1, Real64Weight w2)