23 #include <fst/float-weight.h> 25 #include <fst/signed-log-weight.h> 38 const fst::Fst<fst::SignedLog64Arc> &
fst,
40 std::vector<fst::SignedLog64Weight> *prev_weight,
41 std::vector<fst::SignedLog64Weight> *weight,
42 fst::SignedLog64Weight alpha) {
44 using SLArc = f::SignedLog64Arc;
45 using SLWeight = SLArc::Weight;
47 SLWeight prevw = (*prev_weight)[st];
48 for (f::ArcIterator<f::Fst<SLArc>> aiter(fst, st);
51 const SLArc &arc = aiter.Value();
52 SLWeight nextw =
Times(prevw, arc.weight);
53 if (arc.ilabel != 0) {
54 (*weight)[arc.nextstate] =
55 Plus((*weight)[arc.nextstate], nextw);
57 (*prev_weight)[arc.nextstate] =
58 Plus((*prev_weight)[arc.nextstate], nextw);
62 if (fst.Final(st) != SLWeight::Zero()) {
63 (*weight)[fst.Start()] = Plus((*weight)[fst.Start()],
69 const fst::Fst<fst::SignedLog64Arc> &
fst,
70 std::vector<fst::SignedLog64Weight> *weight,
71 fst::SignedLog64Weight alpha ,
75 using SLArc = f::SignedLog64Arc;
76 using StateId = SLArc::StateId;
77 using SLWeight = SLArc::Weight;
78 using LWeight = f::Log64Weight;
80 std::vector<StateId> top_order;
82 LOG(ERROR) <<
"SignedStationaryDistrib: " 83 <<
"FST has (input) epsilon cycles";
87 size_t nstates = top_order.size();
88 LWeight ldelta(-std::log(delta));
90 std::vector<SLWeight> prev_weight, tmp_weight;
92 prev_weight.resize(nstates,
93 SLWeight(1.0, std::log(static_cast<float>(nstates))));
99 weight->resize(nstates, SLWeight::Zero());
100 tmp_weight = prev_weight;
101 for (
size_t i = 0; i < nstates; ++i) {
102 StateId st = top_order[i];
107 for (
size_t st = 0; st < nstates; ++st) {
108 LWeight dw =
Minus((*weight)[st], prev_weight[st]).Value2();
109 LWeight dp =
Times(ldelta, prev_weight[st].Value2());
110 if (
Less(dp, dw)) ++changed;
111 prev_weight[st] = (*weight)[st];
116 if (niter > maxiters ||
117 (niter > nstates &&
ApproxZero((*weight)[fst.Start()])))
120 VLOG(2) <<
"SignedStationaryDistrib: state weights changed: " 122 }
while (changed > 0);
bool Less(fst::LogWeightTpl< T > weight1, fst::LogWeightTpl< T > weight2)
bool PhiTopOrder(const fst::Fst< Arc > &fst, typename Arc::Label phi_label, std::vector< typename Arc::StateId > *top_order)
Entropy64Weight Minus(Entropy64Weight w1, Entropy64Weight w2)
bool ApproxZero(fst::Log64Weight weight, fst::Log64Weight approx_zero=kApproxZeroWeight)
void SignedStationaryDistribState(const fst::Fst< fst::SignedLog64Arc > &fst, int st, std::vector< fst::SignedLog64Weight > *prev_weight, std::vector< fst::SignedLog64Weight > *weight, fst::SignedLog64Weight alpha)
bool SignedStationaryDistrib(const fst::Fst< fst::SignedLog64Arc > &fst, std::vector< fst::SignedLog64Weight > *weight, fst::SignedLog64Weight alpha=internal::kReEntryWeight, float delta=fst::kDelta, size_t maxiters=internal::kMaxSDIters)
Real64Weight Times(SignedLog64Weight w1, Real64Weight w2)