GRM-SFST  sfst-1.2.1
OpenGrm SFst Library
perplexity.h
Go to the documentation of this file.
1 // Copyright 2018-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // Computes perplexity for a stochastic FST.
15 
16 #ifndef NLP_GRM2_SFST_PERPLEXITY_H_
17 #define NLP_GRM2_SFST_PERPLEXITY_H_
18 
19 #include <cmath>
20 #include <cstddef>
21 #include <vector>
22 
23 #include <fst/log.h>
24 #include <fst/arc.h>
25 #include <fst/expectation-weight.h>
26 #include <fst/signed-log-weight.h>
27 #include <sfst/intersect.h>
28 #include <sfst/normalize.h>
29 #include <sfst/sfst.h>
30 #include <sfst/shortest-distance.h>
31 
32 namespace fst {
33 
34 // The 'entropy' semiring (p, e) is encoded as (-log p, e)
35 // over the log and real semirings resp.
36 using Entropy64Weight = ExpectationWeight<SignedLog64Weight,
37  Real64Weight>;
38 
39 // We define these operations to match our encoding.
40 
41 inline Real64Weight Times(SignedLog64Weight w1, Real64Weight w2) {
42  using Limits = fst::FloatLimits<double>;
43  if (w1 == SignedLog64Weight::Zero() && w2.Value() == Limits::PosInfinity())
44  return Real64Weight::Zero();
45  double s1 = w1.Value1().Value();
46  double l1 = w1.Value2().Value();
47  double p1 = s1 * exp(-l1);
48  double e = w2.Value();
49  return Real64Weight(p1 * e);
50 }
51 
52 inline Real64Weight Times(Real64Weight w1, SignedLog64Weight w2) {
53  using Limits = fst::FloatLimits<double>;
54  if (w2 == SignedLog64Weight::Zero() && w1.Value() == Limits::PosInfinity())
55  return Real64Weight::Zero();
56  double e = w1.Value();
57  double s2 = w2.Value1().Value();
58  double l2 = w2.Value2().Value();
59  double p2 = s2 * exp(-l2);
60  return Real64Weight(e * p2);
61 }
62 
63 // We use this to transform the weights on an SFST into the entropy semiring.
64 template <class Weight>
66  public:
67  using SLWeight = SignedLog64Weight;
68  using RWeight = Real64Weight;
70 
71  // How to transform an SFST's weights depends on how it will be used.
72  using EntropyType = enum {
73  ENTROPY, // for (self) entropy p in -\sum p log p
74  CROSS_ENTROPY_SOURCE, // for cross entropy p in - \sum p log q
75  CROSS_ENTROPY_TARGET // for cross entropy q in - \sum p log q
76  };
77 
78  explicit WeightTransform(EntropyType type = ENTROPY) : type_(type) { }
79 
80  EWeight operator()(const Weight &w) const {
81  if (w == Weight::Zero())
82  return EWeight::Zero();
83 
84  SLWeight slw = to_sl_(w);
85  RWeight rw = slw.Value2().Value();
86 
87  switch (type_) {
88  case ENTROPY:
89  return EWeight(slw, Times(slw, rw)); // (p, -p log p)
90  case CROSS_ENTROPY_SOURCE:
91  return EWeight(slw, RWeight::Zero()); // (p, 0)
92  default:
93  case CROSS_ENTROPY_TARGET:
94  return EWeight(SLWeight::One(), rw); // (1, -log q)
95  }
96  }
97 
98  private:
99  EntropyType type_;
100  fst::WeightConvert<Weight, SLWeight> to_sl_;
101 };
102 
103 // We use this to project onto the first component of the entropy weight.
104 template <>
105 struct WeightConvert<Entropy64Weight, Log64Weight> {
106  using LWeight = Log64Weight;
108 
109  LWeight operator()(const EWeight &w) const {
110  return w.Value1().Value2();
111  }
112 };
113 
114 // We use this to promote to an entropy weight (p, 0)
115 template <>
116 struct WeightConvert<Log64Weight, Entropy64Weight> {
117  using LWeight = Log64Weight;
118  using SLWeight = SignedLog64Weight;
119  using RWeight = Real64Weight;
121 
122  EWeight operator()(const LWeight &w) const {
123  SLWeight slw(1.0, w.Value());
124  return EWeight(slw, RWeight::Zero());
125  }
126 };
127 
128 // We need and have a ring (for ShortestDistance).
130  const SignedLog64Weight slz = SignedLog64Weight::Zero();
131  const Real64Weight rlz = Real64Weight::Zero();
132  return Plus(w1, Entropy64Weight(Minus(slz, w2.Value1()),
133  Minus(rlz, w2.Value2())));
134 }
135 
136 // Has a negative component?
137 inline bool IsNegative(Entropy64Weight w) {
138  using SLWeight = fst::SignedLog64Weight;
139  using RWeight = fst::Real64Weight;
140  return sfst::Less(w.Value1(), SLWeight::Zero()) ||
141  sfst::Less(w.Value2(), RWeight::Zero());
142 }
143 
145  public:
146  using SLWeight = SignedLog64Weight;
147  using RWeight = Real64Weight;
149 
150  explicit Entropy64WeightApproxEqual(float delta)
151  : slw_equal_(delta),
152  rw_equal_(delta) { }
153 
154  bool operator()(const EWeight &w1, const EWeight &w2) const {
155  return slw_equal_(w1.Value1(), w2.Value1()) &&
156  rw_equal_(w1.Value2(), w2.Value2());
157  }
158 
159  private:
161  WeightApproxEqual rw_equal_;
162 };
163 
164 }; // namespace fst
165 
166 
167 namespace sfst {
168 
169 // A float delta for SFST entropy/perplexity algorithms.
170 constexpr float kEntropyDelta = 1e-8;
171 
172 // Computes the cross perplexity for a stochastic target FST q given one or
173 // more source FSTs p. For a single source and target (and L(p) \subseteq L(q))
174 // the cross entropy between between the two FSTs i H = - \sum p log q
175 // and the perplexity (per symbol) is e^(H/l) where l is the expect path
176 // length. When multiple source p's are provided, average values are returned.
177 // If the target FST is not provided, self perplexity is computed.
178 // The inputs must all be normalized stochastic FSTs.
179 template <class Arc>
180 class Perplexity {
181  public:
182  using StateId = typename Arc::StateId;
183  using Weight = typename Arc::Weight;
184  using Label = typename Arc::Label;
185  using SLArc = fst::SignedLog64Arc;
186  using RWeight = fst::Real64Weight;
187  using SLWeight = fst::SignedLog64Weight;
188  using EArc = fst::ExpectationArc<SLArc, RWeight>;
192  using WCM = fst::WeightConvertMapper<Arc, EArc, WT>;
193  using WCM1 = fst::WeightConvertMapper<EArc, fst::Log64Arc>;
194 
195  // For computing cross perplexity.
196  explicit Perplexity(const fst::Fst<Arc> &fst,
197  Label phi_label = fst::kNoLabel,
198  Label unknown_label = fst::kNoLabel,
199  float delta = fst::kDelta,
200  float entropy_delta = kEntropyDelta)
201  : phi_label_(phi_label),
202  unknown_label_(unknown_label),
203  delta_(delta),
204  entropy_delta_(entropy_delta),
205  error_(false),
206  sent_count_(0),
207  oov_count_(0) {
208  SetTarget(fst);
209  }
210 
211  // Computes cross perplexity if SetTarget is called.
212  // O.w. self perplexity is computed.
213  explicit Perplexity(Label phi_label = fst::kNoLabel,
214  Label unknown_label = fst::kNoLabel,
215  float delta = fst::kDelta,
216  float entropy_delta = kEntropyDelta)
217  : phi_label_(phi_label),
218  unknown_label_(unknown_label),
219  delta_(delta),
220  entropy_delta_(entropy_delta),
221  error_(false),
222  sent_count_(0),
223  oov_count_(0) { }
224 
225 
226  // Sets target FST for cross-entropy computation. Invokes
227  // Reset. Called by the first constructor.
228  void SetTarget(const fst::Fst<Arc> &ifst);
229 
230  // Applies a source FST to the perplexity computation.
231  bool Apply(const fst::Fst<Arc> &fst);
232 
233  // Returns cross/self entropy per source FST.
234  double GetEntropy() const {
235  double te = GetTotalEntropy();
236  double sc = GetSourceCount();
237  return te / sc;
238  }
239 
240  // Returns perplexity per symbol.
241  double GetPerplexity() const {
242  double te = GetTotalEntropy();
243  double sc = GetTotalStateCount();
244  return exp(te / sc);
245  }
246 
247  // Returns total # of source FSTs since construction
248  // or last reset.
249  size_t NumSources() const { return sent_count_; }
250 
251  // Returns count of skipped whole or partial source FSTs.
252  // since construction or last reset. Paths can be skipped
253  // when L(p) \subsetneq L(q) and the count indicates the
254  // probability mass of the source p's that are skipped.
255  double SkipCount() const {
256  return sent_count_ - GetSourceCount();
257  }
258 
259  // Returns # of OOVs since construction or last reset (ignoring
260  // fully-skipped sources).
261  size_t NumOOVs() const { return oov_count_; }
262 
263  // Resets perplexity computation.
264  void Reset() {
265  entropy_.Reset();
266  state_count_.Reset();
267  sent_count_ = 0;
268  oov_count_ = 0;
269  }
270 
271  private:
272  // Finds in-vocabulary set.
273  void FindVocabSet(const fst::Fst<Arc> &sfst) {
274  namespace f = fst;
275  vocab_.clear();
276 
277  for (f::StateIterator<f::Fst<Arc>> siter(sfst);
278  !siter.Done();
279  siter.Next()) {
280  StateId s = siter.Value();
281  for (f::ArcIterator<f::Fst<Arc>> aiter(sfst, s);
282  !aiter.Done();
283  aiter.Next()) {
284  const Arc &arc = aiter.Value();
285  if (arc.ilabel != phi_label_ && arc.ilabel != 0)
286  vocab_.insert(arc.ilabel);
287  }
288  }
289  }
290 
291  // Converts a source FST to entropy semiring and converts OOVs to
292  // unknown_label when defined.
293  void PrepareSource(const fst::Fst<Arc> &ifst,
294  fst::MutableFst<EArc> *ofst);
295 
296  // Self or cross entropy being computed?
297  bool IsSelfEntropy() const {
298  namespace f = fst;
299  return qfst_.Start() == f::kNoStateId;
300  }
301 
302  // Returns accumulated cross/self entropy since construction
303  // or last reset.
304  double GetTotalEntropy() const {
305  return entropy_.Sum().Value2().Value();
306  }
307 
308  // Returns count of partial or whole source FSTs used
309  // in the entropy computation since construction or last
310  // reset. This can be fractional if L(p) \subsetneq L(q). This
311  // is used to normalize the total entropy to a per source FST
312  // entropy.
313  double GetSourceCount() const {
314  double sign_count = entropy_.Sum().Value1().Value1().Value();
315  double mag_count = entropy_.Sum().Value1().Value2().Value();
316  return sign_count * exp(-mag_count);
317  }
318 
319  // Returns the accumulated state count mass since construction or last
320  // reset. This is equal to the expected length * source count, which
321  // is the normalizaton factor needed for perplexity. Note that the expected
322  // length of each accepted string includes the super-final label.
323  double GetTotalStateCount() const {
324  double sign_count = state_count_.Sum().Value1().Value();
325  double mag_count = state_count_.Sum().Value2().Value();
326  return sign_count * exp(-mag_count);
327  }
328 
329  Label phi_label_;
330  Label unknown_label_;
331  float delta_;
332  float entropy_delta_;
333  bool error_;
334  fst::VectorFst<EArc> qfst_;
335  fst::Adder<EWeight> entropy_;
336  fst::Adder<SLWeight> state_count_;
337  size_t sent_count_;
338  size_t oov_count_;
339  size_t fst_oov_count_;
340  std::unordered_set<Label> vocab_;
341 
342  Perplexity(const Perplexity &) = delete;
343  Perplexity &operator=(const Perplexity &) = delete;
344 };
345 
346 template <class Arc>
347 void Perplexity<Arc>::SetTarget(const fst::Fst<Arc> &ifst) {
348  namespace f = fst;
349 
350  if (ifst.Start() == f::kNoLabel) {
351  LOG(ERROR) << "Perplexity: target FST has no states";
352  error_ = true;
353  return;
354  }
355 
356  if (!IsNormalized(ifst, phi_label_, delta_)) {
357  LOG(ERROR) << "Perplexity: target is not a normalized stochastic FST";
358  error_ = true;
359  return;
360  }
361 
362  WT to_e(WT::CROSS_ENTROPY_TARGET);
363  WCM wc_mapper(to_e);
364  f::ArcMap(ifst, &qfst_, wc_mapper);
365 
366  if (unknown_label_ != f::kNoLabel)
367  FindVocabSet(ifst);
368 
369  Reset();
370 }
371 
372 // Converts FST to entropy semiring and converts OOVs to unknown_label
373 // when defined.
374 template <class Arc>
375 void Perplexity<Arc>::PrepareSource(const fst::Fst<Arc> &ifst,
376  fst::MutableFst<EArc> *ofst) {
377  namespace f = fst;
378 
379  if (!IsNormalized(ifst, phi_label_, delta_)) {
380  LOG(ERROR) << "Perplexity: source (" << sent_count_
381  << ") is not a normalized stochastic FST";
382  error_ = true;
383  return;
384  }
385 
386  const auto etype =
387  (IsSelfEntropy() ? WT::ENTROPY : WT::CROSS_ENTROPY_SOURCE);
388 
389  WT to_e(etype);
390  WCM wc_mapper(to_e);
391  f::ArcMap(ifst, ofst, wc_mapper);
392  fst_oov_count_ = 0;
393 
394  if (unknown_label_ == f::kNoLabel || IsSelfEntropy())
395  return;
396 
397  for (StateId s = 0; s < ofst->NumStates(); ++s) {
398  for (f::MutableArcIterator<f::MutableFst<EArc> > aiter(ofst, s);
399  !aiter.Done();
400  aiter.Next()) {
401  EArc arc = aiter.Value();
402  if (arc.ilabel != 0) {
403  if (arc.ilabel == unknown_label_ || vocab_.count(arc.ilabel) == 0) {
404  arc.ilabel = arc.olabel = unknown_label_;
405  aiter.SetValue(arc);
406  ++fst_oov_count_;
407  }
408  }
409  }
410  }
411 }
412 
413 template <class Arc>
414 bool Perplexity<Arc>::Apply(const fst::Fst<Arc> &fst) {
415  namespace f = fst;
416  if (error_)
417  return false;
418 
419  f::VectorFst<EArc> pfst, plogq_fst;
420  if (IsSelfEntropy()) {
421  PrepareSource(fst, &plogq_fst);
422  } else {
423  PrepareSource(fst, &pfst);
424  Intersect(pfst, qfst_, &plogq_fst, phi_label_, true);
425  }
426 
427  sent_count_ += 1;
428  std::vector<EWeight> distance;
429  EWeight entropy = ShortestDistance<EArc, EArc, WEq>(
430  plogq_fst, &distance, phi_label_, false, entropy_delta_);
431 
432  // To address a corner case where higher-indexed states are effectively
433  // unreachable (very high cost). This ensures that all states will have
434  // distance populated.
435  distance.resize(plogq_fst.NumStates(), EWeight::Zero());
436 
437  if (entropy.Member() && entropy != EWeight::Zero()) {
438  entropy_.Add(entropy);
439  oov_count_ += fst_oov_count_;
440  // This removes the incoming failure mass to a state, which we
441  // do not want to count here. This is equivalent to having
442  // done the shortest distance computation of the (explicitly)
443  // phi-removed automaton.
444  DiffStateWeights(plogq_fst, &distance, phi_label_, true);
445  for (auto w : distance)
446  state_count_.Add(w.Value1());
447  }
448  return true;
449 }
450 
451 } // namespace sfst
452 
453 #endif // NLP_GRM2_SFST_PERPLEXITY_H_
fst::Entropy64Weight EWeight
Definition: perplexity.h:189
SignedLog64Weight SLWeight
Definition: perplexity.h:146
bool Less(fst::LogWeightTpl< T > weight1, fst::LogWeightTpl< T > weight2)
Definition: sfst.h:38
fst::Real64Weight RWeight
Definition: perplexity.h:186
size_t NumSources() const
Definition: perplexity.h:249
Entropy64WeightApproxEqual(float delta)
Definition: perplexity.h:150
double GetPerplexity() const
Definition: perplexity.h:241
void DiffStateWeights(const fst::Fst< Arc > &fst, std::vector< typename Arc::Weight > *weights, typename Arc::Label phi_label, bool fail_arc)
Definition: perplexity.h:32
Perplexity(Label phi_label=fst::kNoLabel, Label unknown_label=fst::kNoLabel, float delta=fst::kDelta, float entropy_delta=kEntropyDelta)
Definition: perplexity.h:213
fst::WeightConvertMapper< Arc, EArc, WT > WCM
Definition: perplexity.h:192
Entropy64Weight Minus(Entropy64Weight w1, Entropy64Weight w2)
Definition: perplexity.h:129
bool Intersect(const fst::Fst< Arc > &ifst1, const fst::Fst< Arc > &ifst2, fst::MutableFst< Arc > *ofst, typename Arc::Label phi_label=fst::kNoLabel, bool trim=true, TrimType trim_type=TRIM_NEEDED_FINAL)
Definition: intersect.h:33
Perplexity(const fst::Fst< Arc > &fst, Label phi_label=fst::kNoLabel, Label unknown_label=fst::kNoLabel, float delta=fst::kDelta, float entropy_delta=kEntropyDelta)
Definition: perplexity.h:196
Definition: sfstinfo.cc:39
size_t NumOOVs() const
Definition: perplexity.h:261
EWeight operator()(const Weight &w) const
Definition: perplexity.h:80
typename Arc::StateId StateId
Definition: perplexity.h:182
fst::SignedLog64Weight SLWeight
Definition: perplexity.h:187
EWeight operator()(const LWeight &w) const
Definition: perplexity.h:122
typename Arc::Label Label
Definition: perplexity.h:184
Real64Weight RWeight
Definition: perplexity.h:68
fst::SignedLog64Arc SLArc
Definition: perplexity.h:185
bool IsNormalized(const fst::Fst< Arc > &fst, typename Arc::Label phi_label=fst::kNoLabel, float delta=fst::kDelta)
Definition: normalize.h:159
double GetEntropy() const
Definition: perplexity.h:234
SignedLog64Weight SLWeight
Definition: perplexity.h:67
bool operator()(const EWeight &w1, const EWeight &w2) const
Definition: perplexity.h:154
ExpectationWeight< SignedLog64Weight, Real64Weight > Entropy64Weight
Definition: perplexity.h:37
constexpr float kEntropyDelta
Definition: perplexity.h:170
WeightTransform(EntropyType type=ENTROPY)
Definition: perplexity.h:78
LWeight operator()(const EWeight &w) const
Definition: perplexity.h:109
bool IsNegative(Entropy64Weight w)
Definition: perplexity.h:137
fst::WeightConvertMapper< EArc, fst::Log64Arc > WCM1
Definition: perplexity.h:193
Entropy64Weight EWeight
Definition: perplexity.h:69
double SkipCount() const
Definition: perplexity.h:255
enum{ENTROPY, CROSS_ENTROPY_SOURCE, CROSS_ENTROPY_TARGET} EntropyType
Definition: perplexity.h:76
typename Arc::Weight Weight
Definition: perplexity.h:183
Real64Weight Times(SignedLog64Weight w1, Real64Weight w2)
Definition: perplexity.h:41
fst::ExpectationArc< SLArc, RWeight > EArc
Definition: perplexity.h:188