NGram  ngram-1.3.15
OpenGrm-NGram library
ngram-replace-merge.h
Go to the documentation of this file.
1 // Copyright 2005-2013 Brian Roark
2 // Copyright 2005-2020 Google LLC
3 //
4 // Licensed under the Apache License, Version 2.0 (the 'License');
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an 'AS IS' BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 // NGram model class for merging FSTs via replacing weights.
17 
18 #ifndef NGRAM_NGRAM_REPLACE_MERGE_H_
19 #define NGRAM_NGRAM_REPLACE_MERGE_H_
20 
21 #include <vector>
22 
23 #include <fst/arc.h>
24 #include <fst/connect.h>
25 #include <ngram/ngram-merge.h>
26 #include <ngram/ngram-model.h>
27 #include <ngram/util.h>
28 
29 namespace ngram {
30 
31 class NGramReplaceMerge : public NGramMerge<fst::StdArc> {
32  public:
33  typedef fst::StdArc::StateId StateId;
34  typedef fst::StdArc::Label Label;
35 
36  // Constructs an NGramReplaceMerge object consisting of ngram model
37  // to be merged. Since normalization in this case is only handled by
38  // recalculating the backoff weights, the general merging mechanism is not
39  // asked to normalize. Ownership of FST is retained by the caller.
40  explicit NGramReplaceMerge(fst::StdMutableFst *infst1,
41  Label backoff_label = 0,
42  double norm_eps = kNormEps,
43  bool check_consistency = false)
44  : NGramMerge(infst1, backoff_label, norm_eps, check_consistency) {}
45 
46  // Performs replacement merge with n-gram model specified by the FST argument
47  // and a maximum order. For all orders up to the max_replace_order, it is
48  // assumed that infst2 has a superset of n-grams contained in the first model.
49  // Resulting model will have up to and including max_replace_order orders of
50  // the model from infst2, and any orders above that max from infst1.
51  void MergeNGramModels(const fst::StdFst &infst2,
52  int max_replace_order = -1, bool norm = false) {
53  if (Error()) return;
54  NGramModel<fst::StdArc> mod2(infst2);
56  infst2, /* norm = */ false, max_replace_order)) {
57  NGRAMERROR() << "NGramReplaceMerge: Model merging failed";
59  return;
60  }
61  TrimTopology();
62  if (norm) {
63  RecalcBackoff();
64  if (!CheckNormalization()) {
65  NGRAMERROR() << "NGramReplaceMerge: Merged model not fully normalized";
67  return;
68  }
69  }
70  }
71 
72  protected:
73  // Specifies resultant weight when combining a weight from each FST
74  Weight MergeWeights(StateId s1, StateId s2, Label label, Weight w1, Weight w2,
75  bool in_fst1, bool in_fst2) const override {
76  if (!in_fst2) {
77  NGRAMERROR() << "n-grams in original model must be a subset of n-grams "
78  "in the updating model.";
79  }
80  return w2;
81  }
82 
83  // Specifies normalization constant per state 'st' depending whether
84  // state was present in one or boths FSTs.
85  double NormWeight(StateId st, bool in_fst1, bool in_fst2) const override {
86  return 0.0;
87  }
88 
89  // Returns true to allow for max orders less than the model's high order.
90  bool MaxOrderOkay(int order) const override { return true; }
91 
92  private:
93  // Discards extra states from orders higher than the maximum merged order,
94  // and ensures that preserved arcs point to correct preserved states.
95  void TrimTopology() {
96  std::vector<StateId> dest_states(NumStates());
97  for (StateId st = 0; st < NumStates(); ++st) {
98  // Destinations states are identity unless state is a dead end. Sets
99  // destination state to -1 for dead end states, to be adjusted later.
100  dest_states[st] = GetFst().NumArcs(st) == 0 &&
101  ScalarValue(GetFst().Final(st)) ==
102  ScalarValue(fst::StdArc::Weight::Zero())
103  ? -1
104  : st;
105  }
106  for (StateId st = 0; st < NumStates(); ++st) {
107  // Ignores dead end states when adjusting arcs.
108  if (dest_states[st] != st) continue;
109  StateId bo = GetBackoff(st, nullptr);
110  for (fst::MutableArcIterator<fst::MutableFst<fst::StdArc>>
111  aiter(GetMutableFst(), st);
112  !aiter.Done(); aiter.Next()) {
113  fst::StdArc arc = aiter.Value();
114  if (dest_states[arc.nextstate] != arc.nextstate) {
115  // If the arc is pointing to a dead end state, change destination.
116  if (dest_states[arc.nextstate] < 0) {
117  // If the destination has not been set yet, find and store it.
118  UpdateDestStates(bo, arc, &dest_states);
119  }
120  if (dest_states[arc.nextstate] < 0) {
121  NGRAMERROR() << "Destination state not set.";
123  return;
124  }
125  arc.nextstate = dest_states[arc.nextstate];
126  aiter.SetValue(arc);
127  }
128  }
129  }
130 
131  // Discards all dead-end states and re-initializes model information.
132  Connect(GetMutableFst());
133  InitModel();
134  }
135 
136  // Finds correct destination for arcs pointing to dead end states.
137  void UpdateDestStates(StateId st, const fst::StdArc &in_arc,
138  std::vector<StateId> *dest_states) {
139  if ((*dest_states)[in_arc.nextstate] >= 0) {
140  NGRAMERROR() << "Destination state already set.";
142  return;
143  }
144  if (st < 0) {
145  // No backoff state found, should point arc to unigram state.
146  (*dest_states)[in_arc.nextstate] = UnigramState();
147  } else {
148  StateId bo = GetBackoff(st, nullptr);
149  fst::Matcher<fst::StdFst> matcher(GetFst(), fst::MATCH_INPUT);
150  matcher.SetState(st);
151  if (!matcher.Find(in_arc.ilabel)) {
152  NGRAMERROR() << "Could not find n-gram arc at backoff state.";
154  return;
155  }
156  fst::StdArc arc = matcher.Value();
157  if ((*dest_states)[arc.nextstate] < 0) {
158  // This arc also needs a new destination state.
159  UpdateDestStates(bo, arc, dest_states);
160  }
161  if ((*dest_states)[arc.nextstate] < 0) {
162  NGRAMERROR() << "destination state not set.";
164  return;
165  }
166  if (arc.nextstate != in_arc.nextstate) {
167  // Takes destination from valid destination state of backoff n-gram.
168  (*dest_states)[in_arc.nextstate] = (*dest_states)[arc.nextstate];
169  }
170  }
171  }
172 };
173 
174 } // namespace ngram
175 
176 #endif // NGRAM_NGRAM_REPLACE_MERGE_H_
fst::MutableFst< fst::StdArc > * GetMutableFst()
StateId GetBackoff(StateId st, Weight *bocost) const
Definition: ngram-model.h:205
NGramReplaceMerge(fst::StdMutableFst *infst1, Label backoff_label=0, double norm_eps=kNormEps, bool check_consistency=false)
fst::StdArc::Weight Weight
Definition: ngram-merge.h:41
fst::StdArc::StateId StateId
bool MaxOrderOkay(int order) const override
static double ScalarValue(Weight w)
Weight MergeWeights(StateId s1, StateId s2, Label label, Weight w1, Weight w2, bool in_fst1, bool in_fst2) const override
const double kNormEps
Definition: ngram-model.h:37
#define NGRAMERROR()
Definition: util.h:26
double NormWeight(StateId st, bool in_fst1, bool in_fst2) const override
const fst::Fst< fst::StdArc > & GetFst() const
Definition: ngram-model.h:302
void MergeNGramModels(const fst::StdFst &infst2, int max_replace_order=-1, bool norm=false)