NGram  ngram-1.3.15
OpenGrm-NGram library
ngram-count-merge.h
Go to the documentation of this file.
1 // Copyright 2005-2013 Brian Roark
2 // Copyright 2005-2020 Google LLC
3 //
4 // Licensed under the Apache License, Version 2.0 (the 'License');
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an 'AS IS' BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 // NGram model class for merging count FSTs.
17 
18 #ifndef NGRAM_NGRAM_COUNT_MERGE_H_
19 #define NGRAM_NGRAM_COUNT_MERGE_H_
20 
21 #include <fst/arc.h>
22 #include <ngram/ngram-merge.h>
23 #include <ngram/ngram-model.h>
24 #include <ngram/util.h>
25 
26 namespace ngram {
27 
28 class NGramCountMerge : public NGramMerge<fst::StdArc> {
29  public:
30  typedef fst::StdArc::StateId StateId;
31  typedef fst::StdArc::Label Label;
32 
33  // Constructs an NGramCountMerge object consisting of ngram model
34  // to be merged.
35  // Ownership of FST is retained by the caller.
36  explicit NGramCountMerge(fst::StdMutableFst *infst1,
37  Label backoff_label = 0, double norm_eps = kNormEps,
38  bool check_consistency = false)
39  : NGramMerge(infst1, backoff_label, norm_eps, check_consistency) {}
40 
41  // Perform count-model merger with n-gram model specified by the FST argument
42  // and mixing weights alpha and beta.
43  void MergeNGramModels(const fst::StdFst &infst2, double alpha,
44  double beta, bool norm = false) {
45  alpha_ = -log(alpha);
46  beta_ = -log(beta);
47  if (!NGramMerge<fst::StdArc>::MergeNGramModels(infst2, norm)) {
48  NGRAMERROR() << "Count merging failed";
50  }
51  }
52 
53  protected:
54  // Specifies resultant weight when combining a weight from each FST.
55  Weight MergeWeights(StateId s1, StateId s2, Label Label, Weight w1, Weight w2,
56  bool in_fst1, bool in_fst2) const override {
57  if (in_fst1 && in_fst2) {
58  return NegLogSum(w1.Value() + alpha_, w2.Value() + beta_);
59  } else if (in_fst1) {
60  return w1.Value() + alpha_;
61  } else {
62  return w2.Value() + beta_;
63  }
64  }
65 
66  // Specifies the normalization constant per state 'st' depending whether
67  // state was present in one or boths FSTs.
68  double NormWeight(StateId st, bool in_fst1, bool in_fst2) const override {
69  if (in_fst1 && in_fst2) {
70  return -NegLogSum(alpha_, beta_);
71  } else if (in_fst1) {
72  return -alpha_;
73  } else {
74  return -beta_;
75  }
76  }
77 
78  // Specifies if unshared arcs/final weights between the two
79  // FSTs in a merge have a non-trivial merge. In particular, this
80  // means MergeWeights() changes the arc or final weights; any
81  // destination state changes are not relevant here. When false, more
82  // efficient merging may be performed. If the arc/final_weight
83  // comes from the first FST, then 'in_fst1' is true.
84  bool MergeUnshared(bool in_fst1) const override {
85  return (in_fst1) ? (alpha_ != 0.0) : (beta_ != 0.0);
86  }
87 
88  private:
89  double alpha_; // weight to scale model ngram1
90  double beta_; // weight to scale model ngram2
91 };
92 
93 } // namespace ngram
94 
95 #endif // NGRAM_NGRAM_COUNT_MERGE_H_
fst::StdArc::Weight Weight
Definition: ngram-merge.h:41
fst::StdArc::StateId StateId
NGramCountMerge(fst::StdMutableFst *infst1, Label backoff_label=0, double norm_eps=kNormEps, bool check_consistency=false)
double NormWeight(StateId st, bool in_fst1, bool in_fst2) const override
void MergeNGramModels(const fst::StdFst &infst2, double alpha, double beta, bool norm=false)
const double kNormEps
Definition: ngram-model.h:37
Weight MergeWeights(StateId s1, StateId s2, Label Label, Weight w1, Weight w2, bool in_fst1, bool in_fst2) const override
#define NGRAMERROR()
Definition: util.h:26
bool MergeUnshared(bool in_fst1) const override
fst::StdArc::Label Label