18 #ifndef NGRAM_NGRAM_MODEL_H_ 19 #define NGRAM_NGRAM_MODEL_H_ 30 #include <fst/matcher.h> 42 static double NegLogDeltaValue(
double a,
double b,
double *c) {
43 double x = exp(a - b), delta = -log(x + 1);
46 for (
int j = 2; j <= 4; ++j) delta += pow(-x, j) / j;
55 static double NegLogSum(
double a,
double b,
double *c) {
56 if (a == fst::StdArc::Weight::Zero().Value())
return b;
57 if (b == fst::StdArc::Weight::Zero().Value())
return a;
58 if (a > b)
return NegLogSum(b, a, c);
59 double delta = NegLogDeltaValue(a, b, c), val = a + delta;
60 if (c) *c = (val -
a) - delta;
65 static double NegLogSum(
double a,
double b) {
return NegLogSum(a, b,
nullptr); }
70 static double NegLogDiff(
double a,
double b,
bool *error =
nullptr) {
71 if (b == fst::StdArc::Weight::Zero().Value())
return a;
73 if (a - b >= kNormEps) {
74 NGRAMERROR() <<
"NegLogDiff: undefined " << a <<
" " <<
b;
75 if (error) *error =
true;
77 return fst::StdArc::Weight::Zero().Value();
79 return b - log(exp(b - a) - 1);
86 typedef typename Arc::Label
Label;
97 NGramModel(
const fst::Fst<Arc> &infst, Label backoff_label,
98 double norm_eps,
bool state_ngrams)
100 backoff_label_(backoff_label),
102 have_state_ngrams_(state_ngrams),
110 backoff_label_(backoff_label),
112 have_state_ngrams_(false),
122 have_state_ngrams_(false),
135 for (StateId st = 0; st < nstates_; ++st)
136 size += fst_.NumArcs(st) + 1;
146 if (state >= 0 && state < nstates_)
147 return state_orders_[state];
158 if (state < 0) state =
GetFst().Start();
159 fst::Matcher<fst::Fst<Arc>> matcher(fst_, fst::MATCH_INPUT);
160 for (
auto it = ngram.begin(); it != ngram.end(); ++it) {
162 state = fst_.Start();
165 matcher.SetState(state);
166 if (!matcher.Find(*it))
break;
167 const Arc &arc = matcher.Value();
168 state = arc.nextstate;
177 if (!have_state_ngrams_) {
178 NGRAMERROR() <<
"NGramModel: state ngrams not available";
179 return empty_label_vector_;
181 return state_ngrams_[state];
189 fst::Matcher<fst::Fst<Arc>> matcher(fst_, fst::MATCH_INPUT);
190 StateId st = unigram_;
191 if (st < 0) st = fst_.Start();
192 matcher.SetState(st);
193 if (matcher.Find(symbol)) {
194 Arc arc = matcher.Value();
206 StateId backoff = -1;
207 fst::Matcher<fst::Fst<Arc>> matcher(fst_, fst::MATCH_INPUT);
208 matcher.SetState(st);
209 if (matcher.Find(backoff_label_)) {
210 for (; !matcher.Done(); matcher.Next()) {
211 Arc arc = matcher.Value();
212 if (arc.ilabel == fst::kNoLabel)
continue;
213 backoff = arc.nextstate;
214 if (bocost !=
nullptr) bocost[0] = arc.weight;
222 ascending_ngrams_ = 0;
224 for (StateId st = 0; st < nstates_; ++st)
225 if (!CheckTopologyState(st))
return false;
227 if (unigram_ != -1 && ascending_ngrams_ != nstates_ - 2) {
228 VLOG(1) <<
"Incomplete # of ascending n-grams: " << ascending_ngrams_;
236 if (
Error())
return false;
237 for (StateId st = 0; st < nstates_; ++st) {
238 if (!CheckNormalizationState(st)) {
239 VLOG(1) <<
"Failed normalization check at " << st;
248 bool infinite_backoff =
false)
const {
249 double nlog_backoff_num, nlog_backoff_denom;
251 hi_neglog_sum, low_neglog_sum, &nlog_backoff_num, &nlog_backoff_denom,
254 return nlog_backoff_num - nlog_backoff_denom;
259 double *nlog_backoff_num,
260 double *nlog_backoff_denom,
261 bool infinite_backoff =
false)
const {
262 double effective_zero = kNormEps *
kFloatEps, effective_nlog_zero = 99.0;
263 if (infinite_backoff && hi_neglog_sum <= kFloatEps)
265 if (hi_neglog_sum < effective_zero) hi_neglog_sum = effective_zero;
266 if (low_neglog_sum < effective_zero) low_neglog_sum = effective_zero;
267 if (low_neglog_sum <= 0 || hi_neglog_sum <= 0)
return true;
268 if (hi_neglog_sum > effective_nlog_zero) {
269 *nlog_backoff_num = 0.0;
271 *nlog_backoff_num =
NegLogDiff(0.0, hi_neglog_sum);
273 if (low_neglog_sum > effective_nlog_zero) {
274 *nlog_backoff_denom = 0.0;
276 *nlog_backoff_denom =
NegLogDiff(0.0, low_neglog_sum);
286 size_t maxiters = 10000)
const {
289 ret = StationaryStateProbs(probs, .999999, norm_eps_, maxiters);
291 NGramStateProbs(probs);
293 if (FST_FLAGS_v > 1) {
294 for (
size_t st = 0; st < probs->size(); ++st)
295 std::cerr <<
"st: " << st <<
" log_prob: " << log((*probs)[st])
302 const fst::Fst<Arc> &
GetFst()
const {
return fst_; }
307 using fst::kAcceptor;
308 using fst::kIDeterministic;
309 using fst::kILabelSorted;
311 using fst::kNoStateId;
314 if (fst_.Start() == kNoLabel) {
315 NGRAMERROR() <<
"NGramModel: Empty automaton";
319 uint64_t need_props = kAcceptor | kIDeterministic | kILabelSorted;
320 uint64_t have_props = fst_.Properties(need_props,
true);
321 if (!(have_props & kAcceptor)) {
322 NGRAMERROR() <<
"NGramModel: input not an acceptor";
326 if (!(have_props & kIDeterministic)) {
327 NGRAMERROR() <<
"NGramModel: input not deterministic";
331 if (!(have_props & kILabelSorted)) {
332 NGRAMERROR() <<
"NGramModel: input not label sorted";
337 if (!fst::CompatSymbols(fst_.InputSymbols(), fst_.OutputSymbols())) {
338 NGRAMERROR() <<
"NGramModel: input and output symbol tables do not match";
343 nstates_ = CountStates(fst_);
345 ComputeStateOrders();
347 NGRAMERROR() <<
"NGramModel: bad ngram model topology";
358 int num_ngrams = fst_.NumArcs(st);
374 if (ngram.empty())
return Weight::One();
376 StateId st = ngram.front() == 0 || unigram_ < 0 ? fst_.Start() : unigram_;
379 Weight cost = ngram.front() == 0 && unigram_ >= 0 ? fst_.Final(unigram_)
382 fst::Matcher<fst::Fst<Arc>> matcher(fst_, fst::MATCH_INPUT);
384 for (
int n = 0; n < ngram.size(); ++n) {
385 Label label = ngram[n];
387 if (n == 0)
continue;
388 if (n != ngram.size() - 1) {
389 NGRAMERROR() <<
"end-of-string is not the super-final word";
390 return Weight::Zero();
392 while (fst_.Final(st) == Weight::Zero()) {
396 return Weight::Zero();
398 cost = Times(cost, bocost);
400 cost = Times(cost, fst_.Final(st));
403 matcher.SetState(st);
404 if (matcher.Find(label)) {
405 Arc arc = matcher.Value();
407 cost = Times(cost, arc.weight);
413 return Weight::Zero();
415 cost = Times(cost, bocost);
426 Weight cost = Arc::Weight::One();
427 while (fst_.Final(mst) == Arc::Weight::Zero()) {
428 fst::Matcher<fst::Fst<Arc>> matcher(fst_, fst::MATCH_INPUT);
429 matcher.SetState(mst);
430 if (matcher.Find(backoff_label_)) {
431 for (; !matcher.Done(); matcher.Next()) {
432 Arc arc = matcher.Value();
433 if (arc.ilabel == backoff_label_) {
435 cost = Times(cost, arc.weight);
439 NGRAMERROR() <<
"NGramModel: No final cost in model: " << mst;
440 return Arc::Weight::Zero();
443 *order = state_orders_[mst];
445 cost = Times(cost, fst_.Final(mst));
452 const std::vector<Label> *
ngram = 0) {
453 if (have_state_ngrams_ && !
ngram) {
454 NGRAMERROR() <<
"NGramModel::UpdateState: no ngram provided";
458 if (state_orders_.size() < st) {
459 NGRAMERROR() <<
"NGramModel::UpdateState: bad state: " << st;
463 if (order > hi_order_) hi_order_ =
order;
465 if (state_orders_.size() == st) {
466 state_orders_.push_back(order);
470 state_orders_[st] =
order;
474 if (unigram_state) unigram_ = nstates_;
494 bocost = Arc::Weight::Zero();
503 double max = fst::LogArc::Weight::Zero().Value(), nextmax = max;
504 if (st < 0) st =
GetFst().Start();
505 for (fst::ArcIterator<fst::Fst<Arc>> aiter(
GetFst(), st);
506 !aiter.Done(); aiter.Next()) {
507 Arc arc = aiter.Value();
515 ScalarValue(arc.weight) > nextmax) {
519 if (nextmax == fst::LogArc::Weight::Zero().Value())
return exp(max);
524 bool Error()
const {
return error_; }
531 return ngram::NegLogDiff(a, b, &error_);
536 for (
int i = 0; i < nstates_; i++)
537 state_counts->push_back(ScalarValue(Arc::Weight::Zero()));
538 WalkStatesForCount(state_counts);
543 std::vector<double> *bo_arc_weight)
const {
544 fst::Matcher<fst::Fst<Arc>> matcher(
545 fst_, fst::MATCH_INPUT);
546 matcher.SetState(bo);
547 for (fst::ArcIterator<fst::Fst<Arc>> aiter(fst_, st); !aiter.Done();
549 Arc arc = aiter.Value();
550 if (arc.ilabel == backoff_label_)
continue;
551 if (matcher.Find(arc.ilabel)) {
552 Arc barc = matcher.Value();
561 bo_arc_weight->push_back(
ScalarValue(barc.weight) +
562 FactorValue(arc.weight));
564 NGRAMERROR() <<
"NGramModel: lower order arc missing: " << st;
573 bool FindArc(fst::ArcIterator<fst::Fst<Arc>> *biter,
575 while (!biter->Done()) {
576 Arc barc = biter->Value();
577 if (barc.ilabel == label)
579 else if (barc.ilabel < label)
589 Weight cost = Arc::Weight::Zero();
590 fst::Matcher<fst::Fst<Arc>> matcher(fst_, fst::MATCH_INPUT);
591 matcher.SetState(st);
592 if (matcher.Find(label)) {
593 Arc arc = matcher.Value();
601 double *cost)
const {
602 if (label < 0)
return false;
603 StateId currstate = *mst;
607 fst::Matcher<fst::Fst<Arc>> matcher(fst_, fst::MATCH_INPUT);
608 matcher.SetState(currstate);
609 if (matcher.Find(label)) {
610 Arc arc = matcher.Value();
611 *order = state_orders_[currstate];
612 *mst = arc.nextstate;
614 }
else if (matcher.Find(backoff_label_)) {
616 for (; !matcher.Done(); matcher.Next()) {
617 Arc arc = matcher.Value();
618 if (arc.ilabel == backoff_label_) {
619 currstate = arc.nextstate;
623 if (currstate < 0)
return false;
633 double *low_neglog_sum,
bool infinite_backoff =
false,
634 bool unigram =
false)
const {
636 if (bo < 0 && !unigram)
return false;
637 *low_neglog_sum = *hi_neglog_sum =
640 if (bo >= 0 && *hi_neglog_sum !=
ScalarValue(Arc::Weight::Zero()))
643 CalcArcNegLogSums(st, bo, hi_neglog_sum, low_neglog_sum, infinite_backoff);
649 ostrm <<
"state: " << st <<
" order: " << state_orders_[st] <<
" ngram: ";
650 for (
int i = 0; i < state_ngrams_[st].size(); ++i)
651 ostrm << state_ngrams_[st][i] <<
" ";
657 static double WeightRep(
double wt,
bool neglogs,
bool intcnts) {
658 if (!neglogs || intcnts) wt = exp(-wt);
659 if (intcnts) wt = round(wt);
665 bool CalcArcNegLogSums(StateId st, StateId bo,
double *hi_sum,
666 double *low_sum,
bool infinite_backoff =
false)
const {
668 double KahanVal1 = 0, KahanVal2 = 0;
669 double init_low = *low_sum;
670 fst::Matcher<fst::Fst<Arc>> matcher(
671 fst_, fst::MATCH_INPUT);
672 if (bo >= 0) matcher.SetState(bo);
673 for (fst::ArcIterator<fst::Fst<Arc>> aiter(fst_, st); !aiter.Done();
675 Arc arc = aiter.Value();
676 if (arc.ilabel == backoff_label_)
continue;
677 if (bo < 0 || matcher.Find(arc.ilabel)) {
679 Arc barc = matcher.Value();
681 NegLogSum(*low_sum,
ScalarValue(barc.weight), &KahanVal2);
684 NegLogSum(*hi_sum,
ScalarValue(arc.weight), &KahanVal1);
686 NGRAMERROR() <<
"NGramModel: No arc label match in backoff state: " 691 if (bo >= 0 && infinite_backoff && *low_sum == 0.0)
693 if (bo >= 0 && *low_sum <= 0.0) {
694 VLOG(2) <<
"lower order sum less than zero: " << st <<
" " << *low_sum;
695 double start_low =
ScalarValue(Arc::Weight::Zero());
696 if (init_low == start_low) start_low =
ScalarValue(fst_.Final(bo));
697 *low_sum = CalcBruteLowSum(st, bo, start_low);
698 VLOG(2) <<
"new lower order sum: " << st <<
" " << *low_sum;
705 double CalcBruteLowSum(StateId st, StateId bo,
double start_low)
const {
706 double low_sum = start_low, KahanVal = 0;
707 fst::Matcher<fst::Fst<Arc>> matcher(
708 fst_, fst::MATCH_INPUT);
709 matcher.SetState(bo);
710 fst::ArcIterator<fst::Fst<Arc>> biter(fst_, bo);
712 for (fst::ArcIterator<fst::Fst<Arc>> aiter(fst_, st); !aiter.Done();
714 Arc arc = aiter.Value();
715 if (arc.ilabel == backoff_label_)
continue;
716 barc = biter.Value();
717 while (!biter.Done() && barc.ilabel < arc.ilabel) {
718 if (barc.ilabel != backoff_label_)
720 NegLogSum(low_sum,
ScalarValue(barc.weight), &KahanVal);
722 if (!biter.Done()) barc = biter.Value();
724 if (!biter.Done() && barc.ilabel == arc.ilabel) {
726 if (!biter.Done()) barc = biter.Value();
728 if (biter.Done())
break;
730 while (!biter.Done()) {
731 if (barc.ilabel != backoff_label_)
733 NegLogSum(low_sum,
ScalarValue(barc.weight), &KahanVal);
735 if (!biter.Done()) barc = biter.Value();
741 void ComputeStateOrders() {
742 state_orders_.clear();
743 state_orders_.resize(nstates_, -1);
745 if (have_state_ngrams_) {
746 state_ngrams_.clear();
747 state_ngrams_.resize(nstates_);
751 std::deque<StateId> state_queue;
752 if (unigram_ != fst::kNoStateId) {
753 state_orders_[unigram_] = 1;
754 state_queue.push_back(unigram_);
755 state_orders_[fst_.Start()] = hi_order_ = 2;
756 state_queue.push_back(fst_.Start());
757 if (have_state_ngrams_)
758 state_ngrams_[fst_.Start()].push_back(0);
760 state_orders_[fst_.Start()] = 1;
761 state_queue.push_back(fst_.Start());
764 while (!state_queue.empty()) {
765 StateId state = state_queue.front();
766 state_queue.pop_front();
767 for (fst::ArcIterator<fst::Fst<Arc>> aiter(fst_, state);
768 !aiter.Done(); aiter.Next()) {
769 const Arc &arc = aiter.Value();
770 if (state_orders_[arc.nextstate] == -1) {
771 state_orders_[arc.nextstate] = state_orders_[state] + 1;
772 if (have_state_ngrams_) {
773 state_ngrams_[arc.nextstate] = state_ngrams_[state];
774 state_ngrams_[arc.nextstate].push_back(arc.ilabel);
776 if (state_orders_[state] >= hi_order_)
777 hi_order_ = state_orders_[state] + 1;
778 state_queue.push_back(arc.nextstate);
785 bool CheckTopologyState(StateId st)
const {
786 if (unigram_ == -1) {
787 if (fst_.Final(fst_.Start()) == Arc::Weight::Zero()) {
788 VLOG(1) <<
"CheckTopology: bad final weight for start state";
796 fst::Matcher<fst::Fst<Arc>> matcher(
797 fst_, fst::MATCH_INPUT);
799 if (st == unigram_) {
800 if (fst_.Final(unigram_) == Arc::Weight::Zero()) {
801 VLOG(1) <<
"CheckTopology: bad final weight for unigram state: " 804 }
else if (have_state_ngrams_ && !state_ngrams_[unigram_].empty()) {
805 VLOG(1) <<
"CheckTopology: bad unigram state: " << unigram_;
810 VLOG(1) <<
"CheckTopology: no backoff state: " << st;
814 if (fst_.Final(st) != Arc::Weight::Zero() &&
815 fst_.Final(bos) == Arc::Weight::Zero()) {
816 VLOG(1) <<
"CheckTopology: bad final weight for backoff state: " << st;
821 VLOG(1) <<
"CheckTopology: bad backoff arc from: " << st
822 <<
" with order: " <<
StateOrder(st) <<
" to state: " << bos
826 matcher.SetState(bos);
829 for (fst::ArcIterator<fst::Fst<Arc>> aiter(fst_, st); !aiter.Done();
831 Arc arc = aiter.Value();
835 if (have_state_ngrams_ && !CheckStateNGrams(st, arc)) {
836 VLOG(1) <<
"CheckTopology: inconsistent n-gram states: " << st <<
" -- " 837 << arc.ilabel <<
"/" << arc.weight <<
" -> " << arc.nextstate;
841 if (st != unigram_) {
842 if (arc.ilabel == backoff_label_)
continue;
843 if (!matcher.Find(arc.ilabel)) {
844 VLOG(1) <<
"CheckTopology: unmatched arc at backoff state: " 845 << arc.ilabel <<
"/" << arc.weight <<
" for state: " << st;
854 bool CheckStateNGrams(StateId st,
const Arc &arc)
const {
855 std::vector<Label> state_ngram;
856 bool boa = arc.ilabel == backoff_label_;
858 int j = state_orders_[st] - state_orders_[arc.nextstate] + (boa ? 0 : 1);
859 if (j < 0)
return false;
861 for (
int i = j; i < state_ngrams_[st].size(); ++i)
862 state_ngram.push_back(state_ngrams_[st][i]);
863 if (!boa && j <= state_ngrams_[st].size())
864 state_ngram.push_back(arc.ilabel);
866 return state_ngram == state_ngrams_[arc.nextstate];
871 bool CheckNormalizationState(StateId st)
const {
873 Weight bocost = Weight::NoWeight();
877 if (bo >= 0 && Norm !=
ScalarValue(Arc::Weight::Zero()))
879 if (!CalcArcNegLogSums(st, bo, &Norm, &Norm1,
883 return EvaluateNormalization(st, bo,
ScalarValue(bocost), Norm, Norm1);
887 bool EvaluateNormalization(StateId st, StateId bo,
double bocost,
double norm,
888 double norm1)
const {
889 double newnorm = norm;
891 newnorm = NegLogSum(norm, bocost);
892 if (newnorm < norm1 + bocost)
893 newnorm =
NegLogDiff(newnorm, norm1 + bocost);
895 newnorm =
NegLogDiff(norm1 + bocost, newnorm);
898 if (fabs(newnorm) > norm_eps_ &&
899 (bo < 0 || !ReevaluateNormalization(st, bocost, norm, norm1))) {
900 VLOG(2) <<
"State ID: " << st <<
"; " << fst_.NumArcs(st) <<
" arcs;" 901 <<
" -log(sum(P)) = " << newnorm <<
", should be 0";
902 VLOG(2) << norm <<
" " << norm1;
911 bool ReevaluateNormalization(StateId st,
double bocost,
double norm,
912 double norm1)
const {
915 VLOG(2) <<
"Required re-evaluation of normalization: state " << st <<
" " 916 << norm <<
" " << norm1 <<
" " << newalpha <<
" " << norm_eps_;
917 if (fabs(newalpha - bocost) > norm_eps_)
return false;
922 void CollectPrefixCounts(std::vector<double> *state_counts,
924 for (fst::ArcIterator<fst::Fst<Arc>> aiter(fst_, st); !aiter.Done();
926 Arc arc = aiter.Value();
927 if (arc.ilabel != backoff_label_ &&
928 state_orders_[st] < state_orders_[arc.nextstate]) {
929 (*state_counts)[arc.nextstate] =
ScalarValue(arc.weight);
930 CollectPrefixCounts(state_counts, arc.nextstate);
936 void WalkStatesForCount(std::vector<double> *state_counts)
const {
937 if (unigram_ != -1) {
938 (*state_counts)[fst_.Start()] =
ScalarValue(fst_.Final(unigram_));
939 CollectPrefixCounts(state_counts, unigram_);
941 CollectPrefixCounts(state_counts, fst_.Start());
947 bool MixtureConsistent()
const {
948 fst::Matcher<fst::Fst<Arc>> matcher(
949 fst_, fst::MATCH_INPUT);
950 for (StateId st = 0; st < nstates_; ++st) {
956 matcher.SetState(bo);
957 for (fst::ArcIterator<fst::Fst<Arc>> aiter(fst_, st);
958 !aiter.Done(); aiter.Next()) {
959 Arc arc = aiter.Value();
960 if (arc.ilabel == backoff_label_) {
963 if (matcher.Find(arc.ilabel)) {
964 Arc barc = matcher.Value();
966 ScalarValue(barc.weight) + ScalarValue(bocost)) {
970 NGRAMERROR() <<
"NGramModel: lower order arc missing: " << st;
986 void NGramStateProb(StateId st, std::vector<double> *probs)
const {
987 for (fst::ArcIterator<fst::Fst<Arc>> aiter(fst_, st); !aiter.Done();
989 Arc arc = aiter.Value();
990 if (arc.ilabel == backoff_label_)
continue;
991 if (state_orders_[arc.nextstate] > state_orders_[st]) {
992 (*probs)[arc.nextstate] = (*probs)[st] * exp(-
ScalarValue(arc.weight));
993 NGramStateProb(arc.nextstate, probs);
1001 void NGramStateProbs(std::vector<double> *probs,
bool norm =
false)
const {
1003 probs->resize(nstates_, 0.0);
1006 (*probs)[fst_.Start()] = 1.0;
1009 (*probs)[unigram_] = 1.0;
1010 NGramStateProb(unigram_, probs);
1012 (*probs)[fst_.Start()] = exp(-
ScalarValue(fst_.Final(unigram_)));
1014 NGramStateProb(fst_.Start(), probs);
1018 for (
size_t st = 0; st < probs->size(); ++st) sum += (*probs)[st];
1019 for (
size_t st = 0; st < probs->size(); ++st) (*probs)[st] /= sum;
1027 void StationaryStateProb(StateId st, std::vector<double> *init_probs,
1028 std::vector<double> *probs,
double alpha)
const {
1029 fst::Matcher<fst::Fst<Arc>> matcher(
1030 fst_, fst::MATCH_INPUT);
1035 matcher.SetState(bo);
1036 (*init_probs)[bo] += (*init_probs)[st] * exp(-
ScalarValue(bocost));
1039 for (fst::ArcIterator<fst::Fst<Arc>> aiter(fst_, st); !aiter.Done();
1041 Arc arc = aiter.Value();
1042 if (arc.ilabel == backoff_label_)
continue;
1043 (*probs)[arc.nextstate] +=
1044 (*init_probs)[st] * exp(-
ScalarValue(arc.weight));
1045 if (bo != -1 && matcher.Find(arc.ilabel)) {
1047 const Arc &barc = matcher.Value();
1048 (*probs)[barc.nextstate] -=
1050 exp(-
ScalarValue(barc.weight) - ScalarValue(bocost));
1055 (*probs)[fst_.Start()] +=
1056 (*init_probs)[st] * exp(-
ScalarValue(fst_.Final(st))) * alpha;
1059 (*probs)[fst_.Start()] -=
1071 bool StationaryStateProbs(std::vector<double> *probs,
double alpha,
1072 double converge_eps,
size_t maxiters)
const {
1073 std::vector<double> init_probs, last_probs;
1075 NGramStateProbs(&init_probs,
true);
1076 last_probs = init_probs;
1082 probs->resize(nstates_, 0.0);
1084 for (
size_t st = 0; st < nstates_; ++st) {
1085 if (state_orders_[st] ==
order)
1086 StationaryStateProb(st, &init_probs, probs, alpha);
1091 for (
size_t st = 0; st < nstates_; ++st) {
1092 if (fabs((*probs)[st] - last_probs[st]) > converge_eps * last_probs[st])
1094 last_probs[st] = init_probs[st] = (*probs)[st];
1096 VLOG(2) <<
"NGramModel::StationaryStateProbs: state probs changed: " 1098 if (++iters > maxiters)
return false;
1099 }
while (changed > 0);
1103 const fst::Fst<Arc> &fst_;
1105 Label backoff_label_;
1109 std::vector<int> state_orders_;
1110 bool have_state_ngrams_;
1111 mutable size_t ascending_ngrams_;
1112 std::vector<std::vector<Label>>
1114 const std::vector<Label> empty_label_vector_;
1115 mutable bool error_;
1121 template <
typename T>
1129 return w.Value(0).Value();
1132 template <
typename Arc>
1134 return Arc::Weight::One();
1139 std::array<fst::StdArc::Weight, kHistogramBins> weights;
1140 weights.fill(fst::StdArc::Weight::Zero());
1143 static const fst::PowerWeight<fst::StdArc::Weight, kHistogramBins>
1144 one(weights.begin(), weights.end());
1148 template <
typename T>
1156 return w.Value(1).Value();
1161 #endif // NGRAM_NGRAM_MODEL_H_ static double WeightRep(double wt, bool neglogs, bool intcnts)
static double FactorValue(Weight w)
bool CalculateBackoffFactors(double hi_neglog_sum, double low_neglog_sum, double *nlog_backoff_num, double *nlog_backoff_denom, bool infinite_backoff=false) const
StateId GetBackoff(StateId st, Weight *bocost) const
bool CalculateStateProbs(std::vector< double > *probs, bool stationary=false, size_t maxiters=10000) const
constexpr size_t kHistogramBins
Weight GetBackoffCost(StateId st) const
Weight FindArcWeight(StateId st, Label label) const
bool CheckNormalization() const
Weight FinalCostInModel(StateId mst, int *order) const
int NumNGrams(StateId st)
double CalculateBackoffCost(double hi_neglog_sum, double low_neglog_sum, bool infinite_backoff=false) const
bool CalcBONegLogSums(StateId st, double *hi_neglog_sum, double *low_neglog_sum, bool infinite_backoff=false, bool unigram=false) const
StateId NumStates() const
void UpdateState(StateId st, int order, bool unigram_state, const std::vector< Label > *ngram=0)
< epsilon >< epsilon > Infinity a a
NGramModel(const fst::Fst< Arc > &infst, Label backoff_label)
bool FindArc(fst::ArcIterator< fst::Fst< Arc >> *biter, Label label) const
static double ScalarValue(Weight w)
< epsilon >< epsilon > Infinity a Infinity b b
virtual ~NGramModel()=default
double NegLogDiff(double a, double b) const
bool FillBackoffArcWeights(StateId st, StateId bo, std::vector< double > *bo_arc_weight) const
static Weight UnitCount()
bool FindNGramInModel(StateId *mst, int *order, Label label, double *cost) const
double EstimateTotalUnigramCount() const
Weight GetNGramCost(const std::vector< Label > &ngram) const
bool PrintStateNGram(StateId st, std::ostream &ostrm=std::cerr) const
Label BackoffLabel() const
bool CheckTopology() const
const std::vector< Label > & StateNGram(StateId state) const
Weight GetFinalWeight(StateId st) const
StateId UnigramState() const
double GetSymbolUnigramCost(Label symbol) const
int StateOrder(StateId state) const
const fst::Fst< Arc > & GetFst() const
NGramModel(const fst::Fst< Arc > &infst, Label backoff_label, double norm_eps, bool state_ngrams)
StateId NGramState(const std::vector< Label > &ngram) const
NGramModel(const fst::Fst< Arc > &infst)
void FillStateCounts(std::vector< double > *state_counts)