NGram  ngram-1.3.15
OpenGrm-NGram library
ngram-model.h
Go to the documentation of this file.
1 // Copyright 2005-2013 Brian Roark
2 // Copyright 2005-2020 Google LLC
3 //
4 // Licensed under the Apache License, Version 2.0 (the 'License');
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an 'AS IS' BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 // NGram model class.
17 
18 #ifndef NGRAM_NGRAM_MODEL_H_
19 #define NGRAM_NGRAM_MODEL_H_
20 
21 #include <array>
22 #include <cmath>
23 #include <cstdint>
24 #include <deque>
25 #include <iostream>
26 #include <ostream>
27 #include <vector>
28 
29 #include <fst/fst.h>
30 #include <fst/matcher.h>
31 #include <ngram/hist-arc.h>
32 #include <ngram/util.h>
33 
34 namespace ngram {
35 
36 // Default normalization constant (e.g., for checks)
37 const double kNormEps = 0.001;
38 const double kFloatEps = 0.000001;
39 const double kInfBackoff = 99.00;
40 
41 // Calculate - log( exp(a - b) + 1 ) for use in high precision NegLogSum
42 static double NegLogDeltaValue(double a, double b, double *c) {
43  double x = exp(a - b), delta = -log(x + 1);
44  if (x < kNormEps) { // for small x, use Mercator Series to calculate
45  delta = -x;
46  for (int j = 2; j <= 4; ++j) delta += pow(-x, j) / j;
47  }
48  if (c) delta -= *c; // Sum correction from Kahan formula (if using)
49  return delta;
50 }
51 
52 // Precision method for summing reals and saving negative logs
53 // -log( exp(-a) + exp(-b) ) = a - log( exp(a - b) + 1 )
54 // Uses Mercator series and Kahan formula for additional numerical stability
55 static double NegLogSum(double a, double b, double *c) {
56  if (a == fst::StdArc::Weight::Zero().Value()) return b;
57  if (b == fst::StdArc::Weight::Zero().Value()) return a;
58  if (a > b) return NegLogSum(b, a, c);
59  double delta = NegLogDeltaValue(a, b, c), val = a + delta;
60  if (c) *c = (val - a) - delta; // update sum correction for Kahan formula
61  return val;
62 }
63 
64 // Summing reals and saving negative logs, no Kahan formula (backwards compat)
65 static double NegLogSum(double a, double b) { return NegLogSum(a, b, nullptr); }
66 
67 // Negative log of difference: -log(exp^{-a} - exp^{-b}).
68 // FRAGILE: assumes exp^{-a} >= exp^{-b}
69 // Sets (but does not clear) optional error field.
70 static double NegLogDiff(double a, double b, bool *error = nullptr) {
71  if (b == fst::StdArc::Weight::Zero().Value()) return a;
72  if (a >= b) {
73  if (a - b >= kNormEps) { // not equal within fp error
74  NGRAMERROR() << "NegLogDiff: undefined " << a << " " << b;
75  if (error) *error = true;
76  }
77  return fst::StdArc::Weight::Zero().Value();
78  }
79  return b - log(exp(b - a) - 1);
80 }
81 
82 template <class Arc>
83 class NGramModel {
84  public:
85  typedef typename Arc::StateId StateId;
86  typedef typename Arc::Label Label;
87  typedef typename Arc::Weight Weight;
88 
89  // Construct an NGramModel object, consisting of the FST and some
90  // information about the states under the assumption that the FST is
91  // a model. The 'backoff_label' is what is followed when there is no
92  // word match at a given order. The 'norm_eps' is the epsilon used
93  // in checking weight normalization. If 'state_ngrams' is true,
94  // this class explicitly finds, checks the consistency of and stores
95  // the ngram that must be read to reach each state (normally false
96  // to save some time and space).
97  NGramModel(const fst::Fst<Arc> &infst, Label backoff_label,
98  double norm_eps, bool state_ngrams)
99  : fst_(infst),
100  backoff_label_(backoff_label),
101  norm_eps_(norm_eps),
102  have_state_ngrams_(state_ngrams),
103  error_(false) {
104  InitModel();
105  }
106 
107  // Same as above, but requires the FST and the backoff label.
108  NGramModel(const fst::Fst<Arc> &infst, Label backoff_label)
109  : fst_(infst),
110  backoff_label_(backoff_label),
111  norm_eps_(kNormEps),
112  have_state_ngrams_(false),
113  error_(false) {
114  InitModel();
115  }
116 
117  // Same as above, but uses defaults for most of the parameters.
118  explicit NGramModel(const fst::Fst<Arc> &infst)
119  : fst_(infst),
120  backoff_label_(0),
121  norm_eps_(kNormEps),
122  have_state_ngrams_(false),
123  error_(false) {
124  InitModel();
125  }
126 
127  virtual ~NGramModel() = default;
128 
129  // Number of states in the LM fst
130  StateId NumStates() const { return nstates_; }
131 
132  // Size of ngram model is the sum of the number of states and number of arcs
133  int64_t GetSize() const {
134  int64_t size = 0;
135  for (StateId st = 0; st < nstates_; ++st)
136  size += fst_.NumArcs(st) + 1; // number of arcs + 1 state
137  return size;
138  }
139 
140  // Returns highest order
141  int HiOrder() const { return hi_order_; }
142 
143  // Returns order of a given state
144  // Order 1 is unigram, order 2 is bigram, and so on.
145  int StateOrder(StateId state) const {
146  if (state >= 0 && state < nstates_)
147  return state_orders_[state];
148  else
149  return -1;
150  }
151 
152  // Gets the state present in the model corresponding to the longest prefix of
153  // the specified n-gram context. If the longest prefix is the empty string,
154  // then this method returns the unigram state. '0' signifies super-initial
155  // 'word' in the ngram.
156  StateId NGramState(const std::vector<Label> &ngram) const {
157  StateId state = UnigramState();
158  if (state < 0) state = GetFst().Start();
159  fst::Matcher<fst::Fst<Arc>> matcher(fst_, fst::MATCH_INPUT);
160  for (auto it = ngram.begin(); it != ngram.end(); ++it) {
161  if (*it == 0) {
162  state = fst_.Start();
163  continue;
164  }
165  matcher.SetState(state);
166  if (!matcher.Find(*it)) break;
167  const Arc &arc = matcher.Value();
168  state = arc.nextstate;
169  }
170  return state;
171  }
172 
173  // Returns n-gram that must be read to reach 'state'. '0' signifies
174  // super-initial 'word' in the ngram. Constructor argument 'state_ngrams' must
175  // be true.
176  const std::vector<Label> &StateNGram(StateId state) const {
177  if (!have_state_ngrams_) {
178  NGRAMERROR() << "NGramModel: state ngrams not available";
179  return empty_label_vector_;
180  }
181  return state_ngrams_[state];
182  }
183 
184  // Unigram state
185  StateId UnigramState() const { return unigram_; }
186 
187  // Returns the unigram cost of requested symbol if found (inf otherwise)
188  double GetSymbolUnigramCost(Label symbol) const {
189  fst::Matcher<fst::Fst<Arc>> matcher(fst_, fst::MATCH_INPUT);
190  StateId st = unigram_;
191  if (st < 0) st = fst_.Start();
192  matcher.SetState(st);
193  if (matcher.Find(symbol)) {
194  Arc arc = matcher.Value();
195  return ScalarValue(arc.weight);
196  } else {
197  return ScalarValue(Arc::Weight::Zero());
198  }
199  }
200 
201  // Label of backoff transitions
202  Label BackoffLabel() const { return backoff_label_; }
203 
204  // Find the backoff state for a given state st, and provide bocost if req'd
205  StateId GetBackoff(StateId st, Weight *bocost) const {
206  StateId backoff = -1;
207  fst::Matcher<fst::Fst<Arc>> matcher(fst_, fst::MATCH_INPUT);
208  matcher.SetState(st);
209  if (matcher.Find(backoff_label_)) {
210  for (; !matcher.Done(); matcher.Next()) {
211  Arc arc = matcher.Value();
212  if (arc.ilabel == fst::kNoLabel) continue; // non-consuming symbol
213  backoff = arc.nextstate;
214  if (bocost != nullptr) bocost[0] = arc.weight;
215  }
216  }
217  return backoff;
218  }
219 
220  // Verifies LM topology is sane.
221  bool CheckTopology() const {
222  ascending_ngrams_ = 0;
223  // Checks state topology
224  for (StateId st = 0; st < nstates_; ++st)
225  if (!CheckTopologyState(st)) return false;
226  // All but start and unigram state should have a unique ascending ngram arc
227  if (unigram_ != -1 && ascending_ngrams_ != nstates_ - 2) {
228  VLOG(1) << "Incomplete # of ascending n-grams: " << ascending_ngrams_;
229  return false;
230  }
231  return true;
232  }
233 
234  // Iterates through all states and validate that they are fully normalized.
235  bool CheckNormalization() const {
236  if (Error()) return false;
237  for (StateId st = 0; st < nstates_; ++st) {
238  if (!CheckNormalizationState(st)) {
239  VLOG(1) << "Failed normalization check at " << st;
240  return false;
241  }
242  }
243  return true;
244  }
245 
246  // Calculate backoff cost from neglog sums of hi and low order arcs
247  double CalculateBackoffCost(double hi_neglog_sum, double low_neglog_sum,
248  bool infinite_backoff = false) const {
249  double nlog_backoff_num, nlog_backoff_denom; // backoff cost and factors
250  bool return_inf = CalculateBackoffFactors(
251  hi_neglog_sum, low_neglog_sum, &nlog_backoff_num, &nlog_backoff_denom,
252  infinite_backoff);
253  if (return_inf) return kInfBackoff; // backoff cost is 'infinite'
254  return nlog_backoff_num - nlog_backoff_denom;
255  }
256 
257  // Calculates the numerator and denominator for assigning backoff cost
258  bool CalculateBackoffFactors(double hi_neglog_sum, double low_neglog_sum,
259  double *nlog_backoff_num,
260  double *nlog_backoff_denom,
261  bool infinite_backoff = false) const {
262  double effective_zero = kNormEps * kFloatEps, effective_nlog_zero = 99.0;
263  if (infinite_backoff && hi_neglog_sum <= kFloatEps) // unsmoothed and p=1
264  return true;
265  if (hi_neglog_sum < effective_zero) hi_neglog_sum = effective_zero;
266  if (low_neglog_sum < effective_zero) low_neglog_sum = effective_zero;
267  if (low_neglog_sum <= 0 || hi_neglog_sum <= 0) return true;
268  if (hi_neglog_sum > effective_nlog_zero) {
269  *nlog_backoff_num = 0.0;
270  } else {
271  *nlog_backoff_num = NegLogDiff(0.0, hi_neglog_sum);
272  }
273  if (low_neglog_sum > effective_nlog_zero) {
274  *nlog_backoff_denom = 0.0;
275  } else {
276  *nlog_backoff_denom = NegLogDiff(0.0, low_neglog_sum);
277  }
278  return false;
279  }
280 
281  // Calculate marginal state probs. By default, uses the product of
282  // the order-ascending ngram transition probabilities. If 'stationary'
283  // is true, instead computes the stationary distribution of the Markov
284  // chain. Returns true on success.
285  bool CalculateStateProbs(std::vector<double> *probs, bool stationary = false,
286  size_t maxiters = 10000) const {
287  bool ret = true;
288  if (stationary) {
289  ret = StationaryStateProbs(probs, .999999, norm_eps_, maxiters);
290  } else {
291  NGramStateProbs(probs);
292  }
293  if (FST_FLAGS_v > 1) {
294  for (size_t st = 0; st < probs->size(); ++st)
295  std::cerr << "st: " << st << " log_prob: " << log((*probs)[st])
296  << std::endl;
297  }
298  return ret;
299  }
300 
301  // fst::Fst const reference
302  const fst::Fst<Arc> &GetFst() const { return fst_; }
303 
304  // Called at construction. If the model topology is mutated, this should
305  // be re-called prior to any member function that depends on it.
306  void InitModel() {
307  using fst::kAcceptor;
308  using fst::kIDeterministic;
309  using fst::kILabelSorted;
310  using fst::kNoLabel;
311  using fst::kNoStateId;
312  // unigram state is set to -1 for unigram models (in which case start
313  // state is the unigram state, no need to store here)
314  if (fst_.Start() == kNoLabel) {
315  NGRAMERROR() << "NGramModel: Empty automaton";
316  SetError();
317  return;
318  }
319  uint64_t need_props = kAcceptor | kIDeterministic | kILabelSorted;
320  uint64_t have_props = fst_.Properties(need_props, true);
321  if (!(have_props & kAcceptor)) {
322  NGRAMERROR() << "NGramModel: input not an acceptor";
323  SetError();
324  return;
325  }
326  if (!(have_props & kIDeterministic)) {
327  NGRAMERROR() << "NGramModel: input not deterministic";
328  SetError();
329  return;
330  }
331  if (!(have_props & kILabelSorted)) {
332  NGRAMERROR() << "NGramModel: input not label sorted";
333  SetError();
334  return;
335  }
336 
337  if (!fst::CompatSymbols(fst_.InputSymbols(), fst_.OutputSymbols())) {
338  NGRAMERROR() << "NGramModel: input and output symbol tables do not match";
339  SetError();
340  return;
341  }
342 
343  nstates_ = CountStates(fst_);
344  unigram_ = GetBackoff(fst_.Start(), nullptr); // set the unigram state
345  ComputeStateOrders();
346  if (!CheckTopology()) {
347  NGRAMERROR() << "NGramModel: bad ngram model topology";
348  SetError();
349  return;
350  }
351  }
352 
353  // Accessor function for the norm_eps_ parameter
354  double NormEps() const { return norm_eps_; }
355 
356  // Calculates number of n-grams at state
357  int NumNGrams(StateId st) {
358  int num_ngrams = fst_.NumArcs(st); // arcs are n-grams
359  if (GetBackoff(st, nullptr) >= 0) // except one arc, backoff arc
360  num_ngrams--;
361  if (ScalarValue(fst_.Final(st)) !=
362  ScalarValue(Arc::Weight::Zero())) // </s> n-gram
363  num_ngrams++;
364  return num_ngrams;
365  }
366 
367  // Returns the cost assigned by model to an n-gram. '0' signifies
368  // super-initial and super-final 'words'. If the n-gram begins with
369  // '0', the computation begins at the start state and the initial
370  // weight is applied; otherwise the computation begins at the unigram
371  // state. If the n-gram ends with '0' (distinct from from an initial
372  // '0'), the final weight is applied.
373  Weight GetNGramCost(const std::vector<Label> &ngram) const {
374  if (ngram.empty()) return Weight::One();
375 
376  StateId st = ngram.front() == 0 || unigram_ < 0 ? fst_.Start() : unigram_;
377 
378  // p(<s>) = p(</s>)
379  Weight cost = ngram.front() == 0 && unigram_ >= 0 ? fst_.Final(unigram_)
380  : Weight::One();
381 
382  fst::Matcher<fst::Fst<Arc>> matcher(fst_, fst::MATCH_INPUT);
383 
384  for (int n = 0; n < ngram.size(); ++n) {
385  Label label = ngram[n];
386  if (label == 0) {
387  if (n == 0) continue; // super-initial word
388  if (n != ngram.size() - 1) {
389  NGRAMERROR() << "end-of-string is not the super-final word";
390  return Weight::Zero();
391  }
392  while (fst_.Final(st) == Weight::Zero()) {
393  Weight bocost;
394  st = GetBackoff(st, &bocost);
395  if (st < 0) {
396  return Weight::Zero();
397  }
398  cost = Times(cost, bocost);
399  }
400  cost = Times(cost, fst_.Final(st));
401  } else {
402  while (true) {
403  matcher.SetState(st);
404  if (matcher.Find(label)) {
405  Arc arc = matcher.Value();
406  st = arc.nextstate;
407  cost = Times(cost, arc.weight);
408  break;
409  } else {
410  Weight bocost;
411  st = GetBackoff(st, &bocost);
412  if (st < 0) {
413  return Weight::Zero();
414  }
415  cost = Times(cost, bocost);
416  }
417  }
418  }
419  }
420 
421  return cost;
422  }
423 
424  // Mimic a phi matcher: follow backoff links until final state found
425  Weight FinalCostInModel(StateId mst, int *order) const {
426  Weight cost = Arc::Weight::One();
427  while (fst_.Final(mst) == Arc::Weight::Zero()) {
428  fst::Matcher<fst::Fst<Arc>> matcher(fst_, fst::MATCH_INPUT);
429  matcher.SetState(mst);
430  if (matcher.Find(backoff_label_)) {
431  for (; !matcher.Done(); matcher.Next()) {
432  Arc arc = matcher.Value();
433  if (arc.ilabel == backoff_label_) {
434  mst = arc.nextstate; // make current state backoff state
435  cost = Times(cost, arc.weight); // add in backoff cost
436  }
437  }
438  } else {
439  NGRAMERROR() << "NGramModel: No final cost in model: " << mst;
440  return Arc::Weight::Zero();
441  }
442  }
443  *order = state_orders_[mst];
444  // TODO(vitalyk): take care of value call
445  cost = Times(cost, fst_.Final(mst));
446  return cost;
447  }
448 
449  // Change data for a state that would normally be computed
450  // by InitModel; this allows incremental updates
451  void UpdateState(StateId st, int order, bool unigram_state,
452  const std::vector<Label> *ngram = 0) {
453  if (have_state_ngrams_ && !ngram) {
454  NGRAMERROR() << "NGramModel::UpdateState: no ngram provided";
455  SetError();
456  return;
457  }
458  if (state_orders_.size() < st) {
459  NGRAMERROR() << "NGramModel::UpdateState: bad state: " << st;
460  SetError();
461  return;
462  }
463  if (order > hi_order_) hi_order_ = order;
464 
465  if (state_orders_.size() == st) { // add state info
466  state_orders_.push_back(order);
467  if (ngram) state_ngrams_.push_back(*ngram);
468  ++nstates_;
469  } else { // modifies state info
470  state_orders_[st] = order;
471  if (ngram) state_ngrams_.push_back(*ngram);
472  }
473 
474  if (unigram_state) unigram_ = nstates_;
475  }
476 
477  // Returns a scalar value associated with a weight
478  static double ScalarValue(Weight w);
479 
480  // Returns a weight that represents unit count for this model
481  static Weight UnitCount();
482 
483  // Returns a factor used to scale backoff mass in interpolated models
484  static double FactorValue(Weight w);
485 
486  // Returns the final for state st
487  Weight GetFinalWeight(StateId st) const { return fst_.Final(st); }
488 
489  // Returns the backoff cost for state st
490  Weight GetBackoffCost(StateId st) const {
491  Weight bocost;
492  StateId bo = GetBackoff(st, &bocost);
493  if (bo < 0) // if no backoff arc found
494  bocost = Arc::Weight::Zero();
495  return bocost;
496  }
497 
498  // Estimate total unigram count based on probabilities in unigram state
499  // The difference between two smallest probs should be 1/N, return reciprocal
500  double EstimateTotalUnigramCount() const {
501  StateId st = UnigramState();
502  bool first = true;
503  double max = fst::LogArc::Weight::Zero().Value(), nextmax = max;
504  if (st < 0) st = GetFst().Start(); // if model unigram, use Start()
505  for (fst::ArcIterator<fst::Fst<Arc>> aiter(GetFst(), st);
506  !aiter.Done(); aiter.Next()) {
507  Arc arc = aiter.Value();
508  if (arc.ilabel == BackoffLabel()) continue;
509  if (first || ScalarValue(arc.weight) > max) {
510  // maximum negative log prob case
511  nextmax = max; // keep both max and nextmax (to calculate diff)
512  max = ScalarValue(arc.weight);
513  first = false;
514  } else if (ScalarValue(arc.weight) < max &&
515  ScalarValue(arc.weight) > nextmax) {
516  nextmax = ScalarValue(arc.weight);
517  }
518  }
519  if (nextmax == fst::LogArc::Weight::Zero().Value()) return exp(max);
520  return exp(NegLogDiff(nextmax, max));
521  }
522 
523  // Returns true if model in a bad state/not a proper LM.
524  bool Error() const { return error_; }
525 
526  protected:
527  void SetError() { error_ = true; }
528 
529  // Shadows in this class to catch errors.
530  double NegLogDiff(double a, double b) const {
531  return ngram::NegLogDiff(a, b, &error_);
532  }
533 
534  // Fills a vector with the counts of each state, based on prefix count
535  void FillStateCounts(std::vector<double> *state_counts) {
536  for (int i = 0; i < nstates_; i++)
537  state_counts->push_back(ScalarValue(Arc::Weight::Zero()));
538  WalkStatesForCount(state_counts);
539  }
540 
541  // Collect backoff arc weights in a vector
542  bool FillBackoffArcWeights(StateId st, StateId bo,
543  std::vector<double> *bo_arc_weight) const {
544  fst::Matcher<fst::Fst<Arc>> matcher(
545  fst_, fst::MATCH_INPUT); // for querying backoff
546  matcher.SetState(bo);
547  for (fst::ArcIterator<fst::Fst<Arc>> aiter(fst_, st); !aiter.Done();
548  aiter.Next()) {
549  Arc arc = aiter.Value();
550  if (arc.ilabel == backoff_label_) continue;
551  if (matcher.Find(arc.ilabel)) {
552  Arc barc = matcher.Value();
553  // Note that we allow to scale the backoff weight by
554  // a value that depends on the weight of the ngram.
555  // For instance, the fractional count model mixes in a fraction
556  // of lower order mass proportional to the frequency of event
557  // that ngram occurs zero times. So for this model we scale
558  // backoff weights by these frequences.
559  // For all the other models this scaling factor defaults to 0.0
560  // (unity in log semiring).
561  bo_arc_weight->push_back(ScalarValue(barc.weight) +
562  FactorValue(arc.weight));
563  } else {
564  NGRAMERROR() << "NGramModel: lower order arc missing: " << st;
565  return false;
566  }
567  }
568  return true;
569  }
570 
571  // Uses iterator in place of matcher for arc iterators; allows
572  // getting Position(). NB: begins search from current position.
573  bool FindArc(fst::ArcIterator<fst::Fst<Arc>> *biter,
574  Label label) const {
575  while (!biter->Done()) { // scan through arcs
576  Arc barc = biter->Value();
577  if (barc.ilabel == label)
578  return true; // if label matches, true
579  else if (barc.ilabel < label) // if less than value, go to next
580  biter->Next();
581  else
582  return false; // otherwise no match
583  }
584  return false; // no match found
585  }
586 
587  // Finds the arc weight associated with a label at a state
588  Weight FindArcWeight(StateId st, Label label) const {
589  Weight cost = Arc::Weight::Zero();
590  fst::Matcher<fst::Fst<Arc>> matcher(fst_, fst::MATCH_INPUT);
591  matcher.SetState(st);
592  if (matcher.Find(label)) {
593  Arc arc = matcher.Value();
594  cost = arc.weight;
595  }
596  return cost;
597  }
598 
599  // Mimic a phi matcher: follow backoff arcs until label found or no backoff
600  bool FindNGramInModel(StateId *mst, int *order, Label label,
601  double *cost) const {
602  if (label < 0) return false;
603  StateId currstate = *mst;
604  *cost = 0;
605  *mst = -1;
606  while (*mst < 0) {
607  fst::Matcher<fst::Fst<Arc>> matcher(fst_, fst::MATCH_INPUT);
608  matcher.SetState(currstate);
609  if (matcher.Find(label)) { // arc found out of current state
610  Arc arc = matcher.Value();
611  *order = state_orders_[currstate];
612  *mst = arc.nextstate; // assign destination as new model state
613  *cost += ScalarValue(arc.weight); // add cost to total
614  } else if (matcher.Find(backoff_label_)) { // follow backoff arc
615  currstate = -1;
616  for (; !matcher.Done(); matcher.Next()) {
617  Arc arc = matcher.Value();
618  if (arc.ilabel == backoff_label_) {
619  currstate = arc.nextstate; // make current state backoff state
620  *cost += ScalarValue(arc.weight); // add in backoff cost
621  }
622  }
623  if (currstate < 0) return false;
624  } else {
625  return false; // Found label in symbol list, but not in model
626  }
627  }
628  return true;
629  }
630 
631  // Sum final + arc probs out of state and for same transitions out of backoff
632  bool CalcBONegLogSums(StateId st, double *hi_neglog_sum,
633  double *low_neglog_sum, bool infinite_backoff = false,
634  bool unigram = false) const {
635  StateId bo = GetBackoff(st, nullptr);
636  if (bo < 0 && !unigram) return false; // only calc for states that backoff
637  *low_neglog_sum = *hi_neglog_sum = // final costs initialize the sum
638  ScalarValue(fst_.Final(st));
639  // if st is final
640  if (bo >= 0 && *hi_neglog_sum != ScalarValue(Arc::Weight::Zero()))
641  // re-initialize lower sum
642  *low_neglog_sum = ScalarValue(fst_.Final(bo));
643  CalcArcNegLogSums(st, bo, hi_neglog_sum, low_neglog_sum, infinite_backoff);
644  return true;
645  }
646 
647  // Prints state ngram to a stream
648  bool PrintStateNGram(StateId st, std::ostream &ostrm = std::cerr) const {
649  ostrm << "state: " << st << " order: " << state_orders_[st] << " ngram: ";
650  for (int i = 0; i < state_ngrams_[st].size(); ++i)
651  ostrm << state_ngrams_[st][i] << " ";
652  ostrm << "\n";
653  return true;
654  }
655 
656  // Modifies n-gram weights according to printing parameters
657  static double WeightRep(double wt, bool neglogs, bool intcnts) {
658  if (!neglogs || intcnts) wt = exp(-wt);
659  if (intcnts) wt = round(wt);
660  return wt;
661  }
662 
663  private:
664  // Iterate through arcs, accumulate neglog probs from arcs and their backoffs
665  bool CalcArcNegLogSums(StateId st, StateId bo, double *hi_sum,
666  double *low_sum, bool infinite_backoff = false) const {
667  // correction values for Kahan summation
668  double KahanVal1 = 0, KahanVal2 = 0;
669  double init_low = *low_sum;
670  fst::Matcher<fst::Fst<Arc>> matcher(
671  fst_, fst::MATCH_INPUT); // for querying backoff
672  if (bo >= 0) matcher.SetState(bo);
673  for (fst::ArcIterator<fst::Fst<Arc>> aiter(fst_, st); !aiter.Done();
674  aiter.Next()) {
675  Arc arc = aiter.Value();
676  if (arc.ilabel == backoff_label_) continue;
677  if (bo < 0 || matcher.Find(arc.ilabel)) {
678  if (bo >= 0) {
679  Arc barc = matcher.Value();
680  *low_sum = // sum of lower order probs of the same labels
681  NegLogSum(*low_sum, ScalarValue(barc.weight), &KahanVal2);
682  }
683  *hi_sum = // sum of higher order probs
684  NegLogSum(*hi_sum, ScalarValue(arc.weight), &KahanVal1);
685  } else {
686  NGRAMERROR() << "NGramModel: No arc label match in backoff state: "
687  << st;
688  return false;
689  }
690  }
691  if (bo >= 0 && infinite_backoff && *low_sum == 0.0) // ok for unsmoothed
692  return true;
693  if (bo >= 0 && *low_sum <= 0.0) {
694  VLOG(2) << "lower order sum less than zero: " << st << " " << *low_sum;
695  double start_low = ScalarValue(Arc::Weight::Zero());
696  if (init_low == start_low) start_low = ScalarValue(fst_.Final(bo));
697  *low_sum = CalcBruteLowSum(st, bo, start_low);
698  VLOG(2) << "new lower order sum: " << st << " " << *low_sum;
699  }
700  return true;
701  }
702 
703  // Iterate through arcs, accumulate neglog probs from arcs and their backoffs
704  // Used in case the more efficient method fails to produce a sane value
705  double CalcBruteLowSum(StateId st, StateId bo, double start_low) const {
706  double low_sum = start_low, KahanVal = 0;
707  fst::Matcher<fst::Fst<Arc>> matcher(
708  fst_, fst::MATCH_INPUT); // for querying backoff
709  matcher.SetState(bo);
710  fst::ArcIterator<fst::Fst<Arc>> biter(fst_, bo);
711  Arc barc;
712  for (fst::ArcIterator<fst::Fst<Arc>> aiter(fst_, st); !aiter.Done();
713  aiter.Next()) {
714  Arc arc = aiter.Value();
715  if (arc.ilabel == backoff_label_) continue;
716  barc = biter.Value();
717  while (!biter.Done() && barc.ilabel < arc.ilabel) { // linear scan
718  if (barc.ilabel != backoff_label_)
719  low_sum = // sum of lower order probs of different labels
720  NegLogSum(low_sum, ScalarValue(barc.weight), &KahanVal);
721  biter.Next();
722  if (!biter.Done()) barc = biter.Value();
723  }
724  if (!biter.Done() && barc.ilabel == arc.ilabel) {
725  biter.Next();
726  if (!biter.Done()) barc = biter.Value();
727  }
728  if (biter.Done()) break;
729  }
730  while (!biter.Done()) { // linear scan
731  if (barc.ilabel != backoff_label_)
732  low_sum = // sum of lower order probs of different labels
733  NegLogSum(low_sum, ScalarValue(barc.weight), &KahanVal);
734  biter.Next();
735  if (!biter.Done()) barc = biter.Value();
736  }
737  return NegLogDiff(0.0, low_sum);
738  }
739 
740  // Traverse n-gram fst and record each state's n-gram order, return highest
741  void ComputeStateOrders() {
742  state_orders_.clear();
743  state_orders_.resize(nstates_, -1);
744 
745  if (have_state_ngrams_) {
746  state_ngrams_.clear();
747  state_ngrams_.resize(nstates_);
748  }
749 
750  hi_order_ = 1; // calculate highest order in the model
751  std::deque<StateId> state_queue;
752  if (unigram_ != fst::kNoStateId) {
753  state_orders_[unigram_] = 1;
754  state_queue.push_back(unigram_);
755  state_orders_[fst_.Start()] = hi_order_ = 2;
756  state_queue.push_back(fst_.Start());
757  if (have_state_ngrams_)
758  state_ngrams_[fst_.Start()].push_back(0); // initial context
759  } else {
760  state_orders_[fst_.Start()] = 1;
761  state_queue.push_back(fst_.Start());
762  }
763 
764  while (!state_queue.empty()) {
765  StateId state = state_queue.front();
766  state_queue.pop_front();
767  for (fst::ArcIterator<fst::Fst<Arc>> aiter(fst_, state);
768  !aiter.Done(); aiter.Next()) {
769  const Arc &arc = aiter.Value();
770  if (state_orders_[arc.nextstate] == -1) {
771  state_orders_[arc.nextstate] = state_orders_[state] + 1;
772  if (have_state_ngrams_) {
773  state_ngrams_[arc.nextstate] = state_ngrams_[state];
774  state_ngrams_[arc.nextstate].push_back(arc.ilabel);
775  }
776  if (state_orders_[state] >= hi_order_)
777  hi_order_ = state_orders_[state] + 1;
778  state_queue.push_back(arc.nextstate);
779  }
780  }
781  }
782  }
783 
784  // Ensure correct n-gram topology for a given state.
785  bool CheckTopologyState(StateId st) const {
786  if (unigram_ == -1) { // unigram model
787  if (fst_.Final(fst_.Start()) == Arc::Weight::Zero()) {
788  VLOG(1) << "CheckTopology: bad final weight for start state";
789  return false;
790  } else {
791  return true;
792  }
793  }
794 
795  StateId bos = GetBackoff(st, nullptr);
796  fst::Matcher<fst::Fst<Arc>> matcher(
797  fst_, fst::MATCH_INPUT); // for querying backoff
798 
799  if (st == unigram_) { // unigram state
800  if (fst_.Final(unigram_) == Arc::Weight::Zero()) {
801  VLOG(1) << "CheckTopology: bad final weight for unigram state: "
802  << unigram_;
803  return false;
804  } else if (have_state_ngrams_ && !state_ngrams_[unigram_].empty()) {
805  VLOG(1) << "CheckTopology: bad unigram state: " << unigram_;
806  return false;
807  }
808  } else { // non-unigram state
809  if (bos == -1) {
810  VLOG(1) << "CheckTopology: no backoff state: " << st;
811  return false;
812  }
813 
814  if (fst_.Final(st) != Arc::Weight::Zero() &&
815  fst_.Final(bos) == Arc::Weight::Zero()) {
816  VLOG(1) << "CheckTopology: bad final weight for backoff state: " << st;
817  return false;
818  }
819 
820  if (StateOrder(st) != StateOrder(bos) + 1) {
821  VLOG(1) << "CheckTopology: bad backoff arc from: " << st
822  << " with order: " << StateOrder(st) << " to state: " << bos
823  << " with order: " << StateOrder(bos);
824  return false;
825  }
826  matcher.SetState(bos);
827  }
828 
829  for (fst::ArcIterator<fst::Fst<Arc>> aiter(fst_, st); !aiter.Done();
830  aiter.Next()) {
831  Arc arc = aiter.Value();
832 
833  if (StateOrder(st) < StateOrder(arc.nextstate)) ++ascending_ngrams_;
834 
835  if (have_state_ngrams_ && !CheckStateNGrams(st, arc)) {
836  VLOG(1) << "CheckTopology: inconsistent n-gram states: " << st << " -- "
837  << arc.ilabel << "/" << arc.weight << " -> " << arc.nextstate;
838  return false;
839  }
840 
841  if (st != unigram_) {
842  if (arc.ilabel == backoff_label_) continue;
843  if (!matcher.Find(arc.ilabel)) {
844  VLOG(1) << "CheckTopology: unmatched arc at backoff state: "
845  << arc.ilabel << "/" << arc.weight << " for state: " << st;
846  return false;
847  }
848  }
849  }
850  return true;
851  }
852 
853  // Checks state ngrams for consistency
854  bool CheckStateNGrams(StateId st, const Arc &arc) const {
855  std::vector<Label> state_ngram;
856  bool boa = arc.ilabel == backoff_label_;
857 
858  int j = state_orders_[st] - state_orders_[arc.nextstate] + (boa ? 0 : 1);
859  if (j < 0) return false;
860 
861  for (int i = j; i < state_ngrams_[st].size(); ++i)
862  state_ngram.push_back(state_ngrams_[st][i]);
863  if (!boa && j <= state_ngrams_[st].size())
864  state_ngram.push_back(arc.ilabel);
865 
866  return state_ngram == state_ngrams_[arc.nextstate];
867  }
868 
869  // Ensure normalization for a given state to error epsilon
870  // sum of state probs + exp(-backoff_cost) - sum of arc backoff probs = 1
871  bool CheckNormalizationState(StateId st) const {
872  double Norm, Norm1;
873  Weight bocost = Weight::NoWeight();
874  StateId bo = GetBackoff(st, &bocost);
875  // final costs initialize the sum
876  Norm = Norm1 = ScalarValue(fst_.Final(st));
877  if (bo >= 0 && Norm != ScalarValue(Arc::Weight::Zero())) // if st is final
878  Norm1 = ScalarValue(fst_.Final(bo)); // re-initialize lower sum
879  if (!CalcArcNegLogSums(st, bo, &Norm, &Norm1,
880  (ScalarValue(bocost) == kInfBackoff))) {
881  return false;
882  }
883  return EvaluateNormalization(st, bo, ScalarValue(bocost), Norm, Norm1);
884  }
885 
886  // For accumulated negative log probabilities, test for normalization
887  bool EvaluateNormalization(StateId st, StateId bo, double bocost, double norm,
888  double norm1) const {
889  double newnorm = norm;
890  if (bo >= 0) {
891  newnorm = NegLogSum(norm, bocost);
892  if (newnorm < norm1 + bocost)
893  newnorm = NegLogDiff(newnorm, norm1 + bocost);
894  else
895  newnorm = NegLogDiff(norm1 + bocost, newnorm);
896  }
897  // NOTE: can we automatically derive an appropriate epsilon?
898  if (fabs(newnorm) > norm_eps_ && // not normalized
899  (bo < 0 || !ReevaluateNormalization(st, bocost, norm, norm1))) {
900  VLOG(2) << "State ID: " << st << "; " << fst_.NumArcs(st) << " arcs;"
901  << " -log(sum(P)) = " << newnorm << ", should be 0";
902  VLOG(2) << norm << " " << norm1;
903  return false;
904  }
905  return true;
906  }
907 
908  // For accumulated negative log probabilities, a 2nd test for normalization
909  // Intended for states with very high magnitude backoff cost, which makes
910  // previous test unreliable
911  bool ReevaluateNormalization(StateId st, double bocost, double norm,
912  double norm1) const {
913  double newalpha = CalculateBackoffCost(norm, norm1);
914  // NOTE: can we automatically derive an appropriate epsilon?
915  VLOG(2) << "Required re-evaluation of normalization: state " << st << " "
916  << norm << " " << norm1 << " " << newalpha << " " << norm_eps_;
917  if (fabs(newalpha - bocost) > norm_eps_) return false;
918  return true;
919  }
920 
921  // Collects prefix counts for arcs out of a specific state
922  void CollectPrefixCounts(std::vector<double> *state_counts,
923  StateId st) const {
924  for (fst::ArcIterator<fst::Fst<Arc>> aiter(fst_, st); !aiter.Done();
925  aiter.Next()) {
926  Arc arc = aiter.Value();
927  if (arc.ilabel != backoff_label_ && // only counting non-backoff arcs
928  state_orders_[st] < state_orders_[arc.nextstate]) { // that + order
929  (*state_counts)[arc.nextstate] = ScalarValue(arc.weight);
930  CollectPrefixCounts(state_counts, arc.nextstate);
931  }
932  }
933  }
934 
935  // Walks model automaton to collect prefix counts for each state
936  void WalkStatesForCount(std::vector<double> *state_counts) const {
937  if (unigram_ != -1) {
938  (*state_counts)[fst_.Start()] = ScalarValue(fst_.Final(unigram_));
939  CollectPrefixCounts(state_counts, unigram_);
940  }
941  CollectPrefixCounts(state_counts, fst_.Start());
942  }
943 
944  // checks non-negativity of weight and uses +;
945  // Test to see if model came from pre-summing a mixture
946  // Should have: backoff weights > 0; higher order always higher prob (summed)
947  bool MixtureConsistent() const {
948  fst::Matcher<fst::Fst<Arc>> matcher(
949  fst_, fst::MATCH_INPUT); // for querying backoff
950  for (StateId st = 0; st < nstates_; ++st) {
951  Weight bocost;
952  StateId bo = GetBackoff(st, &bocost);
953  if (bo >= 0) { // if bigram or higher order
954  if (bocost < 0) // Backoff cost > 0 (can't happen with mixture)
955  return false;
956  matcher.SetState(bo);
957  for (fst::ArcIterator<fst::Fst<Arc>> aiter(fst_, st);
958  !aiter.Done(); aiter.Next()) {
959  Arc arc = aiter.Value();
960  if (arc.ilabel == backoff_label_) {
961  continue;
962  }
963  if (matcher.Find(arc.ilabel)) {
964  Arc barc = matcher.Value();
965  if (ScalarValue(arc.weight) >
966  ScalarValue(barc.weight) + ScalarValue(bocost)) {
967  return false; // L P + (1-L) P' < (1-L) P' (can't happen w/mix)
968  }
969  } else {
970  NGRAMERROR() << "NGramModel: lower order arc missing: " << st;
971  SetError();
972  return false;
973  }
974  }
975  if (ScalarValue(fst_.Final(st)) != ScalarValue(Arc::Weight::Zero()) &&
976  ScalarValue(fst_.Final(st)) >
977  SclarValue(fst_.Final(bo)) + ScalarValue(bocost))
978  return false; // final cost doesn't sum
979  }
980  }
981  return true;
982  }
983 
984  // At a given state, calculate the marginal prob p(h) based on
985  // the smoothed, order-ascending n-gram transition probabilities.
986  void NGramStateProb(StateId st, std::vector<double> *probs) const {
987  for (fst::ArcIterator<fst::Fst<Arc>> aiter(fst_, st); !aiter.Done();
988  aiter.Next()) {
989  Arc arc = aiter.Value();
990  if (arc.ilabel == backoff_label_) continue;
991  if (state_orders_[arc.nextstate] > state_orders_[st]) {
992  (*probs)[arc.nextstate] = (*probs)[st] * exp(-ScalarValue(arc.weight));
993  NGramStateProb(arc.nextstate, probs);
994  }
995  }
996  }
997 
998  // Calculate marginal state probs as the product of the smoothed,
999  // order-ascending ngram transition probablities: p(abc) =
1000  // p(a)p(b|a)p(c|ba) (odd w/KN)
1001  void NGramStateProbs(std::vector<double> *probs, bool norm = false) const {
1002  probs->clear();
1003  probs->resize(nstates_, 0.0);
1004  if (unigram_ < 0) {
1005  // p(unigram state) = 1
1006  (*probs)[fst_.Start()] = 1.0;
1007  } else {
1008  // p(unigram state) = 1
1009  (*probs)[unigram_] = 1.0;
1010  NGramStateProb(unigram_, probs);
1011  // p(<s>) = p(</s>)
1012  (*probs)[fst_.Start()] = exp(-ScalarValue(fst_.Final(unigram_)));
1013  }
1014  NGramStateProb(fst_.Start(), probs);
1015 
1016  if (norm) { // Normalize result, as a starting point for the power method
1017  double sum = 0.0;
1018  for (size_t st = 0; st < probs->size(); ++st) sum += (*probs)[st];
1019  for (size_t st = 0; st < probs->size(); ++st) (*probs)[st] /= sum;
1020  }
1021  }
1022 
1023  // Exponentiates the weights
1024  // At a given state, calculate one step of the power method
1025  // for the stationary distribution of the closure of the
1026  // LM with re-entry probability 'alpha'.
1027  void StationaryStateProb(StateId st, std::vector<double> *init_probs,
1028  std::vector<double> *probs, double alpha) const {
1029  fst::Matcher<fst::Fst<Arc>> matcher(
1030  fst_, fst::MATCH_INPUT); // for querying backoff
1031  Weight bocost;
1032  StateId bo = GetBackoff(st, &bocost);
1033  if (bo != -1) {
1034  // Treats backoff like an epsilon transition
1035  matcher.SetState(bo);
1036  (*init_probs)[bo] += (*init_probs)[st] * exp(-ScalarValue(bocost));
1037  }
1038 
1039  for (fst::ArcIterator<fst::Fst<Arc>> aiter(fst_, st); !aiter.Done();
1040  aiter.Next()) {
1041  Arc arc = aiter.Value();
1042  if (arc.ilabel == backoff_label_) continue;
1043  (*probs)[arc.nextstate] +=
1044  (*init_probs)[st] * exp(-ScalarValue(arc.weight));
1045  if (bo != -1 && matcher.Find(arc.ilabel)) {
1046  // Subtracts corrective weight for backed-off arc
1047  const Arc &barc = matcher.Value();
1048  (*probs)[barc.nextstate] -=
1049  (*init_probs)[st] *
1050  exp(-ScalarValue(barc.weight) - ScalarValue(bocost));
1051  }
1052  }
1053 
1054  if (ScalarValue(fst_.Final(st)) != ScalarValue(Weight::Zero())) {
1055  (*probs)[fst_.Start()] +=
1056  (*init_probs)[st] * exp(-ScalarValue(fst_.Final(st))) * alpha;
1057  if (bo != -1) {
1058  // Subtracts corrective weight for backed-off superfinal arc
1059  (*probs)[fst_.Start()] -=
1060  (*init_probs)[st] *
1061  exp(-ScalarValue(fst_.Final(bo)) - ScalarValue(bocost)) * alpha;
1062  }
1063  }
1064  }
1065 
1066  // Calculate marginal state probs as the stationary distribution
1067  // of the Markov chain consisting of the closure of the LM
1068  // with re-entry probability 'alpha'. The convergence is controlled
1069  // by 'converge_eps' and the number of iterations by 'maxiters'
1070  // Returns true on convergence.
1071  bool StationaryStateProbs(std::vector<double> *probs, double alpha,
1072  double converge_eps, size_t maxiters) const {
1073  std::vector<double> init_probs, last_probs;
1074  // Initialize based on ngram transition probabilities
1075  NGramStateProbs(&init_probs, true);
1076  last_probs = init_probs;
1077 
1078  size_t changed;
1079  size_t iters = 0;
1080  do {
1081  probs->clear();
1082  probs->resize(nstates_, 0.0);
1083  for (int order = hi_order_; order > 0; --order) {
1084  for (size_t st = 0; st < nstates_; ++st) {
1085  if (state_orders_[st] == order)
1086  StationaryStateProb(st, &init_probs, probs, alpha);
1087  }
1088  }
1089 
1090  changed = 0;
1091  for (size_t st = 0; st < nstates_; ++st) {
1092  if (fabs((*probs)[st] - last_probs[st]) > converge_eps * last_probs[st])
1093  ++changed;
1094  last_probs[st] = init_probs[st] = (*probs)[st];
1095  }
1096  VLOG(2) << "NGramModel::StationaryStateProbs: state probs changed: "
1097  << changed;
1098  if (++iters > maxiters) return false;
1099  } while (changed > 0);
1100  return true;
1101  }
1102 
1103  const fst::Fst<Arc> &fst_;
1104  StateId unigram_; // unigram state
1105  Label backoff_label_; // label of backoff transitions
1106  StateId nstates_; // number of states in LM
1107  int hi_order_; // highest order in the model
1108  double norm_eps_; // epsilon diff allowed to ensure normalized
1109  std::vector<int> state_orders_; // order of each state
1110  bool have_state_ngrams_; // compute and store state n-gram info
1111  mutable size_t ascending_ngrams_; // # of n-gram arcs that increase order
1112  std::vector<std::vector<Label>>
1113  state_ngrams_; // n-gram always read to reach state
1114  const std::vector<Label> empty_label_vector_;
1115  mutable bool error_;
1116 
1117  NGramModel(const NGramModel &) = delete;
1118  NGramModel &operator=(const NGramModel &) = delete;
1119 };
1120 
1121 template <typename T>
1123  return w.Value();
1124 }
1125 
1126 template <>
1129  return w.Value(0).Value();
1130 }
1131 
1132 template <typename Arc>
1133 typename Arc::Weight NGramModel<Arc>::UnitCount() {
1134  return Arc::Weight::One();
1135 }
1136 
1137 template <>
1138 inline typename HistogramArc::Weight NGramModel<HistogramArc>::UnitCount() {
1139  std::array<fst::StdArc::Weight, kHistogramBins> weights;
1140  weights.fill(fst::StdArc::Weight::Zero());
1141  if (kHistogramBins > 0) weights[0] = fst::StdArc::Weight::One();
1142  if (kHistogramBins > 2) weights[2] = fst::StdArc::Weight::One();
1143  static const fst::PowerWeight<fst::StdArc::Weight, kHistogramBins>
1144  one(weights.begin(), weights.end());
1145  return one;
1146 }
1147 
1148 template <typename T>
1150  return 0.0;
1151 }
1152 
1153 template <>
1156  return w.Value(1).Value();
1157 }
1158 
1159 } // namespace ngram
1160 
1161 #endif // NGRAM_NGRAM_MODEL_H_
static double WeightRep(double wt, bool neglogs, bool intcnts)
Definition: ngram-model.h:657
static double FactorValue(Weight w)
Definition: ngram-model.h:1149
bool CalculateBackoffFactors(double hi_neglog_sum, double low_neglog_sum, double *nlog_backoff_num, double *nlog_backoff_denom, bool infinite_backoff=false) const
Definition: ngram-model.h:258
StateId GetBackoff(StateId st, Weight *bocost) const
Definition: ngram-model.h:205
Arc::Label Label
Definition: ngram-model.h:86
bool CalculateStateProbs(std::vector< double > *probs, bool stationary=false, size_t maxiters=10000) const
Definition: ngram-model.h:285
constexpr size_t kHistogramBins
Definition: hist-arc.h:31
Weight GetBackoffCost(StateId st) const
Definition: ngram-model.h:490
Weight FindArcWeight(StateId st, Label label) const
Definition: ngram-model.h:588
bool CheckNormalization() const
Definition: ngram-model.h:235
Weight FinalCostInModel(StateId mst, int *order) const
Definition: ngram-model.h:425
int NumNGrams(StateId st)
Definition: ngram-model.h:357
Arc::Weight Weight
Definition: ngram-model.h:87
int HiOrder() const
Definition: ngram-model.h:141
double CalculateBackoffCost(double hi_neglog_sum, double low_neglog_sum, bool infinite_backoff=false) const
Definition: ngram-model.h:247
bool CalcBONegLogSums(StateId st, double *hi_neglog_sum, double *low_neglog_sum, bool infinite_backoff=false, bool unigram=false) const
Definition: ngram-model.h:632
StateId NumStates() const
Definition: ngram-model.h:130
void UpdateState(StateId st, int order, bool unigram_state, const std::vector< Label > *ngram=0)
Definition: ngram-model.h:451
< epsilon >< epsilon > Infinity a a
Definition: hist.ref.txt:1
NGramModel(const fst::Fst< Arc > &infst, Label backoff_label)
Definition: ngram-model.h:108
bool FindArc(fst::ArcIterator< fst::Fst< Arc >> *biter, Label label) const
Definition: ngram-model.h:573
Arc::StateId StateId
Definition: ngram-model.h:85
static double ScalarValue(Weight w)
Definition: ngram-model.h:1122
< epsilon >< epsilon > Infinity a Infinity b b
Definition: hist.ref.txt:1
virtual ~NGramModel()=default
double NegLogDiff(double a, double b) const
Definition: ngram-model.h:530
bool FillBackoffArcWeights(StateId st, StateId bo, std::vector< double > *bo_arc_weight) const
Definition: ngram-model.h:542
static Weight UnitCount()
Definition: ngram-model.h:1133
bool FindNGramInModel(StateId *mst, int *order, Label label, double *cost) const
Definition: ngram-model.h:600
double EstimateTotalUnigramCount() const
Definition: ngram-model.h:500
const double kInfBackoff
Definition: ngram-model.h:39
double NormEps() const
Definition: ngram-model.h:354
Weight GetNGramCost(const std::vector< Label > &ngram) const
Definition: ngram-model.h:373
const double kNormEps
Definition: ngram-model.h:37
bool PrintStateNGram(StateId st, std::ostream &ostrm=std::cerr) const
Definition: ngram-model.h:648
Label BackoffLabel() const
Definition: ngram-model.h:202
bool CheckTopology() const
Definition: ngram-model.h:221
bool Error() const
Definition: ngram-model.h:524
const std::vector< Label > & StateNGram(StateId state) const
Definition: ngram-model.h:176
int64_t GetSize() const
Definition: ngram-model.h:133
Weight GetFinalWeight(StateId st) const
Definition: ngram-model.h:487
StateId UnigramState() const
Definition: ngram-model.h:185
double GetSymbolUnigramCost(Label symbol) const
Definition: ngram-model.h:188
#define NGRAMERROR()
Definition: util.h:26
int StateOrder(StateId state) const
Definition: ngram-model.h:145
const fst::Fst< Arc > & GetFst() const
Definition: ngram-model.h:302
NGramModel(const fst::Fst< Arc > &infst, Label backoff_label, double norm_eps, bool state_ngrams)
Definition: ngram-model.h:97
StateId NGramState(const std::vector< Label > &ngram) const
Definition: ngram-model.h:156
NGramModel(const fst::Fst< Arc > &infst)
Definition: ngram-model.h:118
void FillStateCounts(std::vector< double > *state_counts)
Definition: ngram-model.h:535
const double kFloatEps
Definition: ngram-model.h:38