GRM-SFST  sfst-1.2.1
OpenGrm SFst Library
count.h
Go to the documentation of this file.
1 // Copyright 2018-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // Algorithm to count from a stochastic FST w.r.t. a specified FST topology.
15 
16 #ifndef NLP_GRM2_SFST_COUNT_H_
17 #define NLP_GRM2_SFST_COUNT_H_
18 
19 #include <sys/types.h>
20 
21 #include <cstdint>
22 #include <vector>
23 
24 #include <fst/log.h>
25 #include <fst/arc.h>
26 #include <fst/arcsort.h>
27 #include <fst/compose.h>
28 #include <fst/expanded-fst.h>
29 #include <fst/float-weight.h>
30 #include <fst/fst.h>
31 #include <fst/matcher.h>
32 #include <fst/mutable-fst.h>
33 #include <fst/properties.h>
34 #include <fst/signed-log-weight.h>
35 #include <fst/util.h>
36 #include <fst/weight.h>
37 #include <sfst/canonical.h>
38 #include <sfst/phi2matcher.h>
39 #include <sfst/rmphi.h>
40 #include <sfst/sfst.h>
41 #include <sfst/shortest-distance.h>
42 
43 namespace sfst {
44 
45 // Counter class for counting from a stochastic FSA w.r.t a specified FSA
46 // topology. Computes the counts C(x,q) with x \in L[q] as described in
47 // Suresh, Roark, Riley, Schogol, "Approximating probabilistic models
48 // as weighted automata" (experimental/fst/papers/approx/approx.pdf)
49 template <class Arc>
50 class Counter {
51  public:
52  using StateId = typename Arc::StateId;
53  using Label = typename Arc::Label;
54  using Weight = typename Arc::Weight;
55  using SLArc = fst::SignedLog64Arc;
56  using SLStateId = fst::SignedLog64Arc::StateId;
57  using SLWeight = fst::SignedLog64Weight;
58 
59  // The 'phi_label' is the failure label (fst::kNoLabel ->
60  // none). The 'delta' parameter controls the degree
61  // of algorithm convergence. The topology FST is passed in and the
62  // result is returned in 'ofst'. This FST will have counts placed on
63  // it. It should be an epsilon-free (if phi_label != 0),
64  // deterministic, canonical SFSA (see canonical.h); these
65  // requirements are not fully checked. Any incoming arc weights on
66  // the topology FST are removed.
67  Counter(Label phi_label, float delta, fst::MutableFst<Arc> *ofst);
68 
69  // Provides an FST to be counted; may be called repeatedly. Assumes
70  // (but does not fully check) the input is a canonical stochastic
71  // FSA (see canonical.h). If the input is cyclic, the SFSA should
72  // be a normalized (see normalize.h). Also assumes input has no
73  // (non-phi) epsilons (or treats such epsilons w.r.t. the failure
74  // semantics as if they were regular, uniquely-labeled symbols).
75  void Count(const fst::Fst<Arc> &ifst);
76 
77  // Terminates counting and finalizes the FST result.
78  void Finalize() {
79  PhiCounts();
80  FinalizeFst();
81  }
82 
83  private:
84  // Initializes topology FST with markup to associated arc counts.
85  void InitFst();
86 
87  // Adds super-final state and arcs to an FST.
88  StateId AddSuperFinal(fst::MutableFst<Arc> *fst, bool *have_phi);
89 
90  // Composes input with topology FST.
91  void BuildComp(const fst::Fst<Arc> &ifst,
92  fst::MutableFst<SLArc> *ofst) const;
93 
94 
95  // Checks if signed-log input is normalized (no failure transitions).
96  bool CheckNorm(const fst::Fst<SLArc> &fst) const {
97  namespace f = fst;
98  using StateItr = f::StateIterator<f::Fst<SLArc>>;
99  using ArcItr = f::ArcIterator<f::Fst<SLArc>>;
100  for (StateItr siter(fst); !siter.Done(); siter.Next()) {
101  StateId s = siter.Value();
102  f::Adder<SLWeight> adder(fst.Final(s));
103  for (ArcItr aiter(fst, s); !aiter.Done(); aiter.Next()) {
104  const SLArc &arc = aiter.Value();
105  adder.Add(arc.weight);
106  }
107  if (!ApproxEqual(adder.Sum(), SLWeight::One()))
108  return false;
109  }
110  return true;
111  }
112 
113  // Computes the distances from the initial and to the final
114  // state in cfst. If cfst is normalized, 'norm' is true.
115  bool ComputeDistances(bool norm, fst::MutableFst<SLArc> *cfst,
116  std::vector<SLWeight> *distance,
117  std::vector<SLWeight> *rdistance) const;
118 
119  // Counts arcs in provided (signed log) FST
120  void CountArcs(const fst::ExpandedFst<SLArc> &fst,
121  const std::vector<SLWeight> &distance,
122  const std::vector<SLWeight> &rdistance);
123 
124  // Fills in failure counts.
125  void PhiCounts();
126 
127  // Removes markup from result.
128  void FinalizeFst();
129 
130  Label phi_label_; // failure label
131  float delta_; // delta value used by shortest-distance
132  fst::MutableFst<Arc> *ofst_; // output topology FST with counts
133  Label sf_label_; // super-final label
134  std::vector<SLWeight> arc_counts_; // maps from arc IDs to arc counts
135  std::vector<SLWeight> phi_counts_; // maps from state IDs to phi counts
136  StateId superfinal_; // ID of super-final state
137  bool have_iphi_; // phis on input FST
138  bool have_ophi_; // phis on topology FST
139  fst::WeightConvert<SLWeight, Weight> from_sl_; // convert from SLWeight
140 
141  Counter(const Counter &) = delete;
142  Counter &operator=(const Counter &) = delete;
143 };
144 
145 template <class Arc>
146 Counter<Arc>::Counter(Label phi_label, float delta,
147  fst::MutableFst<Arc> *ofst)
148  : phi_label_(phi_label), delta_(delta), ofst_(ofst), sf_label_(-2) {
149  namespace f = fst;
150 
151  if (ofst_->Start() == f::kNoStateId) {
152  FSTERROR() << "Counter: topology FST has no states";
153  ofst_->SetProperties(f::kError, f::kError);
154  return;
155  }
156  // Must be an epsilon-free deterministic canonical SFSA. Check properties
157  // one-at-a-time to get a useful error message.
158  const uint64_t props = ofst_->Properties(
159  f::kAcceptor | f::kIDeterministic | f::kILabelSorted | f::kNoEpsilons,
160  true);
161  if ((props & f::kAcceptor) != f::kAcceptor) {
162  FSTERROR() << "Counter: topology FST not an acceptor";
163  ofst_->SetProperties(f::kError, f::kError);
164  return;
165  }
166  if ((props & f::kIDeterministic) != f::kIDeterministic) {
167  FSTERROR() << "Counter: topology FST not i-deterministic";
168  ofst_->SetProperties(f::kError, f::kError);
169  return;
170  }
171  if ((props & f::kILabelSorted) != f::kILabelSorted) {
172  FSTERROR() << "Counter: topology FST not i-label sorted";
173  ofst_->SetProperties(f::kError, f::kError);
174  return;
175  }
176  if (phi_label != 0 && (props & f::kNoEpsilons) != f::kNoEpsilons) {
177  FSTERROR() << "Counter: topology FST not epsilon-free";
178  ofst_->SetProperties(f::kError, f::kError);
179  return;
180  }
181 
182  InitFst();
183 }
184 
185 template <class Arc>
186 void Counter<Arc>::InitFst() {
187  namespace f = fst;
188  superfinal_ = AddSuperFinal(ofst_, &have_ophi_);
189 
190  for (StateId s = 0; s < ofst_->NumStates(); ++s) {
191  for (f::MutableArcIterator<f::MutableFst<Arc>> aiter(ofst_, s);
192  !aiter.Done(); aiter.Next()) {
193  Arc arc = aiter.Value();
194  arc.weight = Weight::One(); // Removes any arc weight
195  if (arc.ilabel != phi_label_) {
196  ssize_t arc_id = arc_counts_.size();
197  // Output label stores the arc ID during the construction.
198  arc.olabel = arc_id;
199 
200  if (arc.olabel != arc_id) {
201  FSTERROR() << "Counter: arc label size needs to be large"
202  << " enough to store arc IDs in construction";
203  ofst_->SetProperties(f::kError, f::kError);
204  return;
205  }
206  arc_counts_.push_back(SLWeight::Zero());
207  }
208  aiter.SetValue(arc);
209  }
210  }
211 }
212 
213 template <class Arc>
214 typename Arc::StateId Counter<Arc>::AddSuperFinal(
215  fst::MutableFst<Arc> *fst, bool *have_phi) {
216  namespace f = fst;
217  *have_phi = false;
218 
219  // Checks for failure and superfinal labels.
220  // Makes all final states point to super-final state
221  StateId sf_state = fst->AddState();
222  fst->SetFinal(sf_state, Weight::One());
223  for (StateId s = 0; s < fst->NumStates() - 1; ++s) {
224  for (f::ArcIterator<f::Fst<Arc>> aiter(*fst, s); !aiter.Done();
225  aiter.Next()) {
226  const Arc &arc = aiter.Value();
227  if (arc.ilabel == sf_label_) {
228  FSTERROR() << "Counter: label " << sf_label_
229  << " reserved for the superfinal label";
230  ofst_->SetProperties(f::kError, f::kError);
231  return f::kNoStateId;
232  }
233  if (arc.ilabel == phi_label_) *have_phi = true;
234  }
235 
236  if (fst->Final(s) != Weight::Zero()) {
237  fst->AddArc(s, Arc(sf_label_, sf_label_, fst->Final(s), sf_state));
238  fst->SetFinal(s, Weight::Zero());
239  }
240  }
241  f::ArcSort(fst, f::ILabelCompare<Arc>());
242  return sf_state;
243 }
244 
245 template <class Arc>
246 void Counter<Arc>::Count(const fst::Fst<Arc> &ifst) {
247  namespace f = fst;
248  if (ifst.Start() == f::kNoStateId) {
249  FSTERROR() << "Counter: input FST has no states";
250  ofst_->SetProperties(f::kError, f::kError);
251  return;
252  }
253  if (!ifst.Properties(f::kAcceptor, true)) {
254  FSTERROR() << "Counter: input FST not an acceptor";
255  ofst_->SetProperties(f::kError, f::kError);
256  return;
257  }
258 
259  f::VectorFst<SLArc> cfst;
260  {
261  // Adds super-final transitions and state to the input FST.
262  f::VectorFst<Arc> sffst(ifst);
263  AddSuperFinal(&sffst, &have_iphi_);
264 
265  // Builds the composition of ifst and ofst and converts
266  // to signed log.
267  BuildComp(sffst, &cfst);
268  }
269 
270  // Checks if composition is normalized.
271  bool cnorm = CheckNorm(cfst);
272 
273  std::vector<SLWeight> distance, rdistance;
274  if (!ComputeDistances(cnorm, &cfst, &distance, &rdistance)) {
275  FSTERROR() << "Counter: shortest-distance computation failed.";
276  ofst_->SetProperties(f::kError, f::kError);
277  return;
278  }
279 
280  // Extracts counts from the composition using the shortest distance
281  CountArcs(cfst, distance, rdistance);
282 }
283 
284 template <class Arc>
285 void Counter<Arc>::BuildComp(const fst::Fst<Arc> &ifst,
286  fst::MutableFst<SLArc> *ofst) const {
287  namespace f = fst;
289  using PF = Phi2Filter<PM>;
290  f::ComposeFstOptions<Arc, PM, PF> copts;
291  copts.gc_limit = 0;
292  if (have_iphi_ || have_ophi_) {
293  Label iphi_label = have_iphi_ ? phi_label_ : f::kNoLabel;
294  Label ophi_label = have_ophi_ ? phi_label_ : f::kNoLabel;
295  f::MatchType imatch_type = have_iphi_ ? f::MATCH_OUTPUT : f::MATCH_NONE;
296  f::MatchType omatch_type = have_ophi_ ? f::MATCH_INPUT : f::MATCH_NONE;
297  copts.matcher1 = new PM(ifst, imatch_type, iphi_label);
298  copts.matcher2 = new PM(*ofst_, omatch_type, ophi_label);
299  }
300 
301  f::ComposeFst<Arc> cfst(ifst, *ofst_, copts);
302  // Moves to the signed-log converting phi-labels to epsilons.
303  const Label rm_phi_label =
304  have_iphi_ && have_ophi_ ? phi_label_ : f::kNoLabel;
305  internal::RmPhi(cfst, ofst, rm_phi_label, fst::MATCHER_REWRITE_NEVER);
306 }
307 
308 template <class Arc>
309 bool Counter<Arc>::ComputeDistances(bool norm, fst::MutableFst<SLArc> *cfst,
310  std::vector<SLWeight> *distance,
311  std::vector<SLWeight> *rdistance) const {
312  namespace f = fst;
314 
315  // Computes shortest distance on signed log semiring
316  if (norm) {
317  // Normalized: we only need the s.d. from the initial states
319  sdist(cfst, phi_label_, delta_);
320  if (!sdist.ComputeDistance(distance, false)) {
321  return false;
322  }
323  rdistance->resize(distance->size(), SLWeight::One());
324  return true;
325  } else {
326  // Not normalized: we need the s.d. in both directions
328  sdist(cfst, phi_label_, delta_);
329  return sdist.ComputeDistance(distance, false) &&
330  sdist.ComputeDistance(rdistance, true);
331  }
332 }
333 
334 template <class Arc>
335 void Counter<Arc>::CountArcs(const fst::ExpandedFst<SLArc> &fst,
336  const std::vector<SLWeight> &distance,
337  const std::vector<SLWeight> &rdistance) {
338  namespace f = fst;
339  for (StateId s = 0; s < fst.NumStates(); ++s) {
340  if (s >= distance.size()) continue; // distance is Zero()
341  SLWeight dist = distance[s];
342  if (ApproxZero(dist)) continue;
343  for (f::ArcIterator<f::Fst<SLArc>> aiter(fst, s); !aiter.Done();
344  aiter.Next()) {
345  const auto &arc = aiter.Value();
346  if (arc.ilabel && arc.ilabel != phi_label_) {
347  SLWeight count = Times(dist, arc.weight);
348  if (arc.nextstate >= rdistance.size()) continue; // rdistance is Zero()
349  SLWeight rdist = rdistance[arc.nextstate];
350  if (ApproxZero(rdist)) continue;
351  count = Times(count, rdist);
352  // olabel stores arc ID
353  ssize_t arc_id = arc.olabel;
354  arc_counts_[arc_id] = Plus(arc_counts_[arc_id], count);
355  }
356  }
357  }
358 }
359 
360 template <class Arc>
362  namespace f = fst;
363  using Matr = f::ExplicitMatcher<f::Matcher<f::Fst<Arc>>>;
364 
365  StateId initial = ofst_->Start();
366  phi_counts_.resize(ofst_->NumStates(), SLWeight::Zero());
367 
368  if (phi_label_ == f::kNoLabel) return;
369 
370  for (StateId s = 0; s < ofst_->NumStates(); ++s) {
371  for (f::ArcIterator<f::Fst<Arc>> aiter(*ofst_, s); !aiter.Done();
372  aiter.Next()) {
373  const Arc &arc = aiter.Value();
374  if (arc.ilabel == phi_label_) continue;
375  ssize_t arc_id = arc.olabel;
376  phi_counts_[arc.nextstate] =
377  Plus(phi_counts_[arc.nextstate], arc_counts_[arc_id]);
378  phi_counts_[s] = Minus(phi_counts_[s], arc_counts_[arc_id]);
379  }
380  }
381 
382  // Incoming count mass to superfinal state => incoming mass
383  // to initial state.
384  phi_counts_[initial] = Plus(phi_counts_[initial], phi_counts_[superfinal_]);
385 
386  std::vector<StateId> top_order;
387  bool acyclic = PhiTopOrder(*ofst_, phi_label_, &top_order);
388  if (!acyclic) {
389  FSTERROR() << "Counter: topology FST is not canonical (phi-cyclic)";
390  ofst_->SetProperties(f::kError, f::kError);
391  return;
392  }
393 
394  Matr matcher(*ofst_, f::MATCH_INPUT);
395  for (StateId i = 0; i < ofst_->NumStates(); ++i) {
396  StateId s = top_order[i]; // ith state in phi-top order
397  matcher.SetState(s);
398  if (matcher.Find(phi_label_)) {
399  const Arc &arc = matcher.Value();
400  StateId fails = arc.nextstate;
401  phi_counts_[fails] = Plus(phi_counts_[fails], phi_counts_[s]);
402  }
403  }
404 }
405 
406 template <class Arc>
408  namespace f = fst;
409 
410  for (StateId s = 0; s < ofst_->NumStates(); ++s) {
411  for (f::MutableArcIterator<f::MutableFst<Arc>> aiter(ofst_, s);
412  !aiter.Done(); aiter.Next()) {
413  Arc arc = aiter.Value();
414  SLWeight arc_count;
415  if (arc.ilabel != phi_label_) {
416  arc_count = arc_counts_[arc.olabel]; // olabel holds arc_id
417  } else {
418  arc_count = phi_counts_[s];
419  }
420 
421  if (Less(arc_count, SLWeight::Zero())) {
422  arc.weight = Weight::Zero(); // ensures sanity
423  } else {
424  arc.weight = from_sl_(arc_count);
425  }
426 
427  if (arc.ilabel != sf_label_) {
428  arc.olabel = arc.ilabel; // projects on to input
429  aiter.SetValue(arc);
430  } else {
431  ofst_->SetFinal(s, arc.weight);
432  }
433  }
434  }
435 
436  // Deletes super-final state and transitions entering it.
437  std::vector<StateId> dstates(1, superfinal_);
438  ofst_->DeleteStates(dstates);
439  f::ArcSort(ofst_, f::ILabelCompare<Arc>());
440 }
441 
442 
443 // Tests that the total weight entering each state
444 // equals the total weight leaving that state. This should
445 // be true, for example, of the result of Counter().
446 template <class Arc>
447 bool IsConservative(const fst::Fst<Arc> &fst,
448  float delta = fst::kDelta) {
449  namespace f = fst;
450  using StateId = typename Arc::StateId;
451  using Weight = typename Arc::Weight;
452  using ArcItr = f::ArcIterator<f::Fst<Arc>>;
453  using Log64Weight = f::Log64Weight;
454 
455  StateId nstates = CountStates(fst);
456 
457  f::WeightConvert<Weight, Log64Weight> to_log64;
458  std::vector<Log64Weight> in_weight(nstates, Log64Weight::Zero());
459  std::vector<Log64Weight> out_weight(nstates, Log64Weight::Zero());
460 
461  for (StateId s = 0; s < nstates; ++s) {
462  for (ArcItr aiter(fst, s); !aiter.Done(); aiter.Next()) {
463  const Arc &arc = aiter.Value();
464  const Log64Weight weight = to_log64(arc.weight);
465  in_weight[arc.nextstate] = Plus(in_weight[arc.nextstate], weight);
466  out_weight[s] = Plus(out_weight[s], weight);
467  }
468  if (fst.Final(s) != Weight::Zero()) {
469  const Log64Weight weight = to_log64(fst.Final(s));
470  out_weight[s] = Plus(out_weight[s], weight);
471  in_weight[fst.Start()] = Plus(in_weight[fst.Start()], weight);
472  }
473  }
474  for (StateId s = 0; s < nstates; ++s) {
475  if (!ApproxEqual(in_weight[s], out_weight[s], delta)) {
476  VLOG(1) << "state: " << s
477  << " in_weight: " << in_weight[s]
478  << " out_weight: " << out_weight[s];
479  return false;
480  }
481  }
482  return true;
483 }
484 
485 
486 } // namespace sfst
487 
488 #endif // NLP_GRM2_SFST_COUNT_H_
Counter(Label phi_label, float delta, fst::MutableFst< Arc > *ofst)
Definition: count.h:146
void Finalize()
Definition: count.h:78
bool Less(fst::LogWeightTpl< T > weight1, fst::LogWeightTpl< T > weight2)
Definition: sfst.h:38
typename Arc::StateId StateId
Definition: count.h:52
typename Arc::Weight Weight
Definition: count.h:54
Definition: perplexity.h:32
bool PhiTopOrder(const fst::Fst< Arc > &fst, typename Arc::Label phi_label, std::vector< typename Arc::StateId > *top_order)
Definition: canonical.h:37
Entropy64Weight Minus(Entropy64Weight w1, Entropy64Weight w2)
Definition: perplexity.h:129
Definition: sfstinfo.cc:39
bool ApproxZero(fst::Log64Weight weight, fst::Log64Weight approx_zero=kApproxZeroWeight)
Definition: sfst.h:84
bool IsConservative(const fst::Fst< Arc > &fst, float delta=fst::kDelta)
Definition: count.h:447
typename Arc::Label Label
Definition: count.h:53
fst::SignedLog64Weight SLWeight
Definition: count.h:57
void RmPhi(const fst::Fst< IArc > &ifst, fst::MutableFst< OArc > *ofst, typename IArc::Label phi_label=fst::kNoLabel, fst::MatcherRewriteMode rewrite_mode=fst::MATCHER_REWRITE_AUTO, const WC &weight_convert=WC())
Definition: rmphi.h:234
fst::SignedLog64Arc::StateId SLStateId
Definition: count.h:56
bool ComputeDistance(std::vector< Weight > *distance, bool reverse=false)
fst::SignedLog64Arc SLArc
Definition: count.h:55
void Count(const fst::Fst< Arc > &ifst)
Definition: count.h:246
Real64Weight Times(SignedLog64Weight w1, Real64Weight w2)
Definition: perplexity.h:41