GRM-SFST  sfst-1.2.1
OpenGrm SFst Library
normalize.h
Go to the documentation of this file.
1 // Copyright 2018-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // Algorithms to normalize stochastic FSTs and to check if one is
15 // normalized.
16 
17 #ifndef NLP_GRM2_SFST_NORMALIZE_H_
18 #define NLP_GRM2_SFST_NORMALIZE_H_
19 
20 #include <sys/types.h>
21 
22 #include <algorithm>
23 #include <cmath>
24 #include <cstddef>
25 #include <memory>
26 #include <vector>
27 
28 #include <fst/log.h>
29 #include <fst/float-weight.h>
30 #include <fst/fst.h>
31 #include <fst/matcher.h>
32 #include <fst/mutable-fst.h>
33 #include <fst/push.h>
34 #include <fst/weight.h>
35 #include <sfst/backoff.h>
36 #include <sfst/canonical.h>
37 #include <sfst/sfst.h>
38 #include <sfst/shortest-distance.h>
39 #include <sfst/state-weights.h>
40 #include <sfst/trim.h>
41 #include <fst/compat.h>
42 
43 
44 namespace sfst {
45 
46 // Computes high-order and (if present) low-order sums at a state.
47 // Note that (non-failure) epsilons are treated as regular symbols
48 // where each instance behaves as if it is uniquely labeled (i.e.,
49 // they are not constrained by failure transitions). Assumes (but does
50 // not check) that the input has the canonical topology (see canonical.h).
51 template <class Arc>
52 void StateSums(const fst::Fst<Arc> &fst,
53  typename Arc::StateId s,
54  typename Arc::Label phi_label,
55  fst::Log64Weight *high_sum,
56  fst::Log64Weight *low_sum,
57  fst::Log64Weight *phi_weight,
58  ssize_t *phi_position) {
59  namespace f = fst;
60  typedef typename Arc::StateId StateId;
61  typedef typename Arc::Label Label;
62  typedef typename Arc::Weight Weight;
63  typedef f::ArcIterator<f::Fst<Arc>> ArcItr;
64  typedef f::ExplicitMatcher<f::Matcher<f::Fst<Arc>>> Matr;
65 
66  f::WeightConvert<Weight, f::Log64Weight> to_log64;
67  f::Adder<f::Log64Weight> high_adder(to_log64(fst.Final(s)));
68  f::Adder<f::Log64Weight> low_adder;
69  *phi_weight = f::Log64Weight::Zero();
70  *phi_position = -1;
71 
72  FailurePath<Arc> failpath(fst, phi_label, true);
73  failpath.SetState(s);
74 
75  Weight fail_weight = Weight::One();
76  for (size_t i = 0; i < failpath.Length(); ++i) {
77  if (i == 0) {
78  *phi_weight = to_log64(failpath.GetWeight(i));
79  *phi_position = failpath.GetPosition(i);
80  if (high_adder.Sum() == f::Log64Weight::Zero()) break;
81  } else {
82  fail_weight = f::Times(fail_weight, failpath.GetWeight(i));
83  }
84  Weight final = fst.Final(failpath.GetNextState(i));
85  if (final != Weight::Zero()) {
86  low_adder.Reset(to_log64(f::Times(fail_weight, final)));
87  break;
88  }
89  }
90 
91  Matr matcher(fst, f::MATCH_INPUT);
92  Label prev_label = f::kNoLabel;
93  for (ArcItr aiter(fst, s); !aiter.Done(); aiter.Next()) {
94  const Arc &high_arc = aiter.Value();
95  Label label = high_arc.ilabel;
96  if (label != phi_label) {
97  f::Log64Weight high_weight = to_log64(high_arc.weight);
98  high_adder.Add(high_weight);
99  if (label != 0) {
100  fail_weight = Weight::One();
101  bool matched = label == prev_label;
102  for (size_t i = 0; i < failpath.Length() && !matched; ++i) {
103  matcher.SetState(failpath.GetNextState(i));
104  if (i > 0)
105  fail_weight = Times(fail_weight, failpath.GetWeight(i));
106  for (matcher.Find(label); !matcher.Done(); matcher.Next()) {
107  const Arc &low_arc = matcher.Value();
108  f::Log64Weight low_weight =
109  to_log64(Times(fail_weight, low_arc.weight));
110  low_adder.Add(low_weight);
111  matched = true;
112  }
113  }
114  }
115  }
116  prev_label = label;
117  }
118  *high_sum = high_adder.Sum();
119  *low_sum = low_adder.Sum();
120 }
121 
122 // Tests if a canonical input FST is normalized at state s.
123 template <class Arc>
124 bool IsNormalizedState(const fst::Fst<Arc> &fst,
125  typename Arc::StateId s,
126  typename Arc::Label phi_label,
127  float delta) {
128  namespace f = fst;
129  f::Log64Weight high_sum, low_sum, phi_weight;
130  ssize_t phi_position;
131  StateSums(fst, s, phi_label, &high_sum, &low_sum,
132  &phi_weight, &phi_position);
133  // Checks if high_sum is a proper probability (<=1)
134  bool high_sum_le_one = Less(high_sum, f::Log64Weight::One()) ||
135  ApproxEqual(high_sum, f::Log64Weight::One(), delta);
136  // Checks if high_sum, low_sum, and phi_weight are consistent:
137  // phi_weight = (1 - high_sum)/(1 - low_sum)
138  bool phi_norm = f::ApproxEqual(f::Plus(high_sum, phi_weight),
139  f::Plus(f::Times(phi_weight, low_sum),
140  f::Log64Weight::One()),
141  delta);
142  bool ret = high_sum_le_one && phi_norm;
143  if (!ret) {
144  VLOG(1) << "IsNormalized: State not normalized: " << s
145  << " high_sum: " << high_sum
146  << " low_sum: " << low_sum
147  << " phi_weight: " << phi_weight;
148  }
149  return ret;
150 }
151 
152 // Tests if an input FST has the canonical topology and the weight of
153 // the transitions (plus any final weight) leaving a state sums to
154 // Weight::One(). The summation follows failure transitions through
155 // to an actual transition accumulating the weight. If the input
156 // is trim (see trim.h), then this will correspond to the condition
157 // that the weight of the paths into the future sum to Weight::One().
158 template <class Arc>
159 bool IsNormalized(const fst::Fst<Arc> &fst,
160  typename Arc::Label phi_label = fst::kNoLabel,
161  float delta = fst::kDelta) {
162  namespace f = fst;
163  typedef typename Arc::StateId StateId;
164  typedef f::StateIterator<f::Fst<Arc>> StateItr;
165 
166  if (!IsCanonical(fst, phi_label)) return false;
167  for (StateItr siter(fst); !siter.Done(); siter.Next()) {
168  StateId s = siter.Value();
169  if (!IsNormalizedState(fst, s, phi_label, delta))
170  return false;
171  }
172  return true;
173 }
174 
175 // Globally normalizes a weighted FST, when possible, as a stochastic
176 // FST. This preserves successful path weights up to a global
177 // constant. Normalization is possible when the sum of the weight of
178 // all successful paths from the initial state is finite. Always
179 // possible in the acyclic case. Returns true if the operation is
180 // successful. The 'delta' parameter controls the degree of
181 // convergence.
182 template <class Arc>
183 bool GlobalNormalize(fst::MutableFst<Arc> *fst,
184  typename Arc::Label phi_label = fst::kNoLabel,
185  float delta = fst::kDelta) {
186  namespace f = fst;
187  typedef typename Arc::Weight Weight;
188 
189  if (!IsCanonical(*fst, phi_label)) {
190  LOG(ERROR) << "GlobalNormalize: input is not a canonical stochastic FST";
191  return false;
192  }
193  // Reweights with the weights found above.
194  std::vector<Weight> distance;
195  const auto total_weight =
196  ShortestDistance(*fst, &distance, phi_label, true, delta);
197  if (!total_weight.Member())
198  return false;
199  f::Reweight(fst, distance, f::REWEIGHT_TO_INITIAL);
200  f::RemoveWeight(fst, total_weight, false);
201 
202  return true;
203 }
204 
205 // Locally normalizes a state of a weighted FST, when possible, as a stochastic
206 // FST. This rescales out-going arc weights (including super-final
207 // weight) from each state by a state-dependent constant. Any phi or
208 // epsilon labels are considered as regular symbols. Normalization is
209 // always possible when the sum of the weight of the out-going arcs of
210 // each state is non-Zero(). Returns true if the operation is
211 // successful.
212 template <class Arc>
213 bool LocalNormalizeState(typename Arc::StateId s,
214  fst::MutableFst<Arc> *fst) {
215  namespace f = fst;
216  typedef typename Arc::Weight Weight;
217 
218  f::WeightConvert<Weight, f::Log64Weight> to_log64;
219  f::WeightConvert<f::Log64Weight, Weight> from_log64;
220  if (s < 0 || s >= fst->NumStates()) {
221  return false;
222  } else {
223  Weight final = fst->Final(s);
224 
225  f::Adder<f::Log64Weight> adder(to_log64(final));
226  for (f::ArcIterator<f::MutableFst<Arc>> aiter(*fst, s); !aiter.Done();
227  aiter.Next()) {
228  const Arc &arc = aiter.Value();
229  f::Log64Weight weight = to_log64(arc.weight);
230  adder.Add(weight);
231  }
232  if (ApproxZero(adder.Sum())) return false;
233 
234  Weight sum = from_log64(adder.Sum());
235  if (final != Weight::Zero()) fst->SetFinal(s, Divide(final, sum));
236  for (f::MutableArcIterator<f::MutableFst<Arc>> aiter(fst, s); !aiter.Done();
237  aiter.Next()) {
238  Arc arc = aiter.Value();
239  arc.weight = Divide(arc.weight, sum);
240  aiter.SetValue(arc);
241  }
242  return true;
243  }
244 }
245 
246 // Locally normalizes a weighted FST, when possible, as a stochastic
247 // FST. This rescales out-going arc weights (including super-final
248 // weight) from each state by a state-dependent constant. Any phi or
249 // epsilon labels are considered as regular symbols. Normalization is
250 // always possible when the sum of the weight of the out-going arcs of
251 // each state is non-Zero(). Returns true if the operation is
252 // successful.
253 template <class Arc>
254 bool LocalNormalize(fst::MutableFst<Arc> *fst) {
255  namespace f = fst;
256  typedef typename Arc::StateId StateId;
257 
258  if (!IsCanonical(*fst, f::kNoLabel)) {
259  LOG(ERROR) << "LocalNormalize: input is not a canonical stochastic FST";
260  return false;
261  }
262 
263  for (StateId s = 0; s < fst->NumStates(); ++s) {
264  if (!LocalNormalizeState(s, fst)) return false;
265  }
266  return true;
267 }
268 
269 // Normalizes a state of a weighted FST, when possible, as a stochastic FST by
270 // computing the appropriate failure transition weights. The
271 // non-failure transition weights are assumed correct where possible,
272 // otherwise they are locally normalized. Returns true if the
273 // operation is successful.
274 template <class Arc>
275 bool PhiNormalizeState(typename Arc::StateId s, fst::MutableFst<Arc> *fst,
276  typename Arc::Label phi_label = fst::kNoLabel) {
277  namespace f = fst;
278  typedef typename Arc::StateId StateId;
279  typedef typename Arc::Weight Weight;
280  typedef f::MutableArcIterator<f::MutableFst<Arc>> MArcItr;
281  constexpr float kNormDelta = 1.0e-15;
282  f::WeightConvert<f::Log64Weight, Weight> from_log64;
283  f::WeightConvert<Weight, f::Log64Weight> to_log64;
284  if (s < 0 || s >= fst->NumStates()) {
285  return false;
286  } else {
287  MArcItr aiter(fst, s);
288  f::Log64Weight high_sum, low_sum, phi_weight;
289  ssize_t phi_position;
290  StateSums(*fst, s, phi_label, &high_sum, &low_sum, &phi_weight,
291  &phi_position);
292 
293  // Only case where high_sum can be zero is if
294  // there is a state with only a phi transition.
295  if (ApproxZero(high_sum) && (phi_position == -1 || fst->NumArcs(s) != 1)) {
296  return false;
297  }
298 
299  bool low_sum_ge_one =
300  Less(f::Log64Weight::One(), low_sum) ||
301  ApproxEqual(low_sum, f::Log64Weight::One(), kNormDelta);
302  bool high_sum_eq_one =
303  ApproxEqual(high_sum, f::Log64Weight::One(), kNormDelta);
304  bool high_sum_gt_one = Less(f::Log64Weight::One(), high_sum);
305 
306  // Locally normalizes if necessary
307  if (low_sum_ge_one || high_sum_gt_one ||
308  (high_sum_eq_one && phi_position != -1) ||
309  (!high_sum_eq_one && phi_position == -1)) {
310  for (; !aiter.Done(); aiter.Next()) {
311  Arc arc = aiter.Value();
312  if (arc.ilabel != phi_label) {
313  arc.weight = from_log64(Divide(to_log64(arc.weight), high_sum));
314  aiter.SetValue(arc);
315  }
316  }
317  Weight final = fst->Final(s);
318  if (final != Weight::Zero()) {
319  final = from_log64(Divide(to_log64(final), high_sum));
320  fst->SetFinal(s, final);
321  }
322  high_sum = f::Log64Weight::One();
323  }
324 
325  if (phi_position != -1) {
326  if (high_sum == f::Log64Weight::One()) {
327  phi_weight = kApproxZeroWeight;
328  } else {
329  f::Log64Weight numer = Minus(f::Log64Weight::One(), high_sum);
330  f::Log64Weight denom = Minus(f::Log64Weight::One(), low_sum);
331  phi_weight = f::Divide(numer, denom);
332  }
333  aiter.Seek(phi_position);
334  Arc arc = aiter.Value();
335  arc.weight = from_log64(phi_weight);
336  aiter.SetValue(arc);
337  }
338  }
339  return true;
340 }
341 
342 // Normalizes a weighted FST, when possible, as a stochastic FST by
343 // computing the appropriate failure transition weights. The
344 // non-failure transition weights are assumed correct where possible,
345 // otherwise they are locally normalized. Returns true if the
346 // operation is successful.
347 template <class Arc>
348 bool PhiNormalize(fst::MutableFst<Arc> *fst,
349  typename Arc::Label phi_label = fst::kNoLabel) {
350  namespace f = fst;
351  typedef typename Arc::StateId StateId;
352  typedef typename Arc::Weight Weight;
353 
354  std::vector<StateId> top_order;
355  if (phi_label == f::kNoLabel) return true;
356  if (!IsCanonical(*fst, phi_label, &top_order)) {
357  LOG(ERROR) << "PhiNormalize: input is not a canonical stochastic FST";
358  return false;
359  }
360 
361  for (StateId i = top_order.size() - 1; i >= 0; --i) {
362  StateId s = top_order[i]; // ith state in reverse phi-top order
363  if (!PhiNormalizeState(s, fst, phi_label)) {
364  return false;
365  }
366  }
367  return true;
368 }
369 
370 // See CountNormalize() below for an explantion of these values.
374  NORM_MARGINALLY_CONSTRAINED OPENFST_DEPRECATED("Use `NORM_KL_MIN` instead.") = 1,
376  NORM_MARGINALLY_APPROXIMATED OPENFST_DEPRECATED(
377  "Use `NORM_KL_MIN_APPROXIMATED` instead.") = 2,
378 };
379 
380 namespace internal {
381 
382 // Internal class to normalize a count SFST. See the public interface
383 // below for algorithm and argument description and for a reference.
384 // The formula used in comments are w.r.t the reference.
385 template <class Arc>
387  public:
388  using StateId = typename Arc::StateId;
389  using Label = typename Arc::Label;
390  using Weight = typename Arc::Weight;
391  using SLWeight = fst::SignedLog64Weight;
392  using ArcItr = fst::ArcIterator<fst::Fst<Arc>>;
393  using MArcItr = fst::MutableArcIterator<fst::MutableFst<Arc>>;
394  using Matr = fst::ExplicitMatcher<fst::Matcher<fst::Fst<Arc>>>;
395 
396  explicit CountNormalizer(Label phi_label, bool trim = false,
397  float delta = kNormDelta,
398  double effective_zero = kEffectiveZero,
399  size_t maxiters = kMaxNormIters)
400  : phi_label_(phi_label),
401  trim_(trim),
402  delta_(delta),
403  effective_zero_(1.0, effective_zero),
404  maxiters_(maxiters) { }
405 
406  // Performs the normalization.
407  bool Normalize(CountNormType norm_type, fst::MutableFst<Arc> *fst);
408 
409  // Comparison delta for state normalization.
410  static constexpr float kNormDelta = 1.0e-5;
411  // Within epsilon of zero.
412  static constexpr double kEffectiveZero = 35.0;
413  // Maximum iterations for convergence of normalization.
414  static constexpr size_t kMaxNormIters = 1000;
415 
416  private:
417  // Internal state associated with the constrained marginalization of an
418  // SFST state.
419  struct NormState {
420  NormState()
421  : count(SLWeight::Zero()),
422  fail_count(SLWeight::Zero()),
423  denom(SLWeight::Zero()) { }
424  SLWeight count; // sum of outgoing transitions weights from counts:
425  // C(q)
426  SLWeight fail_count; // failure transition weight from counts:
427  // C(\varphi, q)
428  SLWeight denom; // denominator of failure arc weight from this state:
429  // d(q, q') where q \in B_1(q')
430  std::vector<StateId> hi_states; // states with a failure arc to this state
431  };
432 
433  // Initializes states in the KL minimization.
434  void InitStates(const fst::ExpandedFst<Arc> &fst);
435 
436  // Iteratively calculates the arc weights at a state in the KL minimization.
437  // This uses a DC (difference of convex functions) optimization.
438  bool KLMinimizeState(StateId s, CountNormType norm_type,
439  fst::MutableFst<Arc> *fst);
440 
441  // Calculates the per arc normalization factor f(x,s,y^n)
442  // used to minimize the KL divergence (see NormArcWeights).
443  // Returns true on success.
444  bool ComputeNormFactor(const fst::ExpandedFst<Arc> &fst,
445  StateId s, std::vector<SLWeight> *norm_factor) const;
446 
447  // Normalizes the arc count using the arc normalization factor
448  // and the lambda Lagrange multiplier to compute
449  // y^{n+1} = C(x, s)/(lambda - f(x, s, y^n))
450  // in the minimization of the KL divergence (see LambdaSearch).
451  // Returns the sum of these new weights at the state.
452  SLWeight NormArcWeights(const fst::ExpandedFst<Arc> &fst, StateId s,
453  CountNormType norm_type,
454  const std::vector<SLWeight> &norm_factor,
455  SLWeight lambda,
456  std::vector<SLWeight> *arc_weights) const;
457 
458  // Search for the normalization that makes the arc weights a
459  // prob distribution. Does so by a binary search on the lambda argument
460  // to NormArcWeights(). Returns true on success.
461  bool LambdaSearch(const fst::ExpandedFst<Arc> &fst,
462  StateId s, CountNormType norm_type,
463  const std::vector<SLWeight> &norm_factor,
464  std::vector<SLWeight> *arc_weights) const;
465 
466  // Initializes arc weights from counts. Returns true if computed arc weights
467  // are solely an initialization; returns false if they are the final answer.
468  bool InitArcWeights(const fst::ExpandedFst<Arc> &fst,
469  StateId s, std::vector<SLWeight> *arc_weights) const;
470 
471  // Calculates the denominator of the backoff weights given the arc weights.
472  // Returns true on success.
473  bool ComputeDenom(const fst::ExpandedFst<Arc> &fst, StateId s,
474  const std::vector<SLWeight> &arc_weights);
475 
476  // Returns number of arcs at a state including any super-final but
477  // excluding any failure arcs.
478  size_t NumNonPhiArcs(const fst::Fst<Arc> &fst, StateId s) const {
479  size_t narcs = fst.NumArcs(s);
480  if (fst.Final(s) != Weight::Zero()) ++narcs;
481  if (backoff_->GetBackoffPosition(s) != -1) --narcs;
482  return narcs;
483  }
484 
485  // Tests within epsilon of zero; ensures approximation is on a closed set.
486  bool IsEffectiveZero(SLWeight w) const {
487  return ApproxZero(w, effective_zero_.Value2());
488  }
489 
490  const Label phi_label_;
491  bool trim_;
492  const float delta_;
493  const SLWeight effective_zero_;
494  const size_t maxiters_;
495 
496  std::unique_ptr<Backoff<Arc>> backoff_;
497  std::vector<NormState> norm_states_;
498  fst::WeightConvert<SLWeight, Weight> from_log_;
499  fst::WeightConvert<Weight, SLWeight> to_log_;
500 
501  CountNormalizer(const CountNormalizer &) = delete;
502  CountNormalizer &operator=(const CountNormalizer &) = delete;
503 };
504 
505 template <class Arc>
507  fst::MutableFst<Arc> *fst) {
508  namespace f = fst;
509 
510  if (norm_type == NORM_SUMMED) {
511  // Sums to lower orders.
512  if (!SumBackoff(fst, phi_label_)) {
513  LOG(ERROR) << "CountNormalize: backoff summation failed";
514  return false;
515  }
516  }
517 
518 
519  if (trim_) {
520  // Trims negligible (non-phi) weight transitions while
521  // preserving a backoff-complete input.
522  const Weight trim_weight = from_log_(effective_zero_);
523  Trimmer<Arc> trim(fst, phi_label_, TRIM_NEEDED_TRIM);
524  if (norm_type == NORM_SUMMED) {
525  trim.WeightTrim(false, trim_weight);
526  } else {
527  trim.SumWeightTrim(false, trim_weight);
528  }
529  trim.Finalize();
530  if (fst->Properties(f::kError, false)) return false;
531  }
532 
533 
534  if (phi_label_ == f::kNoLabel || norm_type == NORM_SUMMED) {
535  if (!LocalNormalize(fst)) {
536  LOG(ERROR) << "CountNormalizer: local normalization failed";
537  return false;
538  }
539  } else { // marginally-constrained case
540  InitStates(*fst);
541 
542  for (StateId i = 0; i < fst->NumStates(); ++i) {
543  StateId s = backoff_->GetPhiTopOrder(i); // ith state in phi-top order
544  if (!KLMinimizeState(s, norm_type, fst)) return false;
545  }
546  }
547 
548  if (phi_label_ != f::kNoLabel && !PhiNormalize(fst, phi_label_)) {
549  LOG(ERROR) << "CountNormalizer: phi normalization failed";
550  return false;
551  }
552  return true;
553 }
554 
555 template <class Arc>
557  const fst::ExpandedFst<Arc> &fst) {
558  namespace f = fst;
559  backoff_ = std::make_unique<Backoff<Arc>>(fst, phi_label_);
560 
561  norm_states_.clear();
562  norm_states_.resize(fst.NumStates());
563  for (StateId i = fst.NumStates() - 1; i >= 0; --i) {
564  const StateId s = backoff_->GetPhiTopOrder(i);
565  NormState &state = norm_states_[s];
566  f::Adder<SLWeight> adder(to_log_(fst.Final(s)));
567  for (ArcItr aiter(fst, s); !aiter.Done(); aiter.Next()) {
568  const Arc &arc = aiter.Value();
569  if (arc.ilabel == phi_label_)
570  state.fail_count = to_log_(arc.weight);
571  adder.Add(to_log_(arc.weight));
572  }
573  state.count = adder.Sum();
574  const StateId bos = backoff_->GetBackoffState(s);
575  if (bos != f::kNoStateId) {
576  NormState &bo_state = norm_states_[bos];
577  // Records the immediately higher-order state sets.
578  bo_state.hi_states.push_back(s);
579  }
580  }
581 }
582 
583 // The computations are done in the signed log rather than the log semiring
584 // for numerical stability.
585 template <class Arc>
587  StateId s, CountNormType norm_type, fst::MutableFst<Arc> *fst) {
588  namespace f = fst;
589  // Stores new arc weights; last position is the super-final weight.
590  std::vector<SLWeight> arc_weights(fst->NumArcs(s) + 1, SLWeight::Zero());
591  std::vector<SLWeight> prev_arc_weights;
592  std::vector<SLWeight> norm_factor(fst->NumArcs(s) + 1, SLWeight::Zero());
593  size_t iters = 0;
594  if (InitArcWeights(*fst, s, &arc_weights)) {
595  if (!ComputeDenom(*fst, s, arc_weights))
596  return false;
597 
598  do {
599  prev_arc_weights = arc_weights;
600  if (!ComputeNormFactor(*fst, s, &norm_factor))
601  return false;
602 
603  if (!LambdaSearch(*fst, s, norm_type, norm_factor, &arc_weights)) {
604  LOG(ERROR) << "CountNormalizer: lambda iterations failed";
605  return false;
606  }
607  if (!ComputeDenom(*fst, s, arc_weights))
608  return false;
609  if (++iters > maxiters_) {
610  LOG(WARNING) << "CountNormalizer: max DC iterations exceeded:"
611  << " state: " << s;
612  break; // Allows sub-optimal solution (with warning).
613  }
614  } while (!ApproxEqualWeights(arc_weights, prev_arc_weights));
615  }
616 
617  NormWeights(&arc_weights); // explicitly renormalizes sum to One()
618 
619  // Copies arc weights to FST and validates their values.
620  ssize_t pos = 0;
621  f::Adder<SLWeight> adder;
622  for (MArcItr aiter(fst, s); !aiter.Done(); aiter.Next(), ++pos) {
623  Arc arc = aiter.Value();
624  if (Less(arc_weights[pos], SLWeight::Zero())) {
625  LOG(ERROR) << "CountNormalizer: bad arc weight: " << arc_weights[pos];
626  return false;
627  }
628  arc.weight = from_log_(arc_weights[pos]);
629  aiter.SetValue(arc);
630  adder.Add(arc_weights[pos]);
631  }
632  // ...including the super-final arc.
633  if (Less(arc_weights[pos], SLWeight::Zero())) {
634  LOG(ERROR) << "CountNormalizer: bad final weight: " << arc_weights[pos];
635  return false;
636  }
637  fst->SetFinal(s, from_log_(arc_weights[pos]));
638  adder.Add(arc_weights[pos]);
639 
640  // Validates total mass.
641  if (Less(SLWeight::One(), adder.Sum()) &&
642  !ApproxEqual(SLWeight::One(), adder.Sum())) {
643  LOG(ERROR) << "CountNormalizer: bad state sum: " << adder.Sum();
644  return false;
645  }
646  return true;
647 }
648 
649 template <class Arc>
651  const fst::ExpandedFst<Arc> &fst, StateId s,
652  std::vector<SLWeight> *norm_factor) const {
653  // We assume any immediately higher NormStates are completed except
654  // for the denom members which should have at least tentative values.
655 
656  const NormState &state = norm_states_[s];
657  // Per-arc normalization factor.
658  std::fill(norm_factor->begin(), norm_factor->end(), SLWeight::Zero());
659  // The higher-order state probabilities w/ backoff.
660  for (auto his : state.hi_states) {
661  // If all arcs are backed-off, skip.
662  if (NumNonPhiArcs(fst, s) == NumNonPhiArcs(fst, his))
663  continue;
664  const NormState &hi_state = norm_states_[his];
665  // C(\varphi, his) / d(his, s)
666  const SLWeight hi_weight = Divide(hi_state.fail_count, hi_state.denom);
667  for (size_t hipos = 0; hipos < fst.NumArcs(his); ++hipos) {
668  const ssize_t pos = backoff_->GetBackedOffArc(his, hipos);
669  // 1_{x \in L[s] \cap \Sigma}
670  if (pos != -1) {
671  (*norm_factor)[pos] = Plus((*norm_factor)[pos], hi_weight);
672  }
673  }
674  // 1_{$ \in L[s] \cap \Sigma}
675  if (fst.Final(his) != Weight::Zero()) {
676  const ssize_t pos = fst.NumArcs(s);
677  (*norm_factor)[pos] = Plus((*norm_factor)[pos], hi_weight);
678  }
679  }
680  return true;
681 }
682 
683 template <class Arc>
684 fst::SignedLog64Weight CountNormalizer<Arc>::NormArcWeights(
685  const fst::ExpandedFst<Arc> &fst, StateId s, CountNormType norm_type,
686  const std::vector<SLWeight> &norm_factor, SLWeight lambda,
687  std::vector<SLWeight> *arc_weights) const {
688  namespace f = fst;
689 
690  // Normalizes using arc normalization factors.
691  size_t pos = 0;
692  f::Adder<SLWeight> adder;
693  for (ArcItr aiter(fst, s); !aiter.Done(); aiter.Next(), ++pos) {
694  const Arc &arc = aiter.Value();
695  // If norm type is NORM_KL_MIN_APPROXIMATED, then the failure
696  // transition weight is not modified from its initial estimate.
697  // C(x, s)
698  if (arc.ilabel != phi_label_ || norm_type != NORM_KL_MIN_APPROXIMATED) {
699  // \lambda - f(x, q, y^n)
700  const SLWeight arc_weight = to_log_(arc.weight);
701  if (IsEffectiveZero(arc_weight)) {
702  // y^{n+1} = max(y^{n+1}, eps)
703  (*arc_weights)[pos] = effective_zero_;
704  } else {
705  const SLWeight norm = Minus(lambda, norm_factor[pos]);
706  // y^{n+1} = C(x, s)/(\lambda - f(x, q, y^n))
707  (*arc_weights)[pos] = Divide(arc_weight, norm);
708  if (IsEffectiveZero((*arc_weights)[pos])) {
709  // y^{n+1} = max(y^{n+1}, eps)
710  (*arc_weights)[pos] = effective_zero_;
711  }
712  }
713  }
714  adder.Add((*arc_weights)[pos]);
715  }
716  // ...including the super-final arc
717  if (fst.Final(s) != Weight::Zero()) {
718  // C(x, s)
719  const SLWeight final_weight = to_log_(fst.Final(s));
720  // \lambda - f(x, q, y^n)
721  const SLWeight norm = Minus(lambda, norm_factor[pos]);
722  if (IsEffectiveZero(final_weight)) {
723  // y^{n+1} = max(y^{n+1}, eps)
724  (*arc_weights)[pos] = effective_zero_;
725  } else {
726  // y^{n+1} = C(x, s)/(\lambda - f(x, q, y^n))
727  (*arc_weights)[pos] = Divide(final_weight, norm);
728  // y^{n+1} = max(y^{n+1}, eps)
729  if (IsEffectiveZero((*arc_weights)[pos]))
730  (*arc_weights)[pos] = effective_zero_;
731  }
732  adder.Add((*arc_weights)[pos]);
733  }
734  return adder.Sum();
735 }
736 
737 template <class Arc>
739  const fst::ExpandedFst<Arc> &fst, StateId s,
740  CountNormType norm_type, const std::vector<SLWeight> &norm_factor,
741  std::vector<SLWeight> *arc_weights) const {
742  // Chooses lambda lower bound as max_x (f(x, s, y^n) + C(x,s))
743  // Chooses lambda higher bound as C(s) + max_x f(x, s, y^n)
744  ssize_t pos = 0;
745  SLWeight maxfx = SLWeight::Zero();
746  SLWeight lambda_low = SLWeight::Zero();
747  SLWeight lambda_hi = SLWeight::Zero();
748  for (ArcItr aiter(fst, s); !aiter.Done(); aiter.Next(), ++pos) {
749  SLWeight cx = to_log_(aiter.Value().weight);
750  SLWeight fx = norm_factor[pos];
751  SLWeight cfx = Plus(cx, fx);
752  if (Less(maxfx, fx)) maxfx = fx;
753  if (Less(lambda_low, cfx)) lambda_low = cfx;
754  lambda_hi = Plus(lambda_hi, cx);
755  }
756  const SLWeight final_cx = to_log_(fst.Final(s));
757  const SLWeight final_fx = norm_factor[pos];
758  const SLWeight final_cfx = Plus(final_fx, final_cx);
759  if (Less(maxfx, final_fx)) maxfx = final_fx;
760  if (Less(lambda_low, final_cfx)) lambda_low = final_cfx;
761  lambda_hi = Plus(Plus(lambda_hi, final_cx), maxfx);
762 
763  if (Less(lambda_hi, lambda_low)) {
764  if (ApproxEqual(lambda_low, lambda_hi, kNormDelta)) {
765  lambda_low = lambda_hi;
766  } else {
767  LOG(ERROR) << "CountNormalizer: bad lambda parameter limits:"
768  << " state: " << s
769  << " lambda_low: " << lambda_low
770  << " lambda_hi: " << lambda_hi;
771  return false;
772  }
773  }
774 
775  const SLWeight two = Plus(SLWeight::One(), SLWeight::One());
776  size_t iters = 0;
777  while (true) {
778  const SLWeight lambda = Divide(Plus(lambda_hi, lambda_low), two);
779  const SLWeight arc_sum =
780  NormArcWeights(fst, s, norm_type, norm_factor, lambda, arc_weights);
781  if (ApproxEqual(arc_sum, SLWeight::One(), delta_) ||
782  ApproxEqual(lambda_low, lambda_hi, kNormDelta)) {
783  return true;
784  } else if (Less(arc_sum, SLWeight::One())) {
785  lambda_hi = lambda;
786  } else {
787  lambda_low = lambda;
788  }
789  if (++iters > maxiters_) {
790  LOG(ERROR) << "CountNormalizer: max (lambda) iterations exceeded:"
791  << " state: " << s
792  << " lambda_low: " << lambda_low
793  << " lambda_hi: " << lambda_hi
794  << " arc_sum: " << arc_sum;
795  return false;
796  }
797  }
798 }
799 
800 template <class Arc>
802  const fst::ExpandedFst<Arc> &fst,
803  StateId s, std::vector<SLWeight> *arc_weights) const {
804  const NormState &state = norm_states_[s];
805 
806  // Initializes the initialization.
807  std::fill(arc_weights->begin(), arc_weights->end(), SLWeight::Zero());
808 
809  // Finds the number of arcs (including super-final arc)
810  ssize_t num_arcs = fst.NumArcs(s);
811  if (fst.Final(s) != Weight::Zero())
812  ++num_arcs;
813  const SLWeight num_arcs_weight =
814  to_log_(-std::log(static_cast<double>(num_arcs)));
815 
816  if (IsEffectiveZero(state.count)) {
817  // Trivial case: (no count mass at the state)
818  // y^0 = 1/|L(s)| and return false
819  for (size_t pos = 0; pos < num_arcs; ++pos)
820  (*arc_weights)[pos] = Divide(SLWeight::One(), num_arcs_weight);
821  return false;
822  } else {
823  // General case:
824  // y^0 = C(x,s)/C(s) (1 - |L(s)| eps) + eps and return true
825  const SLWeight arc_mult = Minus(SLWeight::One(),
826  Times(num_arcs_weight, effective_zero_));
827  size_t pos = 0;
828  for (ArcItr aiter(fst, s);
829  !aiter.Done();
830  aiter.Next(), ++pos) {
831  const Arc &arc = aiter.Value();
832  const SLWeight arc_weight = Divide(to_log_(arc.weight), state.count);
833  (*arc_weights)[pos] = Plus(Times(arc_weight, arc_mult), effective_zero_);
834  }
835  if (fst.Final(s) != Weight::Zero()) {
836  const SLWeight final_weight = Divide(to_log_(fst.Final(s)), state.count);
837  (*arc_weights)[pos] =
838  Plus(Times(final_weight, arc_mult), effective_zero_);
839  }
840  return true;
841  }
842 }
843 
844 template <class Arc>
846  const fst::ExpandedFst<Arc> &fst, StateId s,
847  const std::vector<SLWeight> &arc_weights) {
848  namespace f = fst;
849  NormState &state = norm_states_[s];
850  for (auto his : state.hi_states) {
851  NormState &hi_state = norm_states_[his];
852  f::Adder<SLWeight> adder;
853  // Computes d(his,s) = 1 - \sum{x \in L[his]\cap\Sigma} y_x.
854  for (size_t hipos = 0; hipos < fst.NumArcs(his); ++hipos) {
855  const ssize_t pos = backoff_->GetBackedOffArc(his, hipos);
856  if (pos != -1)
857  adder.Add(arc_weights[pos]);
858  }
859  if (fst.Final(his) != Weight::Zero()) {
860  const ssize_t pos = fst.NumArcs(s);
861  adder.Add(arc_weights[pos]);
862  }
863  hi_state.denom = Minus(SLWeight::One(), adder.Sum());
864 
865  if (Less(hi_state.denom, effective_zero_)) {
866  // Computes d(his,s) = \sum{x \notin L[his]\cap\Sigma} y_x. Equal
867  // mathematically to above but possibly different numerically. This
868  // ensures the low probability is not due to that on the few states that
869  // get here.
870  adder.Reset();
871  Matr matcher(fst, f::MATCH_INPUT);
872  matcher.SetState(his);
873  ssize_t pos = 0;
874  for (ArcItr aiter(fst, s); !aiter.Done(); aiter.Next(), ++pos) {
875  const Arc &arc = aiter.Value();
876  if (arc.ilabel == phi_label_ || !matcher.Find(arc.ilabel))
877  adder.Add(arc_weights[pos]);
878  }
879  if (fst.Final(his) == Weight::Zero())
880  adder.Add(arc_weights[pos]);
881  hi_state.denom = adder.Sum();
882  }
883 
884  if (IsEffectiveZero(hi_state.denom)) {
885  // Probability mass at the lower order for arcs at this state is nearly
886  // one, so that the denominator in backoff calculation is effectively
887  // zero. This can happen for states with a large number of arcs and
888  // nearly the same arcs at each state.
889  hi_state.denom = effective_zero_;
890  } else if (Less(hi_state.denom, SLWeight::Zero())) {
891  LOG(ERROR) << "CountNormalizer: bad backoff denominator: "
892  << hi_state.denom
893  << " state: " << s << " high state: " << his;
894  return false;
895  }
896  }
897  return true;
898 }
899 
900 } // namespace internal
901 
902 // Public interface to algorithm to normalize a count SFST. The input
903 // should be (smoothed) count SFST e.g., as returned by
904 // sfst::Count. It should be a 'backoff-complete' SFST (see backoff.h). It
905 // should not be 'phi-summed' before input (see below). It is
906 // transformed into a normalized SFST. Returns true on success.
907 //
908 // If norm_type == NORM_SUMMED, the algorithm 'phi-sums' the input (by
909 // adding higher-order counts to lower counts), locally normalizes and then
910 // determines the failure weights. In this way, the output at a state
911 // is marginalized over the higher orders from that state. This default
912 // choice results in simple reliable count normalization.
913 //
914 // If norm_type == NORM_KL_MIN, the Kullback–Leibler divergence minimum
915 // distribution w.r.t. the counts is computed as described in
916 // Suresh, Roark, Riley, Schogol, "Approximating probabilistic models
917 // as weighted automata" (experimental/fst/papers/approx/approx.pdf)
918 //
919 // If norm_type == NORM_KL_MIN_APPROXIMATED, is similar to the
920 // NORM_KL_MIN, but may be more numerically stable and can empirically
921 // give better results. It does so by fixing the failure weight to its
922 // initial estimate in the above optimization.
923 //
924 // The 'delta' (convergence threshold), (-log) 'effective_zero' (underflow)
925 // and 'maxiters' (iteration count threshold) parameters control
926 // the iterative computation used in the last two cases.
927 template <class Arc>
929  fst::MutableFst<Arc> *fst,
930  typename Arc::Label phi_label,
931  CountNormType norm_type = NORM_SUMMED,
932  bool trim = false,
934  double effective_zero = internal::CountNormalizer<Arc>::kEffectiveZero,
936  internal::CountNormalizer<Arc> normalizer(phi_label, trim, delta,
937  effective_zero, maxiters);
938  return normalizer.Normalize(norm_type, fst);
939 }
940 
941 
942 // Modifies input FST to move it toward a globally normalizable
943 // stochastic FST. Degree of modification controlled by non-negative
944 // "delta" with "delta = 0.0" meaning no modification. Returns true
945 // if the operation is successful.
946 template <class Arc>
947 bool Condition(fst::MutableFst<Arc> *fst,
948  typename Arc::Label phi_label = fst::kNoLabel,
949  float delta = fst::kDelta) {
950  namespace f = fst;
951  typedef typename Arc::StateId StateId;
952  typedef typename Arc::Weight Weight;
953 
954  if (delta < 0.0) {
955  LOG(ERROR) << "Condition: conditioning delta is negative";
956  return false;
957  }
958 
959  f::WeightConvert<f::Log64Weight, Weight> from_log64;
960  Weight weight = from_log64(delta);
961 
962  // Multiplies the delta on to every non-special arc weight
963  for (StateId s = 0; s < fst->NumStates(); ++s) {
964  for (f::MutableArcIterator<f::MutableFst<Arc>> aiter(fst, s);
965  !aiter.Done();
966  aiter.Next()) {
967  Arc arc = aiter.Value();
968  // Skips special labels
969  if (arc.ilabel == phi_label || arc.ilabel == 0)
970  continue;
971  arc.weight = Times(arc.weight, weight);
972  aiter.SetValue(arc);
973  }
974  }
975  return true;
976 }
977 
978 } // namespace sfst
979 
980 #endif // NLP_GRM2_SFST_NORMALIZE_H_
bool Less(fst::LogWeightTpl< T > weight1, fst::LogWeightTpl< T > weight2)
Definition: sfst.h:38
bool LocalNormalizeState(typename Arc::StateId s, fst::MutableFst< Arc > *fst)
Definition: normalize.h:213
static constexpr size_t kMaxNormIters
Definition: normalize.h:414
void StateSums(const fst::Fst< Arc > &fst, typename Arc::StateId s, typename Arc::Label phi_label, fst::Log64Weight *high_sum, fst::Log64Weight *low_sum, fst::Log64Weight *phi_weight, ssize_t *phi_position)
Definition: normalize.h:52
size_t Length() const
Definition: sfst.h:149
Definition: perplexity.h:32
Weight GetWeight(size_t i) const
Definition: sfst.h:154
Entropy64Weight Minus(Entropy64Weight w1, Entropy64Weight w2)
Definition: perplexity.h:129
bool CountNormalize(fst::MutableFst< Arc > *fst, typename Arc::Label phi_label, CountNormType norm_type=NORM_SUMMED, bool trim=false, float delta=internal::CountNormalizer< Arc >::kNormDelta, double effective_zero=internal::CountNormalizer< Arc >::kEffectiveZero, size_t maxiters=internal::CountNormalizer< Arc >::kMaxNormIters)
Definition: normalize.h:928
Definition: sfstinfo.cc:39
fst::SignedLog64Weight SLWeight
Definition: normalize.h:391
bool ApproxZero(fst::Log64Weight weight, fst::Log64Weight approx_zero=kApproxZeroWeight)
Definition: sfst.h:84
bool SumBackoff(fst::MutableFst< Arc > *fst, typename Arc::Label phi_label)
Definition: backoff.h:250
bool Condition(fst::MutableFst< Arc > *fst, typename Arc::Label phi_label=fst::kNoLabel, float delta=fst::kDelta)
Definition: normalize.h:947
bool Normalize(CountNormType norm_type, fst::MutableFst< Arc > *fst)
Definition: normalize.h:506
bool LocalNormalize(fst::MutableFst< Arc > *fst)
Definition: normalize.h:254
void Finalize()
Definition: trim.h:648
bool IsCanonical(const fst::Fst< Arc > &fst, typename Arc::Label phi_label, std::vector< typename Arc::StateId > *top_order)
Definition: canonical.h:71
Arc::Weight ShortestDistance(const fst::Fst< Arc > &fst, std::vector< typename Arc::Weight > *distance, typename Arc::Label phi_label=fst::kNoLabel, bool reverse=false, float delta=fst::kShortestDelta)
StateId GetNextState(size_t i) const
Definition: sfst.h:152
bool IsNormalizedState(const fst::Fst< Arc > &fst, typename Arc::StateId s, typename Arc::Label phi_label, float delta)
Definition: normalize.h:124
void WeightTrim(bool include_phi, Weight approx_zero=ApproxZeroWeight())
Definition: trim.h:596
fst::ArcIterator< fst::Fst< Arc >> ArcItr
Definition: normalize.h:392
size_t GetPosition(size_t i) const
Definition: sfst.h:156
bool IsNormalized(const fst::Fst< Arc > &fst, typename Arc::Label phi_label=fst::kNoLabel, float delta=fst::kDelta)
Definition: normalize.h:159
CountNormalizer(Label phi_label, bool trim=false, float delta=kNormDelta, double effective_zero=kEffectiveZero, size_t maxiters=kMaxNormIters)
Definition: normalize.h:396
bool PhiNormalize(fst::MutableFst< Arc > *fst, typename Arc::Label phi_label=fst::kNoLabel)
Definition: normalize.h:348
const fst::Log64Weight kApproxZeroWeight
Definition: sfst.h:34
fst::MutableArcIterator< fst::MutableFst< Arc >> MArcItr
Definition: normalize.h:393
typename Arc::StateId StateId
Definition: normalize.h:388
void SumWeightTrim(bool include_phi, Weight approx_zero=ApproxZeroWeight())
Definition: trim.h:620
void SetState(StateId s)
Definition: sfst.h:179
CountNormType
Definition: normalize.h:371
fst::ExplicitMatcher< fst::Matcher< fst::Fst< Arc >>> Matr
Definition: normalize.h:394
typename Arc::Weight Weight
Definition: normalize.h:390
constexpr double kNormDelta
void NormWeights(std::vector< Weight > *weights)
Definition: state-weights.h:89
bool GlobalNormalize(fst::MutableFst< Arc > *fst, typename Arc::Label phi_label=fst::kNoLabel, float delta=fst::kDelta)
Definition: normalize.h:183
static constexpr float kNormDelta
Definition: normalize.h:410
bool ApproxEqualWeights(const std::vector< Weight > &weights1, const std::vector< Weight > &weights2, float delta=fst::kDelta, Weight approx_zero=Weight::NoWeight())
static constexpr double kEffectiveZero
Definition: normalize.h:412
Real64Weight Times(SignedLog64Weight w1, Real64Weight w2)
Definition: perplexity.h:41
bool PhiNormalizeState(typename Arc::StateId s, fst::MutableFst< Arc > *fst, typename Arc::Label phi_label=fst::kNoLabel)
Definition: normalize.h:275
typename Arc::Label Label
Definition: normalize.h:389