GRM-SFST  sfst-1.2.1
OpenGrm SFst Library
state-weights.h
Go to the documentation of this file.
1 // Copyright 2018-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // Routines to manipulate weights associated with stochastic FST states.
15 
16 #ifndef NLP_GRM2_SFST_STATE_WEIGHTS_H_
17 #define NLP_GRM2_SFST_STATE_WEIGHTS_H_
18 
19 #include <iostream>
20 #include <ostream>
21 #include <string>
22 #include <vector>
23 
24 #include <fst/float-weight.h>
25 #include <fst/fst.h>
26 #include <fst/matcher.h>
27 #include <fst/signed-log-weight.h>
28 #include <fst/weight.h>
29 #include <sfst/sfst.h>
30 
31 namespace sfst {
32 
33 namespace internal {
34 
35 //
36 // Weight vector utilities
37 //
38 
39 // Useful with STL. 'delta' is the FST weight approximation delta
40 // and approx_zero a value considered approximately Zero(). If
41 // the latter is unset, a default is used.
42 template <class Weight>
44  public:
45  explicit ApproxEqualPred(float delta = fst::kDelta,
46  Weight approx_zero = Weight::NoWeight())
47  : delta_(delta),
48  approx_zero_(approx_zero.Member() ? to_log_(approx_zero) :
50 
51  bool operator()(const Weight &w1, const Weight &w2) const {
52  return fst::ApproxEqual(w1, w2, delta_) ||
53  (Less(to_log_(w1), approx_zero_) &&
54  Less(to_log_(w2), approx_zero_));
55  }
56 
57  private:
58  fst::WeightConvert<Weight, fst::Log64Weight> to_log_;
59  float delta_;
60  fst::Log64Weight approx_zero_;
61 };
62 
63 // Specialization for signed log weight.
64 template <typename T>
65 class ApproxEqualPred<fst::SignedLogWeightTpl<T>> {
66  public:
67  using Weight = fst::SignedLogWeightTpl<T>;
68  explicit ApproxEqualPred(float delta = fst::kDelta,
69  Weight approx_zero = Weight::NoWeight())
70  : delta_(delta),
71  approx_zero_(approx_zero.Member() ? approx_zero :
72  Weight(1.0, kApproxZeroWeight.Value())) {}
73 
74  bool operator()(const Weight &w1, const Weight &w2) const {
75  return fst::ApproxEqual(w1, w2, delta_) ||
76  (Less(w1.Value2(), approx_zero_.Value2()) &&
77  Less(w2.Value2(), approx_zero_.Value2()));
78  }
79 
80  private:
81  float delta_;
82  Weight approx_zero_;
83 };
84 
85 } // namespace internal
86 
87 // Normalizes a weight vector.
88 template <class Weight>
89 void NormWeights(std::vector<Weight> *weights) {
90  namespace f = fst;
91  f::WeightConvert<f::Log64Weight, Weight> from_log;
92  f::WeightConvert<Weight, f::Log64Weight> to_log;
93 
94  f::Adder<f::Log64Weight> sum;
95  for (size_t i = 0; i < weights->size(); ++i)
96  sum.Add(to_log((*weights)[i]));
97  for (size_t i = 0; i < weights->size(); ++i)
98  (*weights)[i] = Divide((*weights)[i], from_log(sum.Sum()));
99 }
100 
101 // Checks two vectors contain approximately the same weights.
102 // See ApproxEqualPred for the approximation argument explanation.
103 template <class Weight>
104 bool ApproxEqualWeights(const std::vector<Weight> &weights1,
105  const std::vector<Weight> &weights2,
106  float delta = fst::kDelta,
107  Weight approx_zero = Weight::NoWeight()) {
108  internal::ApproxEqualPred<Weight> aepred(delta, approx_zero);
109  if (weights1.size() != weights2.size()) return false;
110  return std::equal(weights1.begin(), weights1.end(),
111  weights2.begin(), aepred);
112 }
113 
114 // Writes weights to stream.
115 template <class Weight>
116 void WriteWeights(std::ostream &strm,
117  const std::vector<Weight> &weights) {
118  strm.precision(9);
119  for (size_t s = 0; s < weights.size(); ++s)
120  strm << s << "\t" << weights[s] << "\n";
121 }
122 
123 // Writes weights to a file. Returns true on success.
124 template <class Weight>
125  bool WriteWeights(const std::string &filename,
126  const std::vector<Weight> &weights) {
127  std::ofstream ostrm;
128  if (!filename.empty()) {
129  ostrm.open(filename);
130  if (!ostrm.good()) {
131  LOG(ERROR) << "WriteWeights: Can't open file: " << filename;
132  return false;
133  }
134  }
135  std::ostream &strm = ostrm.is_open() ? ostrm : std::cout;
136  WriteWeights(strm, weights);
137  if (strm.fail()) {
138  LOG(ERROR) << "WritePotentials: Write failed: "
139  << (filename.empty() ? "standard output" : filename);
140  return false;
141  }
142  return true;
143 }
144 
145 //
146 // State weight utilities
147 //
148 
149 
150 // Converts between state weights that exclude the incoming failure mass
151 // (e.g. as produced by StationaryDistrib) to those that include it
152 // (e.g. as produced by ShortestDistance). The 'fail_arc' bool
153 // determines in the incoming failure arc weight is included. Assumes
154 // (but does not fully check) that the input has the canonical topology
155 // (see canonical.h).
156 template <class Arc>
157 void SumStateWeights(const fst::Fst<Arc> &fst,
158  std::vector<typename Arc::Weight> *weights,
159  typename Arc::Label phi_label, bool fail_arc) {
160  namespace f = fst;
161  using StateId = typename Arc::StateId;
162  using Weight = typename Arc::Weight;
163  using Matr = f::ExplicitMatcher<f::Matcher<f::Fst<Arc>>>;
164 
165  f::WeightConvert<f::Log64Weight, Weight> from_log;
166  f::WeightConvert<Weight, f::Log64Weight> to_log;
167 
168  if (phi_label == f::kNoLabel)
169  return;
170 
171  std::vector<StateId> top_order;
172  bool acyclic = PhiTopOrder(fst, phi_label, &top_order);
173  if (!acyclic) {
174  FSTERROR() << "SumStateWeights: FST not canonical (phi-cyclic)";
175  return;
176  }
177  Matr matcher(fst, f::MATCH_INPUT);
178  for (StateId i = 0; i < top_order.size(); ++i) {
179  StateId s = top_order[i]; // ith state in phi-top order
180  matcher.SetState(s);
181  if (matcher.Find(phi_label)) {
182  const Arc &arc = matcher.Value();
183  f::Log64Weight w1 = to_log((*weights)[arc.nextstate]);
184  f::Log64Weight w2 = to_log(fail_arc ? Times((*weights)[s], arc.weight) :
185  (*weights)[s]);
186  (*weights)[arc.nextstate] = from_log(Plus(w1, w2));
187  }
188  }
189 }
190 
191 // Converts between state weights that include the incoming failure mass
192 // (e.g. as produced by ShortestDistance) to those that exclude it
193 // (e.g. as produced by StationaryDistrib). The 'fail_arc' bool
194 // determines in the incoming failure arc weight is included. Assumes
195 // (but does not fully check) that the input has the canonical topology
196 // (see canonical.h).
197 template <class Arc>
198 void DiffStateWeights(const fst::Fst<Arc> &fst,
199  std::vector<typename Arc::Weight> *weights,
200  typename Arc::Label phi_label, bool fail_arc) {
201  namespace f = fst;
202  using StateId = typename Arc::StateId;
203  using Weight = typename Arc::Weight;
204  using Matr = f::ExplicitMatcher<f::Matcher<f::Fst<Arc>>>;
205 
206  f::WeightConvert<f::Log64Weight, Weight> from_log;
207  f::WeightConvert<Weight, f::Log64Weight> to_log;
208 
209  if (phi_label == f::kNoLabel)
210  return;
211 
212  std::vector<StateId> top_order;
213  bool acyclic = PhiTopOrder(fst, phi_label, &top_order);
214  if (!acyclic) {
215  FSTERROR() << "DiffStateWeights: FST not canonical (phi-cyclic)";
216  return;
217  }
218  Matr matcher(fst, f::MATCH_INPUT);
219  for (StateId i = top_order.size() - 1 ; i >= 0; --i) {
220  StateId s = top_order[i]; // ith state in reverse phi-top order
221  matcher.SetState(s);
222  if (matcher.Find(phi_label)) {
223  const Arc &arc = matcher.Value();
224  f::Log64Weight w1 = to_log((*weights)[arc.nextstate]);
225  f::Log64Weight w2 = to_log(fail_arc ? Times((*weights)[s], arc.weight) :
226  (*weights)[s]);
227  (*weights)[arc.nextstate] = Less(w2, w1) ?
228  from_log(Minus(w1, w2)) : Weight::Zero();
229  }
230  }
231 }
232 
233 // Sums arcs (including superfinal) leaving state. Assumes
234 // (but does not fully check) that the input has the canonical topology
235 // (see canonical.h).
236 template <class Arc>
237 void SumStates(const fst::Fst<Arc> &fst, typename Arc::Label phi_label,
238  std::vector<typename Arc::Weight> *weights) {
239  namespace f = fst;
240  using ArcItr = f::ArcIterator<f::Fst<Arc>>;
241  using StateId = typename Arc::StateId;
242  using Weight = typename Arc::Weight;
243  f::WeightConvert<f::Log64Weight, Weight> from_log;
244  f::WeightConvert<Weight, f::Log64Weight> to_log;
245  weights->clear();
246 
247  std::vector<StateId> top_order;
248  bool acyclic = PhiTopOrder(fst, phi_label, &top_order);
249  if (!acyclic) {
250  FSTERROR() << "SumStates: FST not canonical (phi-cyclic)";
251  return;
252  }
253  weights->resize(top_order.size(), Weight::Zero());
254 
255  for (StateId i = top_order.size() - 1 ; i >= 0; --i) {
256  StateId s = top_order[i]; // ith state in reverse phi-top order
257  f::Adder<f::Log64Weight> ssum(to_log(fst.Final(s)));
258  for (ArcItr aiter(fst, s); !aiter.Done(); aiter.Next()) {
259  const Arc &arc = aiter.Value();
260  ssum.Add(to_log(arc.weight));
261  }
262  (*weights)[s] = from_log(ssum.Sum());
263  }
264 }
265 
266 
267 } // namespace sfst
268 
269 #endif // NLP_GRM2_SFST_STATE_WEIGHTS_H_
bool Less(fst::LogWeightTpl< T > weight1, fst::LogWeightTpl< T > weight2)
Definition: sfst.h:38
bool operator()(const Weight &w1, const Weight &w2) const
Definition: state-weights.h:74
void DiffStateWeights(const fst::Fst< Arc > &fst, std::vector< typename Arc::Weight > *weights, typename Arc::Label phi_label, bool fail_arc)
Definition: perplexity.h:32
bool PhiTopOrder(const fst::Fst< Arc > &fst, typename Arc::Label phi_label, std::vector< typename Arc::StateId > *top_order)
Definition: canonical.h:37
Entropy64Weight Minus(Entropy64Weight w1, Entropy64Weight w2)
Definition: perplexity.h:129
Definition: sfstinfo.cc:39
void WriteWeights(std::ostream &strm, const std::vector< Weight > &weights)
void SumStateWeights(const fst::Fst< Arc > &fst, std::vector< typename Arc::Weight > *weights, typename Arc::Label phi_label, bool fail_arc)
const fst::Log64Weight kApproxZeroWeight
Definition: sfst.h:34
ApproxEqualPred(float delta=fst::kDelta, Weight approx_zero=Weight::NoWeight())
Definition: state-weights.h:45
void NormWeights(std::vector< Weight > *weights)
Definition: state-weights.h:89
ApproxEqualPred(float delta=fst::kDelta, Weight approx_zero=Weight::NoWeight())
Definition: state-weights.h:68
bool ApproxEqualWeights(const std::vector< Weight > &weights1, const std::vector< Weight > &weights2, float delta=fst::kDelta, Weight approx_zero=Weight::NoWeight())
bool operator()(const Weight &w1, const Weight &w2) const
Definition: state-weights.h:51
Real64Weight Times(SignedLog64Weight w1, Real64Weight w2)
Definition: perplexity.h:41
void SumStates(const fst::Fst< Arc > &fst, typename Arc::Label phi_label, std::vector< typename Arc::Weight > *weights)