NGram  ngram-1.3.15
OpenGrm-NGram library
ngram-complete.h
Go to the documentation of this file.
1 // Copyright 2005-2013 Brian Roark
2 // Copyright 2005-2020 Google LLC
3 //
4 // Licensed under the Apache License, Version 2.0 (the 'License');
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an 'AS IS' BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 // Complete a partial model by adding transitions to ensure proper topology
17 // as required by NGramModel.
18 
19 #ifndef NGRAM_NGRAM_COMPLETE_H_
20 #define NGRAM_NGRAM_COMPLETE_H_
21 
22 #include <set>
23 #include <vector>
24 
25 #include <fst/arcsort.h>
26 #include <fst/fst.h>
27 #include <fst/matcher.h>
28 #include <fst/mutable-fst.h>
29 #include <ngram/ngram-model.h>
30 #include <ngram/util.h>
31 
32 namespace ngram {
33 
34 // Ascends the NGram WFST from lower order states and collects state info.
35 template <class Arc>
37  const fst::Fst<Arc> &fst, int order, typename Arc::Label backoff_label,
38  std::vector<std::vector<typename Arc::StateId>> *order_states,
39  std::vector<int> *state_orders,
40  std::vector<typename Arc::StateId> *backoff_states) {
41  if (order >= order_states->size()) return false;
42  for (int i = 0; i < (*order_states)[order].size(); ++i) {
43  auto s = (*order_states)[order][i];
44  if ((*state_orders)[s] != order) {
45  NGRAMERROR() << "State " << s << " included in vector of states with "
46  << "order " << order << ", but that is not the case";
47  return false;
48  }
49  for (fst::ArcIterator<fst::Fst<Arc>> aiter(fst, s); !aiter.Done();
50  aiter.Next()) {
51  const Arc &arc = aiter.Value();
52  if (arc.ilabel == backoff_label) {
53  (*backoff_states)[s] = arc.nextstate;
54  } else if ((*state_orders)[arc.nextstate] <= 0) {
55  if (order_states->size() <= order + 1) order_states->resize(order + 2);
56  (*state_orders)[arc.nextstate] = order + 1;
57  (*order_states)[order + 1].push_back(arc.nextstate);
58  }
59  }
60  if (order > 1 && (*backoff_states)[s] == fst::kNoStateId) {
61  NGRAMERROR() << "No backoff state for higher order state " << s;
62  return false;
63  }
64  }
65  return true;
66 }
67 
68 template <class Arc>
69 bool NGramComplete(fst::MutableFst<Arc> *fst,
70  typename Arc::Label backoff_label = 0) {
71  typedef typename Arc::StateId StateId;
72  typedef typename Arc::Weight Weight;
73  typedef typename Arc::Label Label;
74  if (fst->NumStates() < 2) return true;
75  fst::ArcIterator<fst::Fst<Arc>> aiter(*fst, fst->Start());
76  const Arc &arc = aiter.Value();
77  if (arc.ilabel != backoff_label) {
78  NGRAMERROR() << "First arc out of start state is not backoff arc";
79  return false;
80  }
81  StateId unigram_state = arc.nextstate;
82  std::vector<int> state_orders(fst->NumStates());
83  std::vector<StateId> backoff_states(fst->NumStates(), fst::kNoStateId);
84  std::vector<std::vector<StateId>> order_states(3);
85  order_states[1].push_back(unigram_state);
86  state_orders[unigram_state] = 1;
87  order_states[2].push_back(fst->Start());
88  state_orders[fst->Start()] = 2;
89  backoff_states[fst->Start()] = unigram_state;
90  int st_order = 1;
91  while (st_order < order_states.size() && !order_states[st_order].empty()) {
92  if (!AscendAndCollectStateInfo(*fst, st_order++, backoff_label,
93  &order_states, &state_orders,
94  &backoff_states)) {
95  return false;
96  }
97  }
98 
99  for (StateId s = 0; s < fst->NumStates(); ++s) {
100  if (s != unigram_state && backoff_states[s] == fst::kNoStateId) {
101  NGRAMERROR() << "state with no backoff state: " << s;
102  return false;
103  }
104  if (state_orders[s] <= 0) {
105  NGRAMERROR() << "state not on ascending path: " << s;
106  return false;
107  }
108  }
109 
110  std::vector<std::set<Label>> label_sets(fst->NumStates());
111  std::set<StateId> new_final_states;
112 
113  std::unique_ptr<fst::Matcher<fst::Fst<Arc>>> matcher(
114  new fst::Matcher<fst::Fst<Arc>>(*fst, fst::MATCH_INPUT));
115 
116  for (int order = order_states.size() - 1; order > 1; --order) {
117  for (int idx = 0; idx < order_states[order].size(); ++idx) {
118  StateId s = order_states[order][idx];
119  StateId bs = backoff_states[s];
120  if (state_orders[s] != state_orders[bs] + 1) {
121  NGRAMERROR() << "State " << s << " backs off more than one order";
122  return false;
123  }
124  matcher->SetState(bs);
125  for (fst::ArcIterator<fst::Fst<Arc>> aiter(*fst, s);
126  !aiter.Done(); aiter.Next()) {
127  const Arc &arc = aiter.Value();
128  if (arc.ilabel == backoff_label || matcher->Find(arc.ilabel)) continue;
129  label_sets[bs].insert(arc.ilabel);
130  }
131 
132  for (auto iter = label_sets[s].begin(); iter != label_sets[s].end();
133  ++iter) {
134  if (matcher->Find(*iter)) continue;
135  label_sets[bs].insert(*iter);
136  }
137 
138  if ((NGramModel<Arc>::ScalarValue(fst->Final(s)) !=
139  NGramModel<Arc>::ScalarValue(Weight::Zero()) ||
140  new_final_states.count(s) != 0) &&
141  NGramModel<Arc>::ScalarValue(fst->Final(bs)) ==
142  NGramModel<Arc>::ScalarValue(Weight::Zero())) {
143  new_final_states.insert(bs);
144  }
145  }
146  }
147 
148  for (int order = 1; order < order_states.size() - 1; ++order) {
149  for (int idx = 0; idx < order_states[order].size(); ++idx) {
150  StateId s = order_states[order][idx];
151  if (label_sets[s].empty()) continue;
152  StateId bs = backoff_states[s];
153  std::unique_ptr<fst::Matcher<fst::Fst<Arc>>> updated_matcher(
154  new fst::Matcher<fst::Fst<Arc>>(*fst, fst::MATCH_INPUT));
155  if (bs < 0) {
156  updated_matcher->SetState(s);
157  } else {
158  updated_matcher->SetState(bs);
159  }
160 
161  std::vector<Arc> arcs;
162  arcs.reserve(fst->NumArcs(s) + label_sets[s].size());
163 
164  for (auto iter = label_sets[s].begin(); iter != label_sets[s].end();
165  ++iter) {
166  StateId nextstate = unigram_state;
167  if (bs >= 0 && updated_matcher->Find(*iter))
168  nextstate = updated_matcher->Value().nextstate;
169  arcs.push_back(
170  Arc(*iter, *iter, NGramModel<Arc>::UnitCount(), nextstate));
171  }
172 
173  if (arcs.empty()) {
174  NGRAMERROR() << "No arcs found";
175  return false;
176  }
177 
178  for (fst::ArcIterator<fst::Fst<Arc>> aiter(*fst, s);
179  !aiter.Done(); aiter.Next())
180  arcs.push_back(aiter.Value());
181 
182  fst->DeleteArcs(s);
183  std::sort(arcs.begin(), arcs.end(), fst::ILabelCompare<Arc>());
184 
185  for (size_t i = 0; i < arcs.size(); ++i) fst->AddArc(s, arcs[i]);
186  }
187  }
188 
189  for (auto iter = new_final_states.begin(); iter != new_final_states.end();
190  ++iter) {
191  fst->SetFinal(*iter, NGramModel<Arc>::UnitCount());
192  }
193 
194  if (NGramModel<Arc>::ScalarValue(fst->Final(unigram_state)) ==
195  NGramModel<Arc>::ScalarValue(Weight::Zero()))
196  fst->SetFinal(unigram_state, NGramModel<Arc>::UnitCount());
197  return true;
198 }
199 
200 } // namespace ngram
201 
202 #endif // NGRAM_NGRAM_COMPLETE_H_
bool NGramComplete(fst::MutableFst< Arc > *fst, typename Arc::Label backoff_label=0)
bool AscendAndCollectStateInfo(const fst::Fst< Arc > &fst, int order, typename Arc::Label backoff_label, std::vector< std::vector< typename Arc::StateId >> *order_states, std::vector< int > *state_orders, std::vector< typename Arc::StateId > *backoff_states)
#define NGRAMERROR()
Definition: util.h:26