GRM-SFST  sfst-1.2.1
OpenGrm SFst Library
stationary-distrib.cc
Go to the documentation of this file.
1 // Copyright 2018-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // Computes the stationary distribution of a stochastic FST
15 
17 
18 #include <cmath>
19 #include <vector>
20 
21 #include <fst/log.h>
22 #include <fst/arc.h>
23 #include <fst/float-weight.h>
24 #include <fst/fst.h>
25 #include <fst/signed-log-weight.h>
26 #include <sfst/canonical.h>
27 #include <sfst/sfst.h>
28 
29 namespace sfst {
30 
31 namespace internal {
32 
33 // At a given state, calculate one step of the power method
34 // for the stationary distribution of the closure of the
35 // input stochastic FST with re-entry weight 'alpha'.
36 // TODO(riley): handle epsilons - pushed?
38  const fst::Fst<fst::SignedLog64Arc> &fst,
39  int st,
40  std::vector<fst::SignedLog64Weight> *prev_weight,
41  std::vector<fst::SignedLog64Weight> *weight,
42  fst::SignedLog64Weight alpha) {
43  namespace f = fst;
44  using SLArc = f::SignedLog64Arc;
45  using SLWeight = SLArc::Weight;
46 
47  SLWeight prevw = (*prev_weight)[st];
48  for (f::ArcIterator<f::Fst<SLArc>> aiter(fst, st);
49  !aiter.Done();
50  aiter.Next()) {
51  const SLArc &arc = aiter.Value();
52  SLWeight nextw = Times(prevw, arc.weight);
53  if (arc.ilabel != 0) {
54  (*weight)[arc.nextstate] =
55  Plus((*weight)[arc.nextstate], nextw);
56  } else {
57  (*prev_weight)[arc.nextstate] =
58  Plus((*prev_weight)[arc.nextstate], nextw);
59  }
60  }
61 
62  if (fst.Final(st) != SLWeight::Zero()) {
63  (*weight)[fst.Start()] = Plus((*weight)[fst.Start()],
64  Times(prevw, Times(fst.Final(st), alpha)));
65  }
66 }
67 
69  const fst::Fst<fst::SignedLog64Arc> &fst,
70  std::vector<fst::SignedLog64Weight> *weight,
71  fst::SignedLog64Weight alpha /* kReEntryWeight */,
72  float delta /* = fst::kDelta */,
73  size_t maxiters /* kMaxSDIters */) {
74  namespace f = fst;
75  using SLArc = f::SignedLog64Arc;
76  using StateId = SLArc::StateId;
77  using SLWeight = SLArc::Weight;
78  using LWeight = f::Log64Weight;
79 
80  std::vector<StateId> top_order;
81  if (!PhiTopOrder(fst, f::kNoLabel, &top_order)) { // epsilon top order
82  LOG(ERROR) << "SignedStationaryDistrib: "
83  << "FST has (input) epsilon cycles";
84  return false;
85  }
86 
87  size_t nstates = top_order.size();
88  LWeight ldelta(-std::log(delta));
89 
90  std::vector<SLWeight> prev_weight, tmp_weight;
91  // Initializes to the uniform distribution
92  prev_weight.resize(nstates,
93  SLWeight(1.0, std::log(static_cast<float>(nstates))));
94 
95  size_t changed;
96  size_t niter = 0;
97  do {
98  weight->clear();
99  weight->resize(nstates, SLWeight::Zero());
100  tmp_weight = prev_weight;
101  for (size_t i = 0; i < nstates; ++i) {
102  StateId st = top_order[i]; // ith state in epsilon top order
103  SignedStationaryDistribState(fst, st, &tmp_weight, weight, alpha);
104  }
105 
106  changed = 0;
107  for (size_t st = 0; st < nstates; ++st) {
108  LWeight dw = Minus((*weight)[st], prev_weight[st]).Value2();
109  LWeight dp = Times(ldelta, prev_weight[st].Value2());
110  if (Less(dp, dw)) ++changed;
111  prev_weight[st] = (*weight)[st];
112  }
113  ++niter;
114 
115  // Second clause tests for periodic states
116  if (niter > maxiters ||
117  (niter > nstates && ApproxZero((*weight)[fst.Start()])))
118  return false;
119 
120  VLOG(2) << "SignedStationaryDistrib: state weights changed: "
121  << changed;
122  } while (changed > 0);
123 
124  return true;
125 }
126 
127 } // namespace internal
128 
129 
130 } // namespace sfst
bool Less(fst::LogWeightTpl< T > weight1, fst::LogWeightTpl< T > weight2)
Definition: sfst.h:38
Definition: perplexity.h:32
bool PhiTopOrder(const fst::Fst< Arc > &fst, typename Arc::Label phi_label, std::vector< typename Arc::StateId > *top_order)
Definition: canonical.h:37
Entropy64Weight Minus(Entropy64Weight w1, Entropy64Weight w2)
Definition: perplexity.h:129
Definition: sfstinfo.cc:39
bool ApproxZero(fst::Log64Weight weight, fst::Log64Weight approx_zero=kApproxZeroWeight)
Definition: sfst.h:84
void SignedStationaryDistribState(const fst::Fst< fst::SignedLog64Arc > &fst, int st, std::vector< fst::SignedLog64Weight > *prev_weight, std::vector< fst::SignedLog64Weight > *weight, fst::SignedLog64Weight alpha)
bool SignedStationaryDistrib(const fst::Fst< fst::SignedLog64Arc > &fst, std::vector< fst::SignedLog64Weight > *weight, fst::SignedLog64Weight alpha=internal::kReEntryWeight, float delta=fst::kDelta, size_t maxiters=internal::kMaxSDIters)
Real64Weight Times(SignedLog64Weight w1, Real64Weight w2)
Definition: perplexity.h:41