GRM-SFST  sfst-1.2.1
OpenGrm SFst Library
backoff.h
Go to the documentation of this file.
1 // Copyright 2018-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // Algorithms for backoff automata.
15 
16 #ifndef NLP_GRM2_SFST_BACKOFF_H_
17 #define NLP_GRM2_SFST_BACKOFF_H_
18 
19 #include <sys/types.h>
20 
21 #include <cstddef>
22 #include <utility>
23 #include <vector>
24 
25 #include <fst/log.h>
26 #include <fst/float-weight.h>
27 #include <fst/fst.h>
28 #include <fst/matcher.h>
29 #include <fst/util.h>
30 #include <sfst/canonical.h>
31 #include <sfst/sfst.h>
32 
33 namespace sfst {
34 
35 // A backoff-complete transducer is a canonical SFST (see canonical.h) that
36 // additionally has the properties:
37 //
38 // (1) it is (input) deterministic and (non-phi) epsilon-free
39 //
40 // (2) if there is a failure transition from state q to 'backoff
41 // state' q', then every transition from q has a 'backed-off'
42 // transition from q'. I.e. if a transition from q is (input) labeled
43 // with x, then there is a transition from q' also (input) labeled
44 // with x.
45 //
46 // This class provides convenient access to the backoff states and
47 // backed off arcs.
48 template <class Arc>
49 class Backoff {
50  public:
51  using StateId = typename Arc::StateId;
52  using Label = typename Arc::Label;
53  using Weight = typename Arc::Weight;
54  using Matr = fst::SortedMatcher<fst::Fst<Arc>>;
55 
56  // Input is required to be a canonical SFST. It must also be a
57  // backoff-complete SFST unless require_backoff_complete is false.
58  Backoff(const fst::Fst<Arc> &fst, Label phi_label,
59  bool require_backoff_complete = true);
60 
61  // Returns the number of states in the input.
62  int NumStates() const { return states_.size(); }
63 
64  // The maximum number of states (length + 1) in a failure path.
65  int MaxOrder() const { return order_; }
66 
67  // The number of states (length + 1) in a failure path from
68  // that state.
69  int Order(StateId s) const { return states_[s].order; }
70 
71  // As returned by canonical.h::PhiTopOrder().
72  StateId GetPhiTopOrder(StateId i) const { return top_order_[i]; }
73 
74  // Gives the state that is failed to from s (or fst::kNoStateId).
75  StateId GetBackoffState(StateId s) const { return states_[s].backoff_state; }
76 
77  // Gives the position of the failure arc from s (or -1).
79  return states_[s].backoff_position;
80  }
81 
82  // Returns the position at the 'backed-off' arc for the arc
83  // at position pos from state s. Returns -1 if none (e.g. for the
84  // phi arc).
85  ssize_t GetBackedOffArc(StateId s, ssize_t pos) const {
86  auto iter = arc_map_.find(std::make_pair(s, pos));
87  return iter != arc_map_.end() ? iter->second : -1;
88  }
89 
90  // Returns true if the input is a backoff-complete FST.
91  bool IsBackoffComplete() const { return is_backoff_; }
92 
93  // Returns true if in a bad state.
94  bool Error() const { return error_; }
95 
96  private:
97  // Locates the backoff states and backed-off arcs.
98  void FindBackedOffArcs(StateId s);
99 
100  void SetError() const { error_ = true; }
101 
102  // Backoff state data
103  struct BackoffState {
104  StateId backoff_state; // backoff state ID for the referenced state
105  ssize_t backoff_position; // arc position of backoff arc
106  int order; // order of the referenced state
107 
108  BackoffState(StateId s, ssize_t p, int o)
109  : backoff_state(s), backoff_position(p), order(o) {}
110 
111  BackoffState()
112  : backoff_state(fst::kNoStateId), backoff_position(-1), order(1) {}
113  };
114 
115  // (state, arc position)
116  using Pair = std::pair<StateId, ssize_t>;
117 
118  struct PairHash {
119  size_t operator()(const Pair &p) const {
120  return (static_cast<size_t>(p.first) * 55697) ^
121  (static_cast<size_t>(p.second) * 54631);
122  }
123  };
124 
125  using PairArcMap = std::unordered_map<Pair, ssize_t, PairHash>;
126 
127  const fst::Fst<Arc> &fst_; // input FST
128  Label phi_label_; // failure label
129  int order_; // maximal phi-order
130  std::vector<StateId> top_order_; //
131  std::vector<BackoffState> states_; // maps from state ID to backoff info
132  PairArcMap arc_map_; // maps from state id and position to
133  // ... backed-off arc position
134  bool require_backoff_complete_;
135  mutable bool is_backoff_;
136  mutable bool error_;
137 
138  Backoff(const Backoff &) = delete;
139  Backoff &operator=(const Backoff &) = delete;
140 };
141 
142 template <class Arc>
143 Backoff<Arc>::Backoff(const fst::Fst<Arc> &fst, Label phi_label,
144  bool require_backoff_complete)
145  : fst_(fst),
146  phi_label_(phi_label),
147  order_(1),
148  require_backoff_complete_(require_backoff_complete),
149  is_backoff_(true),
150  error_(false) {
151  namespace f = fst;
152 
153  if (fst_.Start() == f::kNoStateId) {
154  is_backoff_ = false;
155  if (require_backoff_complete_) {
156  FSTERROR() << "Backoff: FST has no states";
157  SetError();
158  }
159  return;
160  }
161 
162  if (!fst.Properties(f::kILabelSorted, true)) {
163  FSTERROR() << "Backoff: FST is not canonical (ilabel-sorted)";
164  SetError();
165  return;
166  }
167 
168  bool acyclic = PhiTopOrder(fst_, phi_label_, &top_order_);
169  if (!acyclic) {
170  FSTERROR() << "Backoff: FST is not canonical (phi-cyclic)";
171  SetError();
172  return;
173  }
174 
175  states_.resize(top_order_.size());
176 
177  if (phi_label_ == f::kNoStateId) return;
178 
179  for (StateId i = top_order_.size() - 1; i >= 0; --i) {
180  StateId s = top_order_[i]; // ith state in reverse phi-top order
181  BackoffState &bo_state = states_[s];
182  Matr matcher(fst_, f::MATCH_INPUT);
183  matcher.SetState(s);
184  if (matcher.Find(phi_label_)) {
185  for (; !matcher.Done(); matcher.Next()) {
186  const Arc &arc = matcher.Value();
187  // Continues on implicit match.
188  if (arc.ilabel == f::kNoLabel) continue;
189  // Error if phi non-determinism.
190  if (bo_state.backoff_state != f::kNoStateId) {
191  FSTERROR() << "Backoff: FST is not canonical (phi-non-determinism)";
192  SetError();
193  continue;
194  }
195  bo_state.backoff_state = arc.nextstate;
196  bo_state.backoff_position = matcher.Position();
197  bo_state.order = states_[bo_state.backoff_state].order + 1;
198  if (bo_state.order > order_) order_ = bo_state.order;
199  FindBackedOffArcs(s);
200  }
201  }
202  }
203 }
204 
205 template <class Arc>
207  namespace f = fst;
208  Matr matcher(fst_, f::MATCH_INPUT);
209  matcher.SetState(states_[s].backoff_state);
210  for (f::ArcIterator<f::Fst<Arc>> aiter(fst_, s); !aiter.Done();
211  aiter.Next()) {
212  const Arc &arc = aiter.Value();
213  if (arc.ilabel == phi_label_) continue;
214  if (arc.ilabel == 0) {
215  is_backoff_ = false;
216  if (require_backoff_complete_) {
217  FSTERROR() << "Backoff: non-failure epsilons disallowed";
218  SetError();
219  return;
220  }
221  continue;
222  }
223  if (matcher.Find(arc.ilabel)) {
224  Pair p(s, aiter.Position());
225  arc_map_[p] = matcher.Position();
226  } else {
227  is_backoff_ = false;
228  if (require_backoff_complete_) {
229  FSTERROR() << "Backoff: no backed-off arc with label " << arc.ilabel
230  << " from state " << s;
231  SetError();
232  return;
233  }
234  }
235  }
236 }
237 
238 // Tests that the input is a backoff SFST (see above).
239 template <class Arc>
240 bool IsBackoffComplete(const fst::Fst<Arc> &fst,
241  typename Arc::Label phi_label) {
242  Backoff<Arc> backoff(fst, phi_label, false);
243  return backoff.IsBackoffComplete();
244 }
245 
246 // 'Phi-sums' a backoff-complete SFST: adds the higher-order arc weights
247 // of a backoff-complete automaton onto lower-order backed-off transitions
248 // Returns true on success.
249 template <class Arc, class CompWeight = fst::Log64Weight>
250 bool SumBackoff(fst::MutableFst<Arc> *fst, typename Arc::Label phi_label) {
251  namespace f = fst;
252  using Weight = typename Arc::Weight;
253  using StateId = typename Arc::StateId;
254  fst::WeightConvert<CompWeight, Weight> from_compute_weight;
255  fst::WeightConvert<Weight, CompWeight> to_compute_weight;
256 
257  Backoff<Arc> backoff(*fst, phi_label);
258  for (StateId i = 0; i < fst->NumStates(); ++i) {
259  // ith state in phi-top order
260  StateId s = backoff.GetPhiTopOrder(i);
261  StateId bos = backoff.GetBackoffState(s);
262  if (bos == f::kNoStateId) continue;
263  Weight final = fst->Final(s);
264  if (final != Weight::Zero()) {
265  Weight bo_final = fst->Final(bos);
266  CompWeight w1 = to_compute_weight(bo_final);
267  CompWeight w2 = to_compute_weight(final);
268  fst->SetFinal(bos, from_compute_weight(Plus(w1, w2)));
269  }
270  f::MutableArcIterator<f::MutableFst<Arc>> miter(fst, bos);
271  for (f::ArcIterator<f::Fst<Arc>> aiter(*fst, s); !aiter.Done();
272  aiter.Next()) {
273  const Arc &arc = aiter.Value();
274  ssize_t pos = aiter.Position();
275  if (arc.ilabel == phi_label) continue;
276  ssize_t bo_pos = backoff.GetBackedOffArc(s, pos);
277  CHECK_NE(bo_pos, -1);
278  miter.Seek(bo_pos);
279  Arc bo_arc = miter.Value();
280  CompWeight w1 = to_compute_weight(bo_arc.weight);
281  CompWeight w2 = to_compute_weight(arc.weight);
282  bo_arc.weight = from_compute_weight(Plus(w1, w2));
283  miter.SetValue(bo_arc);
284  }
285  }
286 
287  return !backoff.Error();
288 }
289 
290 // Undoes the 'phi-summation' of a backoff-complete SFST: subtracts the
291 // higher-order arc weights a backoff-complete automaton from lower-order
292 // backed-off transitions The bo_zero weight is used for effectively
293 // zero backoffed-off weights. This should be non-Zero() if a
294 // returned backoff-complete topology is to be ensured (cf. super-final
295 // weights). Returns true on success.
296 template <class Arc>
297 bool DiffBackoff(fst::MutableFst<Arc> *fst, typename Arc::Label phi_label,
298  typename Arc::Weight bo_zero = Arc::Weight::Zero()) {
299  namespace f = fst;
300  using Weight = typename Arc::Weight;
301  using StateId = typename Arc::StateId;
302  f::WeightConvert<f::Log64Weight, Weight> from_log;
303  f::WeightConvert<Weight, f::Log64Weight> to_log;
304 
305  Backoff<Arc> backoff(*fst, phi_label);
306  for (StateId i = fst->NumStates() - 1; i >= 0; --i) {
307  // ith state in reverse phi-top order
308  StateId s = backoff.GetPhiTopOrder(i);
309  StateId bos = backoff.GetBackoffState(s);
310  if (bos == f::kNoStateId) continue;
311  Weight final = fst->Final(s);
312  if (final != Weight::Zero()) {
313  Weight bo_final = fst->Final(bos);
314  f::Log64Weight w1 = to_log(bo_final);
315  f::Log64Weight w2 = to_log(final);
316  fst->SetFinal(bos, Less(w2, w1) ? from_log(Minus(w1, w2)) : bo_zero);
317  }
318  f::MutableArcIterator<f::MutableFst<Arc>> miter(fst, bos);
319  for (f::ArcIterator<f::Fst<Arc>> aiter(*fst, s); !aiter.Done();
320  aiter.Next()) {
321  const Arc &arc = aiter.Value();
322  ssize_t pos = aiter.Position();
323  if (arc.ilabel == phi_label) continue;
324  ssize_t bo_pos = backoff.GetBackedOffArc(s, pos);
325  CHECK_NE(bo_pos, -1);
326  miter.Seek(bo_pos);
327  Arc bo_arc = miter.Value();
328  f::Log64Weight w1 = to_log(bo_arc.weight);
329  f::Log64Weight w2 = to_log(arc.weight);
330  bo_arc.weight = Less(w2, w1) ? from_log(Minus(w1, w2)) : bo_zero;
331  miter.SetValue(bo_arc);
332  }
333  }
334 
335  return !backoff.Error();
336 }
337 
338 } // namespace sfst
339 
340 #endif // NLP_GRM2_SFST_BACKOFF_H_
bool Less(fst::LogWeightTpl< T > weight1, fst::LogWeightTpl< T > weight2)
Definition: sfst.h:38
typename Arc::Label Label
Definition: backoff.h:52
Definition: perplexity.h:32
bool PhiTopOrder(const fst::Fst< Arc > &fst, typename Arc::Label phi_label, std::vector< typename Arc::StateId > *top_order)
Definition: canonical.h:37
Backoff(const fst::Fst< Arc > &fst, Label phi_label, bool require_backoff_complete=true)
Definition: backoff.h:143
Entropy64Weight Minus(Entropy64Weight w1, Entropy64Weight w2)
Definition: perplexity.h:129
Definition: sfstinfo.cc:39
bool SumBackoff(fst::MutableFst< Arc > *fst, typename Arc::Label phi_label)
Definition: backoff.h:250
typename Arc::Weight Weight
Definition: backoff.h:53
StateId GetBackoffState(StateId s) const
Definition: backoff.h:75
bool IsBackoffComplete() const
Definition: backoff.h:91
int MaxOrder() const
Definition: backoff.h:65
StateId GetBackoffPosition(StateId s) const
Definition: backoff.h:78
typename Arc::StateId StateId
Definition: backoff.h:51
ssize_t GetBackedOffArc(StateId s, ssize_t pos) const
Definition: backoff.h:85
StateId GetPhiTopOrder(StateId i) const
Definition: backoff.h:72
int NumStates() const
Definition: backoff.h:62
int Order(StateId s) const
Definition: backoff.h:69
bool DiffBackoff(fst::MutableFst< Arc > *fst, typename Arc::Label phi_label, typename Arc::Weight bo_zero=Arc::Weight::Zero())
Definition: backoff.h:297
bool Error() const
Definition: backoff.h:94
fst::SortedMatcher< fst::Fst< Arc >> Matr
Definition: backoff.h:54