NGram  ngram-1.3.15
OpenGrm-NGram library
ngram-bayes-model-merge.h
Go to the documentation of this file.
1 // Copyright 2005-2013 Brian Roark
2 // Copyright 2005-2020 Google LLC
3 //
4 // Licensed under the Apache License, Version 2.0 (the 'License');
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an 'AS IS' BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 // Implements 'full' bayes model merging.
17 
18 #ifndef NGRAM_NGRAM_BAYES_MODEL_MERGE_H_
19 #define NGRAM_NGRAM_BAYES_MODEL_MERGE_H_
20 
21 #include <cmath>
22 #include <vector>
23 
24 #include <fst/arc.h>
25 #include <fst/fst.h>
26 #include <fst/mutable-fst.h>
27 #include <ngram/ngram-merge.h>
28 #include <ngram/ngram-model.h>
29 
30 namespace ngram {
31 
32 class NGramBayesModelMerge : public NGramMerge<fst::StdArc> {
33  public:
34  typedef fst::StdArc::StateId StateId;
35  typedef fst::StdArc::Label Label;
36 
37  // Constructs an NGramBayesModelMerge object consisting of ngram model
38  // to be merged.
39  // Ownership of FST is retained by the caller.
40  explicit NGramBayesModelMerge(fst::StdMutableFst *infst1,
41  Label backoff_label = 0,
42  double norm_eps = kNormEps)
43  : NGramMerge<fst::StdArc>(infst1, backoff_label, norm_eps, true) {
44  if (!CheckNormalization()) {
45  NGRAMERROR() << "NGramBayesModelMerge: Model 1 must be normalized to"
46  << " use smoothing in merging";
48  }
49  }
50 
51  // Perform smooth-model merge with n-gram model specified by the FST argument
52  // and mixing weights alpha and beta. Resultant model will be normalized.
53  void MergeNGramModels(const fst::StdFst &infst2, double alpha,
54  double beta) {
55  NGramModel<fst::StdArc> mod2(infst2);
56  if (!mod2.CheckNormalization()) {
57  NGRAMERROR() << "NGramBayesModelMerge: Model 2 must be normalized to"
58  << " use smoothing in merging";
60  return;
61  }
62  alpha_ = -log(alpha);
63  beta_ = -log(beta);
64  state_alpha_.clear();
65  if (!NGramMerge<fst::StdArc>::MergeNGramModels(infst2, true)) {
66  NGRAMERROR() << "NGramBayesModelMerge: Model merging failed";
68  }
69  }
70 
71  protected:
72  // Specifies resultant weight when combining a weight from each FST
73  fst::StdArc::Weight MergeWeights(StateId s1, StateId s2, Label label,
74  fst::StdArc::Weight w1,
75  fst::StdArc::Weight w2, bool in_fst1,
76  bool in_fst2) const override {
77  if (label == BackoffLabel()) { // don't modify (needed) backoff weights
78  return in_fst1 ? w1.Value() : w2.Value();
79  } else {
80  StateId st = in_fst1 ? s1 : ExactMap2To1(s2);
81  double alpha = StateAlpha(st);
82  double beta = NegLogDiff(0.0, alpha);
83  return NegLogSum(w1.Value() + alpha, w2.Value() + beta);
84  }
85  }
86 
87  private:
88  // normalized state weight to scale model ngram1
89  double StateAlpha(StateId st) const {
90  while (st >= state_alpha_.size()) state_alpha_.push_back(-1.0);
91  if (state_alpha_[st] < 0.0) {
92  const std::vector<Label> &ngram = StateNGram(st);
93 
94  // -log p(h|k), k=1,2
95  double w1 = ScalarValue(GetNGramCost(ngram));
96  double w2 = ScalarValue(NGram2().GetNGramCost(ngram));
97 
98  // p(k|h) = p(h|k) p(k) / sum_k' p(h|k') p(k')
99  state_alpha_[st] = w1 + alpha_;
100 
101  // Only normalize non-infinite cost (to avoid potential NaN issues).
102  // If state_alpha_[st] = inf (i.e., p = 0), then normalized is also inf.
103  if (state_alpha_[st] < fst::StdArc::Weight::Zero().Value())
104  state_alpha_[st] -= NegLogSum(w1 + alpha_, w2 + beta_);
105  }
106  return state_alpha_[st];
107  }
108 
109  double alpha_; // global weight to scale model ngram1
110  double beta_; // global weight to scale model ngram2
111 
112  // stored normalized state weight to scale model ngram1
113  mutable std::vector<double> state_alpha_;
114 };
115 
116 } // namespace ngram
117 
118 #endif // NGRAM_NGRAM_BAYES_MODEL_MERGE_H_
NGramBayesModelMerge(fst::StdMutableFst *infst1, Label backoff_label=0, double norm_eps=kNormEps)
const NGramModel< fst::StdArc > & NGram2() const
Definition: ngram-merge.h:143
fst::StdArc::Weight MergeWeights(StateId s1, StateId s2, Label label, fst::StdArc::Weight w1, fst::StdArc::Weight w2, bool in_fst1, bool in_fst2) const override
StateId ExactMap2To1(StateId s2) const
Definition: ngram-merge.h:154
void MergeNGramModels(const fst::StdFst &infst2, double alpha, double beta)
static double ScalarValue(Weight w)
double NegLogDiff(double a, double b) const
Definition: ngram-model.h:530
Weight GetNGramCost(const std::vector< Label > &ngram) const
Definition: ngram-model.h:373
const double kNormEps
Definition: ngram-model.h:37
const std::vector< Label > & StateNGram(StateId state) const
Definition: ngram-model.h:176
#define NGRAMERROR()
Definition: util.h:26