GRM-SFST  sfst-1.2.1
OpenGrm SFst Library
topology.h
Go to the documentation of this file.
1 // Copyright 2018-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // Algorithms for constructing specific FST topologies.
15 
16 #ifndef NLP_GRM2_SFST_TOPOLOGY_H_
17 #define NLP_GRM2_SFST_TOPOLOGY_H_
18 
19 #include <sys/types.h>
20 
21 #include <cstddef>
22 #include <utility>
23 #include <vector>
24 
25 #include <fst/log.h>
26 #include <fst/arcsort.h>
27 #include <fst/float-weight.h>
28 #include <fst/fst.h>
29 #include <fst/mutable-fst.h>
30 #include <fst/weight.h>
31 
32 namespace sfst {
33 
34 // Finds the ngrams up to a given order in a canonical SFSA.
35 // N-grams that contain embedded failure or epsilon transitions
36 // are not considered.
37 template <class Arc>
39  public:
40  using StateId = typename Arc::StateId;
41  using Label = typename Arc::Label;
42  using Weight = typename Arc::Weight;
43 
44  // Constructs an NGramTopology object for a given n-gram 'order'.
45  // The 'phi_label' is the failure label. Result is returned
46  // in 'ofst'. Output will be a canonical OpenGrm ngram FST model.
47  NGramTopology(int order, Label phi_label, fst::MutableFst<Arc> *ofst);
48 
49  // Computes ngram topology of 'ifst'. Assumes (but does not fully
50  // check) the input is a canonical stochastic FSA (see canonical.h).
51  void FindNGrams(const fst::Fst<Arc> &ifst);
52 
53  private:
54  // Ngram state data.
55  struct NGramState {
56  StateId backoff_state; // ID of the backoff state for the current state.
57  int order; // N-gram order of the state (of the outgoing arcs)
58  NGramState(StateId s, int o) : backoff_state(s), order(o) { }
59  };
60 
61  using Pair = std::pair<ssize_t, ssize_t>;
62 
63  struct PairHash {
64  size_t operator()(const Pair &p) const {
65  return (static_cast<size_t>(p.first) * 55697) ^
66  (static_cast<size_t>(p.second) * 54631);
67  }
68  };
69 
70  using DestMap = std::unordered_map<Pair, ssize_t, PairHash>;
71 
72  // Adds n-gram corresponding to the arc labeled 'label'
73  // out of state with ID 'state_id' if it doesn't already exist.
74  // Returns its destination state.
75  StateId UpdateNGram(StateId state_id, Label label, DestMap *dest_map);
76 
77  // UpdateNGram for n-grams that end in the super-final label.
78  void UpdateFinalNGram(StateId state_id) {
79  namespace f = fst;
80  while (state_id != f::kNoStateId &&
81  ofst_->Final(state_id) == Weight::Zero()) {
82  ofst_->SetFinal(state_id, Weight::One());
83  state_id = states_[state_id].backoff_state;
84  }
85  }
86 
87  int order_; // Max. order of n-gram being counted
88  Label phi_label_; // Failure label
89  fst::MutableFst<Arc> *ofst_; // Output n-gram FST
90  std::vector<NGramState> states_; // Maps state IDs to NGramStates
91  StateId initial_; // ID of start state
92  StateId backoff_; // ID of unigram/backoff state
93  fst::WeightConvert<fst::Log64Weight, Weight> from_log_;
94  fst::WeightConvert<Weight, fst::Log64Weight> to_log_;
95 
96  NGramTopology(const NGramTopology &) = delete;
97  NGramTopology &operator=(const NGramTopology &) = delete;
98 };
99 
100 template <class Arc>
102  int order, Label phi_label, fst::MutableFst<Arc> *ofst)
103  : order_(order),
104  phi_label_(phi_label),
105  ofst_(ofst) {
106  namespace f = fst;
107  ofst->DeleteStates();
108 
109  if (order == 0) {
110  FSTERROR() << "NGramTopology: order must be greater than 0";
111  ofst->SetProperties(f::kError, f::kError);
112  return;
113  }
114  if (phi_label_ == f::kNoLabel) {
115  FSTERROR() << "NGramTopology: bad phi label: " << phi_label_;
116  ofst->SetProperties(f::kError, f::kError);
117  return;
118  }
119 
120  backoff_ = ofst->AddState();
121  states_.push_back(NGramState(f::kNoStateId, 1));
122  if (order == 1) {
123  initial_ = backoff_;
124  } else {
125  initial_ = ofst->AddState();
126  states_.push_back(NGramState(backoff_, 2));
127  }
128  ofst->SetStart(initial_);
129 }
130 
131 template <class Arc>
132 void NGramTopology<Arc>::FindNGrams(const fst::Fst<Arc> &ifst) {
133  namespace f = fst;
134 
135  if (ifst.Start() == f::kNoStateId) {
136  FSTERROR() << "NGramTopology: input FST has no states";
137  ofst_->SetProperties(f::kError, f::kError);
138  return;
139  }
140  if (!ifst.Properties(f::kAcceptor, true)) {
141  FSTERROR() << "NGramTopology: input FST not an acceptor";
142  ofst_->SetProperties(f::kError, f::kError);
143  return;
144  }
145 
146  // Finds n-grams up to order n in the input.
147  std::vector<Pair> queue;
148  std::unordered_set<Pair, PairHash> pairset;
149  DestMap dest_map;
150  Pair start_pair(ifst.Start(), initial_);
151  pairset.insert(start_pair);
152  queue.push_back(start_pair);
153  bool non_trivial_label = false;
154  while (!queue.empty()) {
155  Pair current_pair = queue.back();
156  auto fst_state = current_pair.first;
157  StateId ngram_state = current_pair.second;
158  queue.pop_back();
159  for (f::ArcIterator<f::Fst<Arc>> aiter(ifst, fst_state);
160  !aiter.Done(); aiter.Next()) {
161  const auto &arc = aiter.Value();
162  // Next pair will use backoff_ count state if failure arc
163  Pair next_pair(arc.nextstate, backoff_);
164  if (arc.ilabel && arc.ilabel != phi_label_) {
165  // Next pair uses n-gram destination state of non-special arc
166  next_pair.second = UpdateNGram(ngram_state, arc.ilabel, &dest_map);
167  non_trivial_label = true;
168  } else if (phi_label_ != 0 && arc.ilabel == 0) {
169  // Next pair uses n-gram source state for epsilon arc
170  next_pair.second = ngram_state;
171  }
172  auto iter = pairset.find(next_pair);
173  if (iter == pairset.end()) { // If new pair, inserts it.
174  pairset.insert(next_pair);
175  queue.push_back(next_pair);
176  }
177  }
178  if (ifst.Final(fst_state) != Weight::Zero())
179  UpdateFinalNGram(ngram_state);
180  }
181 
182 
183  // Input should have some non-trival labeled arcs
184  if (!non_trivial_label) {
185  FSTERROR() << "NGramTopology: input FST has no non-trivial path";
186  ofst_->SetProperties(f::kError, f::kError);
187  return;
188  }
189 
190  // Sets up backoff arcs and symbol tables.
191  for (StateId s = 0; s < states_.size(); ++s) {
192  if (states_[s].backoff_state != f::kNoStateId) {
193  ofst_->AddArc(s, Arc(phi_label_, phi_label_, Weight::One(),
194  states_[s].backoff_state));
195  }
196  }
197 
198  f::ArcSort(ofst_, f::ILabelCompare<Arc>());
199  ofst_->SetInputSymbols(ifst.InputSymbols());
200  ofst_->SetOutputSymbols(ifst.OutputSymbols());
201 }
202 
203 template <class Arc>
204 typename Arc::StateId NGramTopology<Arc>::UpdateNGram(
205  StateId state_id, Label label, DestMap *dest_map) {
206  namespace f = fst;
207 
208  // First determines if there already exists a corresponding arc.
209  Pair p(state_id, label);
210  auto iter = dest_map->find(p);
211  if (iter != dest_map->end())
212  return iter->second;
213 
214  // Otherwise, creates the arc
215  NGramState ngram_state = states_[state_id];
216 
217  Arc arc(label, label, Weight::One(), initial_);
218 
219  if (order_ != 1) {
220  // Finds the backed-off arc destination.
221  StateId backoff_dest = ngram_state.backoff_state == f::kNoStateId ?
222  f::kNoStateId :
223  UpdateNGram(ngram_state.backoff_state, label, dest_map);
224 
225  // Computes the current arc destination state.
226  if (ngram_state.order == order_) {
227  // The destination state is the destination of the backed-off arc.
228  arc.nextstate = backoff_dest;
229  } else {
230  // The destination state needs to be created.
231  arc.nextstate = ofst_->AddState();
232  NGramState next_ngram_state(
233  backoff_dest == f::kNoStateId ? backoff_ : backoff_dest,
234  ngram_state.order + 1);
235  states_.push_back(next_ngram_state);
236  }
237  }
238  ofst_->AddArc(state_id, arc);
239  (*dest_map)[p] = arc.nextstate;
240  return arc.nextstate;
241 }
242 
243 } // namespace sfst
244 
245 #endif // NLP_GRM2_SFST_TOPOLOGY_H_
Definition: perplexity.h:32
typename Arc::Label Label
Definition: topology.h:41
Definition: sfstinfo.cc:39
void FindNGrams(const fst::Fst< Arc > &ifst)
Definition: topology.h:132
NGramTopology(int order, Label phi_label, fst::MutableFst< Arc > *ofst)
Definition: topology.h:101
typename Arc::Weight Weight
Definition: topology.h:42
typename Arc::StateId StateId
Definition: topology.h:40