NGram  ngram-1.3.15
OpenGrm-NGram library
ngram-hist-merge.h
Go to the documentation of this file.
1 // Copyright 2005-2013 Brian Roark
2 // Copyright 2005-2020 Google LLC
3 //
4 // Licensed under the Apache License, Version 2.0 (the 'License');
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an 'AS IS' BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 // NGram model class for merging histogram FSTs.
17 
18 #ifndef NGRAM_NGRAM_HIST_MERGE_H_
19 #define NGRAM_NGRAM_HIST_MERGE_H_
20 
21 #include <array>
22 #include <cmath>
23 
24 #include <fst/mutable-fst.h>
25 #include <ngram/hist-arc.h>
26 #include <ngram/ngram-merge.h>
27 #include <ngram/ngram-model.h>
28 #include <ngram/util.h>
29 
30 namespace ngram {
31 
32 class NGramHistMerge : public NGramMerge<HistogramArc> {
33  public:
34  typedef HistogramArc::StateId StateId;
35  typedef HistogramArc::Label Label;
36 
37  // Constructs an NGramCountMerge object consisting of ngram model
38  // to be merged.
39  // Ownership of FST is retained by the caller.
40  explicit NGramHistMerge(fst::MutableFst<HistogramArc> *infst1,
41  Label backoff_label = 0, double norm_eps = kNormEps,
42  bool check_consistency = false)
43  : NGramMerge(infst1, backoff_label, norm_eps, check_consistency) {}
44 
45  // Perform count-model merger with n-gram model specified by the FST argument
46  // and mixing weights alpha and beta.
47  void MergeNGramModels(const fst::Fst<HistogramArc> &infst2, double alpha,
48  double beta, bool norm = false) {
49  alpha_ = -log(alpha);
50  beta_ = -log(beta);
51  if (!NGramMerge::MergeNGramModels(infst2, norm)) {
52  NGRAMERROR() << "Histogram count merging failed";
54  }
55  }
56 
57  protected:
58  // Specifies resultant weight when combining a weight from each FST.
59  Weight MergeWeights(StateId s1, StateId s2, Label Label, Weight w1, Weight w2,
60  bool in_fst1, bool in_fst2) const override {
61  if (in_fst1 && in_fst2) {
62  return NGramHistMerge::WeightSum(w1, w2);
63  } else if (in_fst1) {
64  return w1;
65  } else {
66  return w2;
67  }
68  }
69 
70  // TODO(vitalyk): this does nothing!
71  // Specifies the normalization constant per state 'st' depending whether
72  // state was present in one or boths FSTs.
73  double NormWeight(StateId st, bool in_fst1, bool in_fst2) const override {
74  if (in_fst1 && in_fst2) {
75  return -NegLogSum(alpha_, beta_);
76  } else if (in_fst1) {
77  return -alpha_;
78  } else {
79  return -beta_;
80  }
81  }
82 
83  // Specifies if unshared arcs/final weights between the two
84  // FSTs in a merge have a non-trivial merge. In particular, this
85  // means MergeWeights() changes the arc or final weights; any
86  // destination state changes are not relevant here. When false, more
87  // efficient merging may be performed. If the arc/final_weight
88  // comes from the first FST, then 'in_fst1' is true.
89  bool MergeUnshared(bool in_fst1) const override {
90  return (in_fst1) ? (alpha_ != 0.0) : (beta_ != 0.0);
91  }
92 
93  private:
94  // Add together two weights using addition for histogram weights.
95  // Histogram weight is a tuple where first coordinate corresponds
96  // to expected count and the rest K+1 coordinates indicate
97  // the probability of observing index-1 occurrences of the n-gram
98  // (associated with this weight).
99  Weight WeightSum(Weight w1, Weight w2) const {
100  std::array<fst::TropicalWeight, kHistogramBins> v;
101  v.fill(fst::TropicalWeight::Zero());
102  v[0] = NegLogSum(w1.Value(0).Value(), w2.Value(0).Value());
103 
104  for (int k = 0; k < v.size() - 1; k++) {
105  for (int j = 0; j <= k; j++) {
106  v[k + 1] = NegLogSum(v[k + 1].Value(), w1.Value(j + 1).Value() +
107  w2.Value(k - j + 1).Value());
108  }
109  }
110  return fst::PowerWeight<fst::TropicalWeight, kHistogramBins>(
111  v.begin(), v.end());
112  }
113 
114  double alpha_; // weight to scale model ngram1
115  double beta_; // weight to scale model ngram2
116 };
117 
118 } // namespace ngram
119 
120 #endif // NGRAM_NGRAM_HIST_MERGE_H_
bool MergeUnshared(bool in_fst1) const override
HistogramArc::Weight Weight
Definition: ngram-merge.h:41
NGramHistMerge(fst::MutableFst< HistogramArc > *infst1, Label backoff_label=0, double norm_eps=kNormEps, bool check_consistency=false)
bool MergeNGramModels(const fst::Fst< Arc > &infst2, bool norm=false, int max_order=-1)
Definition: ngram-merge.h:89
HistogramArc::Label Label
void MergeNGramModels(const fst::Fst< HistogramArc > &infst2, double alpha, double beta, bool norm=false)
Weight MergeWeights(StateId s1, StateId s2, Label Label, Weight w1, Weight w2, bool in_fst1, bool in_fst2) const override
HistogramArc::StateId StateId
const double kNormEps
Definition: ngram-model.h:37
#define NGRAMERROR()
Definition: util.h:26
double NormWeight(StateId st, bool in_fst1, bool in_fst2) const override