GRM-SFST  sfst-1.2.1
OpenGrm SFst Library
sfst.h
Go to the documentation of this file.
1 // Copyright 2018-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // Common definitions for stochastic FSTs.
15 
16 #include <sys/types.h>
17 
18 #include <cmath>
19 #include <cstddef>
20 #include <cstdlib>
21 #include <vector>
22 
23 #include <fst/float-weight.h>
24 #include <fst/fst.h>
25 #include <fst/matcher.h>
26 #include <fst/signed-log-weight.h>
27 
28 #ifndef NLP_GRM2_SFST_SFST_H_
29 #define NLP_GRM2_SFST_SFST_H_
30 
31 namespace sfst {
32 
33 // Matches ngram library choice
34 const fst::Log64Weight kApproxZeroWeight = 99.0;
35 
36 // Order w.r.t. probability: exp(-weight)
37 template <class T>
38 inline bool Less(fst::LogWeightTpl<T> weight1,
39  fst::LogWeightTpl<T> weight2) {
40  return weight1.Value() > weight2.Value();
41 }
42 
43 // Order w.r.t. probability: exp(-weight) for the Tropical weight.
44 template <class T>
45 inline bool Less(fst::TropicalWeightTpl<T> weight1,
46  fst::TropicalWeightTpl<T> weight2) {
47  return weight1.Value() > weight2.Value();
48 }
49 
50 // Order w.r.t. probability: weight
51 template <class T>
52 inline bool Less(fst::RealWeightTpl<T> weight1,
53  fst::RealWeightTpl<T> weight2) {
54  return weight1.Value() < weight2.Value();
55 }
56 
57 template <class T>
58 inline bool Less(fst::SignedLogWeightTpl<T> weight1,
59  fst::SignedLogWeightTpl<T> weight2) {
60  bool s1 = weight1.Value1().Value() > 0.0;
61  bool s2 = weight2.Value1().Value() > 0.0;
62 
63  if (!s1 && s2) {
64  return true;
65  } else if (s1 && !s2) {
66  return false;
67  } else if (s1 && s2) {
68  return Less(weight1.Value2(), weight2.Value2());
69  } else {
70  return Less(weight2.Value2(), weight1.Value2());
71  }
72 }
73 
74 template <class Weight>
75 bool LessOrEqual(Weight w1, Weight w2) {
76  return Less(w1, w2) || w1 == w2;
77 }
78 
79 inline bool IsNegative(fst::SignedLog64Weight w) {
80  using SLWeight = fst::SignedLog64Weight;
81  return Less(w, SLWeight::Zero());
82 }
83 
84 inline bool ApproxZero(
85  fst::Log64Weight weight,
86  fst::Log64Weight approx_zero = kApproxZeroWeight) {
87  return LessOrEqual(weight, approx_zero);
88 }
89 
90 inline bool ApproxZero(
91  fst::SignedLog64Weight weight,
92  fst::Log64Weight pos_approx_zero = kApproxZeroWeight,
93  fst::Log64Weight neg_approx_zero = 10.0) {
94  if (weight.Value1().Value() > 0.0) {
95  return LessOrEqual(weight.Value2(), pos_approx_zero);
96  } else {
97  return LessOrEqual(weight.Value2(), neg_approx_zero);
98  }
99 }
100 
101 
102 // Compares w.r.t. exponentiated values ('probabilities' vs
103 // '- log probabilities'). Appropriate from SignedLog(64) weights.
104 template <class Weight>
106  public:
107  explicit SignedLogWeightApproxEqual(float delta) : delta_(delta) {}
108 
109  bool operator()(const Weight &w1, const Weight &w2) const {
110  double sgn1 = w1.Value1().Value();
111  double sgn2 = w2.Value1().Value();
112  double val1 = w1.Value2().Value();
113  double val2 = w2.Value2().Value();
114  double exp1 = sgn1 * std::exp(-val1);
115  double exp2 = sgn2 * std::exp(-val2);
116 
117  return std::abs(exp1 - exp2) < delta_;
118  }
119 
120  private:
121  const float delta_;
122 };
123 
124 // Class to get information about the failure path leaving a state.
125 // Assumes (but does not check) input FST has a canonical topology
126 // (see canonical.h) for a stochastic FST (when match_input = true,
127 // o.w. fst^-1 is assumed canonical).
128 template <class Arc>
129 class FailurePath {
130  public:
131  typedef typename Arc::StateId StateId;
132  typedef typename Arc::Label Label;
133  typedef typename Arc::Weight Weight;
134  typedef fst::ExplicitMatcher<fst::SortedMatcher<fst::Fst<Arc>>>
136 
137  FailurePath(const fst::Fst<Arc> &fst, Label phi_label, bool match_input)
138  : fst_(fst),
139  phi_label_(phi_label),
140  matcher_(fst,
141  match_input ? fst::MATCH_INPUT : fst::MATCH_OUTPUT),
142  s_(fst::kNoStateId) {}
143 
144  // Sets the current state.
145  void SetState(StateId s);
146 
147  // Length of the failure path from current state.
148  // Same as the state order - 1.
149  size_t Length() const { return faildest_.size(); }
150 
151  // Destination state of the ith transition on the current failure path.
152  StateId GetNextState(size_t i) const { return faildest_[i]; }
153  // Weight of the ith transition on the current failure path.
154  Weight GetWeight(size_t i) const { return failweight_[i]; }
155  // Arc position of the ith transition on the current failure path.
156  size_t GetPosition(size_t i) const { return failpos_[i]; }
157 
158  private:
159  // Finds failure arc for a state and returns the arc and arc position.
160  // If no failure arc, uses (kNoLabel, kNoLabel, Zero(), kNoStateId)
161  // and position -1.
162  ssize_t GetFailureArc(StateId s, Arc *arc);
163 
164  const fst::Fst<Arc> &fst_;
165  Label phi_label_;
166  Matr matcher_;
167 
168  StateId s_;
169 
170  std::vector<StateId> faildest_; // phi destination states
171  std::vector<Weight> failweight_; // phi weights
172  std::vector<size_t> failpos_; // phi positions
173 
174  FailurePath(const FailurePath &) = delete;
175  FailurePath &operator=(const FailurePath &) = delete;
176 };
177 
178 template <class Arc>
180  namespace f = fst;
181  if (s == s_) // Already setup
182  return;
183 
184  s_ = s;
185  faildest_.clear();
186  failweight_.clear();
187  failpos_.clear();
188  Arc failarc;
189  for (StateId r = s;; r = failarc.nextstate) {
190  ssize_t pos = GetFailureArc(r, &failarc);
191  if (failarc.nextstate == f::kNoStateId) break;
192  faildest_.push_back(failarc.nextstate);
193  failweight_.push_back(failarc.weight);
194  failpos_.push_back(pos);
195  }
196 }
197 
198 template <class Arc>
199 ssize_t FailurePath<Arc>::GetFailureArc(StateId s, Arc *failarc) {
200  namespace f = fst;
201  *failarc = Arc(f::kNoLabel, f::kNoLabel, Weight::Zero(), f::kNoStateId);
202  ssize_t pos = -1;
203  if (phi_label_ == f::kNoLabel) return pos;
204 
205  matcher_.SetState(s);
206  if (matcher_.Find(phi_label_)) {
207  *failarc = matcher_.Value();
208  pos = matcher_.GetMatcher()->Position();
209  }
210  return pos;
211 }
212 
213 } // namespace sfst
214 
215 #endif // NLP_GRM2_SFST_SFST_H_
bool Less(fst::LogWeightTpl< T > weight1, fst::LogWeightTpl< T > weight2)
Definition: sfst.h:38
size_t Length() const
Definition: sfst.h:149
Definition: perplexity.h:32
Weight GetWeight(size_t i) const
Definition: sfst.h:154
Arc::Weight Weight
Definition: sfst.h:133
fst::ExplicitMatcher< fst::SortedMatcher< fst::Fst< Arc > > > Matr
Definition: sfst.h:135
bool IsNegative(fst::SignedLog64Weight w)
Definition: sfst.h:79
Definition: sfstinfo.cc:39
bool ApproxZero(fst::Log64Weight weight, fst::Log64Weight approx_zero=kApproxZeroWeight)
Definition: sfst.h:84
Arc::Label Label
Definition: sfst.h:132
bool LessOrEqual(Weight w1, Weight w2)
Definition: sfst.h:75
FailurePath(const fst::Fst< Arc > &fst, Label phi_label, bool match_input)
Definition: sfst.h:137
StateId GetNextState(size_t i) const
Definition: sfst.h:152
size_t GetPosition(size_t i) const
Definition: sfst.h:156
const fst::Log64Weight kApproxZeroWeight
Definition: sfst.h:34
void SetState(StateId s)
Definition: sfst.h:179
bool operator()(const Weight &w1, const Weight &w2) const
Definition: sfst.h:109
SignedLogWeightApproxEqual(float delta)
Definition: sfst.h:107
Arc::StateId StateId
Definition: sfst.h:131