NGram  ngram-1.3.15
OpenGrm-NGram library
ngram-model-merge.h
Go to the documentation of this file.
1 // Copyright 2005-2013 Brian Roark
2 // Copyright 2005-2020 Google LLC
3 //
4 // Licensed under the Apache License, Version 2.0 (the 'License');
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an 'AS IS' BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 // NGram model class for merging smoothed model FSTs.
17 
18 #ifndef NGRAM_NGRAM_MODEL_MERGE_H_
19 #define NGRAM_NGRAM_MODEL_MERGE_H_
20 
21 #include <fst/arc.h>
22 #include <ngram/ngram-merge.h>
23 #include <ngram/ngram-model.h>
24 #include <ngram/util.h>
25 
26 namespace ngram {
27 
28 class NGramModelMerge : public NGramMerge<fst::StdArc> {
29  public:
30  typedef fst::StdArc::StateId StateId;
31  typedef fst::StdArc::Label Label;
32 
33  // Constructs an NGramModelMerge object consisting of ngram model
34  // to be merged.
35  // Ownership of FST is retained by the caller.
36  explicit NGramModelMerge(fst::StdMutableFst *infst1,
37  Label backoff_label = 0, double norm_eps = kNormEps,
38  bool check_consistency = false)
39  : NGramMerge(infst1, backoff_label, norm_eps, check_consistency),
40  merge_norm_(true) {
41  if (!CheckNormalization()) {
42  NGRAMERROR() << "NGramModelMerge: Model 1 must be normalized to"
43  << " use smoothing in merging";
45  }
46  }
47 
48  // Perform smooth-model merge with n-gram model specified by the FST argument
49  // and mixing weights alpha and beta.
50  void MergeNGramModels(const fst::StdFst &infst2, double alpha,
51  double beta, bool norm = false) {
52  if (Error()) return;
53  NGramModel<fst::StdArc> mod2(infst2);
54  if (!mod2.CheckNormalization()) {
55  NGRAMERROR() << "NGramModelMerge: Model 2 must be normalized to"
56  << " use smoothing in merging";
58  return;
59  }
60  alpha_ = -log(alpha);
61  beta_ = -log(beta);
62  if (!NGramMerge<fst::StdArc>::MergeNGramModels(infst2, norm)) {
63  NGRAMERROR() << "NGramModelMerge: Model merging failed";
65  return;
66  }
67  if (!norm) merge_norm_ = false;
68  }
69 
70  protected:
71  // Specifies resultant weight when combining a weight from each FST
72  Weight MergeWeights(StateId s1, StateId s2, Label label, Weight w1, Weight w2,
73  bool in_fst1, bool in_fst2) const override {
74  if (label == BackoffLabel()) { // don't modify (needed) backoff weights
75  return in_fst1 ? w1.Value() : w2.Value();
76  } else {
77  return NegLogSum(w1.Value() + alpha_, w2.Value() + beta_);
78  }
79  }
80 
81  // Specifies normalization constant per state 'st' depending whether
82  // state was present in one or boths FSTs.
83  double NormWeight(StateId st, bool in_fst1, bool in_fst2) const override {
84  return -NegLogSum(alpha_, beta_);
85  }
86 
87  private:
88  double alpha_; // weight to scale model ngram1
89  double beta_; // weight to scale model ngram2
90  bool merge_norm_; // is the (possibly intermediate) result normalized?
91 };
92 
93 } // namespace ngram
94 
95 #endif // NGRAM_NGRAM_MODEL_MERGE_H_
NGramModelMerge(fst::StdMutableFst *infst1, Label backoff_label=0, double norm_eps=kNormEps, bool check_consistency=false)
fst::StdArc::Weight Weight
Definition: ngram-merge.h:41
double NormWeight(StateId st, bool in_fst1, bool in_fst2) const override
void MergeNGramModels(const fst::StdFst &infst2, double alpha, double beta, bool norm=false)
const double kNormEps
Definition: ngram-model.h:37
fst::StdArc::Label Label
#define NGRAMERROR()
Definition: util.h:26
fst::StdArc::StateId StateId
Weight MergeWeights(StateId s1, StateId s2, Label label, Weight w1, Weight w2, bool in_fst1, bool in_fst2) const override