GRM-SFST  sfst-1.2.1
OpenGrm SFst Library
trim.h
Go to the documentation of this file.
1 // Copyright 2018-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // Classes to identify and remove useless states and transitions in
15 // a stochastic FST.
16 
17 #ifndef NLP_GRM2_SFST_TRIM_H_
18 #define NLP_GRM2_SFST_TRIM_H_
19 
20 #include <sys/types.h>
21 
22 #include <cstdint>
23 #include <map>
24 #include <unordered_map>
25 #include <utility>
26 #include <vector>
27 
28 #include <fst/log.h>
29 #include <fst/arcsort.h>
30 #include <fst/dfs-visit.h>
31 #include <fst/float-weight.h>
32 #include <fst/fst.h>
33 #include <fst/matcher.h>
34 #include <fst/properties.h>
35 #include <fst/queue.h>
36 #include <fst/util.h>
37 #include <sfst/backoff.h>
38 #include <sfst/canonical.h>
39 #include <sfst/sfst.h>
40 
41 namespace sfst {
42 
43 // Identifies inaccessible transitions due to failure labels.
44 // Time complexity: O(V + E * max_phi_order).
45 // (Actually V -> V log(V) due to queue 'optimization' for the
46 // expected cases). Assumes (but does not fully check) that the input
47 // has the canonical topology (see canonical.h).
48 template <class Arc>
49 class PhiAccess {
50  public:
51  using StateId = typename Arc::StateId;
52  using Weight = typename Arc::Weight;
53  using Label = typename Arc::Label;
54 
55  // Status of states in visitation.
56  using Status = enum {
57  UNDISCOVERED, // not yet seen; not enqueued
58  PHI_DISCOVERED, // seen by a phi-transition; enqueued
59  DISCOVERED, // seen by a non-phi transition; enqueued
60  PHI_VISITED, // expanded with phi discovery; dequeued
61  PHI_REDISCOVERED, // expanded with phi discovery; re-enqueued
62  VISITED // expanded with discovery; dequeued
63  };
64 
65  PhiAccess(const fst::Fst<Arc> &fst, typename Arc::Label phi_label)
66  : fst_(fst),
67  phi_label_(phi_label),
68  nstates_(0),
69  error_(false),
70  s_(fst::kNoStateId) {
71  namespace f = fst;
72  if (fst_.Start() == f::kNoStateId) return;
73 
74  ForbiddenLabels forbidden_labels;
75  FindForbiddenLabels(&forbidden_labels);
76  FindForbiddenPositions(forbidden_labels);
77  }
78 
79  StateId NumStates() const { return nstates_; }
80 
81  // Returns true if there are 'forbidden transitions' at some states.
82  bool HasForbidden() const { return !forbidden_positions_.empty(); }
83 
84  // Returns true if state is accessible on an allowed path.
85  bool Accessible(StateId s) const { return status_[s] != UNDISCOVERED; }
86 
87  // These are used to iterate over the 'forbidden positions' at a state.
88 
89  // Sets current state for identifying 'forbidden positions'.
90  void SetState(StateId s) {
91  s_ = s;
92  forbidden_positions_iter_ = forbidden_positions_.find(s);
93  }
94 
95  // Returns an arc position that is not accessible from the
96  // current state: a 'forbidden position'. Arc position 0
97  // is the first arc from the state, etc. Arc position -1
98  // denotes the superfinal transition.
99  ssize_t ForbiddenPosition() const {
100  return forbidden_positions_iter_->second;
101  }
102 
103  // Next forbidden position at the current state.
104  void Next() { ++forbidden_positions_iter_; }
105 
106  // Any more forbidden positions?
107  bool Done() const {
108  return forbidden_positions_iter_ == forbidden_positions_.end() ||
109  forbidden_positions_iter_->first != s_;
110  }
111 
112  // Error seen with this input.
113  bool Error() const { return error_; }
114 
115  private:
116  // Comparison function for phi-access queue.
117  class Compare {
118  public:
119  Compare(const std::vector<Status> &status,
120  const std::vector<StateId> &top_order)
121  : status_(status), top_order_(top_order.size(), fst::kNoStateId) {
122  // top_order_ gives the topological position of state id s.
123  for (StateId s = 0; s < top_order.size(); ++s) {
124  top_order_[top_order[s]] = s;
125  }
126  }
127 
128  // Discovered is first, then phi-discovered/phi-visited by top order.
129  bool operator()(const StateId s1, const StateId s2) const {
130  if (status_[s1] == DISCOVERED && status_[s2] == DISCOVERED) {
131  return s1 < s2;
132  } else if (status_[s1] == DISCOVERED) {
133  return s1;
134  } else if (status_[s2] == DISCOVERED) {
135  return s2;
136  } else {
137  return top_order_[s1] < top_order_[s2];
138  }
139  }
140 
141  private:
142  const std::vector<Status> &status_;
143  std::vector<StateId> top_order_;
144  };
145 
146  // Queue for visitation
147  using PhiAccessQueue = fst::ShortestFirstQueue<StateId, Compare>;
148 
149  // Maps from StateId to set of labels that are phi-inaccessible.
150  using ForbiddenLabels = std::vector<std::unordered_set<Label>>;
151 
152  // Maps from state id to list of newly-allowed labels.
153  using AllowedLabels = std::unordered_multimap<StateId, Label>;
154 
155  // Maps from state id to list of forbidden positions;
156  using ForbiddenPositions = std::multimap<StateId, ssize_t>;
157 
158  // Explicit matcher (no epsilon loops)
159  using Matr = fst::ExplicitMatcher<fst::Matcher<fst::Fst<Arc>>>;
160 
161  // Finds the forbidden label set.
162  void FindForbiddenLabels(ForbiddenLabels *forbidden_labels);
163 
164  // Converts forbidden labels to forbidden positions.
165  void FindForbiddenPositions(const ForbiddenLabels &forbidden_labels);
166 
167  // Processes a phi-discovered state.
168  void PhiDiscover(StateId s, StateId source, PhiAccessQueue *queue,
169  ForbiddenLabels *forbidden_labels,
170  AllowedLabels *allowed_labels);
171 
172  // Processes a newly discovered state.
173  void Discover(StateId s, PhiAccessQueue *queue,
174  ForbiddenLabels *forbidden_labels);
175 
176  // Sets forbidden labels to (phi) source's labels and forbidden labels.
177  void SetForbidden(StateId s, StateId source,
178  ForbiddenLabels *forbidden_labels) const;
179 
180  // Deletes forbidden labels at s that are not forbidden at (phi) source
181  // or by source. If alabels non-null, adds them to allowed labels.
182  // Returns true if newly allowed labels at this state.
183  bool UpdateForbidden(StateId s, StateId source,
184  ForbiddenLabels *forbidden_labels,
185  AllowedLabels *allowed_labels) const;
186 
187  void SetError() { error_ = true; }
188 
189  const fst::Fst<Arc> &fst_;
190  Label phi_label_;
191  ssize_t nstates_;
192  bool error_;
193 
194  // Status of visitation at a state.
195  std::vector<Status> status_;
196 
197  // Indicates arc positions at a state that are forbidden.
198  ForbiddenPositions forbidden_positions_;
199 
200  // Iteration state
201  StateId s_;
202  typename ForbiddenPositions::iterator forbidden_positions_iter_;
203 
204  PhiAccess(const PhiAccess &) = delete;
205  PhiAccess &operator=(const PhiAccess &) = delete;
206 };
207 
208 template <class Arc>
209 void PhiAccess<Arc>::FindForbiddenLabels(ForbiddenLabels *forbidden_labels) {
210  namespace f = fst;
211 
212  std::vector<StateId> top_order;
213  bool acyclic = PhiTopOrder(fst_, phi_label_, &top_order);
214  if (!acyclic) {
215  FSTERROR() << "PhiAccess: FST is not canonical (phi-cyclic)";
216  SetError();
217  return;
218  }
219 
220  nstates_ = top_order.size();
221  forbidden_labels->resize(nstates_);
222 
223  if (phi_label_ == f::kNoLabel) return;
224 
225  AllowedLabels allowed_labels;
226  status_.resize(nstates_, UNDISCOVERED);
227  Compare comp(status_, top_order);
228  PhiAccessQueue queue(comp);
229  StateId initial = fst_.Start();
230  status_[initial] = DISCOVERED;
231  queue.Enqueue(initial);
232 
233  // Visitation that favors (non-phi) discovered states first, then
234  // phi-discovered states by phi top-order.
235  while (!queue.Empty()) {
236  StateId s = queue.Head();
237  queue.Dequeue();
238 
239  if (status_[s] == DISCOVERED || status_[s] == PHI_DISCOVERED) {
240  // Iterates over all arcs if newly (phi-)discovered.
241  for (f::ArcIterator<f::Fst<Arc>> aiter(fst_, s); !aiter.Done();
242  aiter.Next()) {
243  const Arc &arc = aiter.Value();
244  if (arc.ilabel == phi_label_) {
245  PhiDiscover(arc.nextstate, s, &queue, forbidden_labels,
246  &allowed_labels);
247  } else if (status_[s] == DISCOVERED ||
248  (*forbidden_labels)[s].count(arc.ilabel) == 0) {
249  Discover(arc.nextstate, &queue, forbidden_labels);
250  }
251  }
252  status_[s] = (status_[s] == DISCOVERED) ? VISITED : PHI_VISITED;
253  } else { // PHI_REDISCOVERED
254  // Iterates only over newly allowed arcs if already phi-visited.
255  Matr matcher(fst_, f::MATCH_INPUT);
256  matcher.SetState(s);
257  auto liter = allowed_labels.find(s);
258  while (liter != allowed_labels.end() && liter->first == s) {
259  if (matcher.Find(liter->second)) {
260  for (; !matcher.Done(); matcher.Next()) {
261  const Arc &arc = matcher.Value();
262  Discover(arc.nextstate, &queue, forbidden_labels);
263  }
264  allowed_labels.erase(liter++);
265  } else {
266  ++liter;
267  }
268  }
269 
270  // Updates along the failure path.
271  if (matcher.Find(phi_label_)) {
272  const Arc &arc = matcher.Value();
273  PhiDiscover(arc.nextstate, s, &queue, forbidden_labels,
274  &allowed_labels);
275  }
276  status_[s] = PHI_VISITED;
277  }
278  }
279 }
280 
281 template <class Arc>
283  const ForbiddenLabels &forbidden_labels) {
284  namespace f = fst;
285  f::SortedMatcher<f::Fst<Arc>> matcher(fst_, f::MATCH_INPUT);
286  for (StateId s = 0; s < nstates_; ++s) {
287  if (forbidden_labels[s].empty()) continue;
288  matcher.SetState(s);
289  for (auto liter = forbidden_labels[s].begin();
290  liter != forbidden_labels[s].end(); ++liter) {
291  if (*liter != f::kNoLabel && matcher.Find(*liter)) {
292  // forbidden label
293  for (; !matcher.Done(); matcher.Next()) {
294  ssize_t pos = matcher.Position();
295  forbidden_positions_.insert(std::make_pair(s, pos));
296  }
297  } else if (*liter == f::kNoLabel && fst_.Final(s) != Weight::Zero()) {
298  // forbidden superfinal label
299  forbidden_positions_.insert(std::make_pair(s, -1));
300  }
301  }
302  }
303 }
304 
305 template <class Arc>
307  PhiAccessQueue *queue,
308  ForbiddenLabels *forbidden_labels,
309  AllowedLabels *allowed_labels) {
310  switch (status_[s]) {
311  case UNDISCOVERED:
312  SetForbidden(s, source, forbidden_labels);
313  status_[s] = PHI_DISCOVERED;
314  queue->Enqueue(s);
315  break;
316  case PHI_DISCOVERED:
317  UpdateForbidden(s, source, forbidden_labels, nullptr);
318  break;
319  case PHI_VISITED:
320  if (UpdateForbidden(s, source, forbidden_labels, allowed_labels)) {
321  status_[s] = PHI_REDISCOVERED;
322  queue->Enqueue(s);
323  }
324  break;
325  case PHI_REDISCOVERED:
326  UpdateForbidden(s, source, forbidden_labels, allowed_labels);
327  break;
328  default:
329  break;
330  }
331 }
332 
333 template <class Arc>
334 void PhiAccess<Arc>::Discover(StateId s, PhiAccessQueue *queue,
335  ForbiddenLabels *forbidden_labels) {
336  switch (status_[s]) {
337  case UNDISCOVERED:
338  status_[s] = DISCOVERED;
339  queue->Enqueue(s);
340  break;
341  case PHI_DISCOVERED:
342  case PHI_VISITED:
343  case PHI_REDISCOVERED:
344  status_[s] = DISCOVERED;
345  queue->Update(s);
346  (*forbidden_labels)[s].clear();
347  break;
348  default:
349  break;
350  }
351 }
352 
353 template <class Arc>
355  ForbiddenLabels *forbidden_labels) const {
356  namespace f = fst;
357  for (f::ArcIterator<f::Fst<Arc>> aiter(fst_, source); !aiter.Done();
358  aiter.Next()) {
359  const Arc &arc = aiter.Value();
360  if (arc.ilabel != phi_label_ && arc.ilabel != 0)
361  (*forbidden_labels)[s].insert(arc.ilabel);
362  }
363  if (fst_.Final(source) != Weight::Zero())
364  (*forbidden_labels)[s].insert(f::kNoLabel);
365 
366  (*forbidden_labels)[s].insert((*forbidden_labels)[source].begin(),
367  (*forbidden_labels)[source].end());
368 }
369 
370 template <class Arc>
372  ForbiddenLabels *forbidden_labels,
373  AllowedLabels *allowed_labels) const {
374  namespace f = fst;
375  Matr matcher(fst_, f::MATCH_INPUT);
376  bool newly_allowed = false;
377  matcher.SetState(source);
378  auto liter = (*forbidden_labels)[s].begin();
379  while (liter != (*forbidden_labels)[s].end()) {
380  if (*liter != f::kNoLabel && matcher.Find(*liter)) {
381  ++liter;
382  } else if (*liter == f::kNoLabel && fst_.Final(source) != Weight::Zero()) {
383  ++liter;
384  } else if ((*forbidden_labels)[source].find(*liter) !=
385  (*forbidden_labels)[source].end()) {
386  ++liter;
387  } else {
388  // no longer forbidden
389  if (*liter != f::kNoLabel) {
390  if (allowed_labels != nullptr)
391  allowed_labels->insert(std::make_pair(s, *liter));
392  }
393  newly_allowed = true;
394  (*forbidden_labels)[s].erase(liter++);
395  }
396  }
397  return newly_allowed;
398 }
399 
400 // Tests if input is trim - i.e., no inaccessible or coaccessible
401 // states or transitions due to failure transitions or otherwise
402 // (irrespective of weights). Assumes (but does not fully check) that
403 // the input has the canonical topology (see canonical.h).
404 template <class Arc>
405 bool IsTrim(const fst::Fst<Arc> &fst, typename Arc::Label phi_label) {
406  namespace f = fst;
407  uint64_t props = 0;
408  f::SccVisitor<Arc> scc_visitor(nullptr, nullptr, nullptr, &props);
409  f::DfsVisit(fst, &scc_visitor);
410  if (props & (f::kNotAccessible | f::kNotCoAccessible)) return false;
411  PhiAccess<Arc> phi_access(fst, phi_label);
412  return !phi_access.HasForbidden();
413 }
414 
415 // In trimming, some otherwise useless transitions if removed
416 // will change the failure semantics. In those cases, the
417 // trimming algorithm will:
418 //
419 // TRIM_NEEDED_TRIM: go ahead and remove. IsTrim() will be true
420 // but the result may not be equivalent to the input.
421 //
422 // Otherwise such transitions are directed to a unique new state n and:
423 //
424 // TRIM_NEEDED_FINAL: n is set final and the weights of transitions to
425 // it are set to kApproxZeroWeight. IsTrim() on the result will be true,
426 // but additional successful paths of with near zero weight may be added.
427 //
428 // TRIM_NEEDED_NONFINAL: n is set non-final. The successful paths are
429 // unchanged, but the result may have one non-coaccessible state,
430 // IsTrim() may be false.
431 enum TrimType {
435 };
436 
437 // Class to remove useless states and transitions in stochastic
438 // automata. Assumes (but does not fully check) that the input has the
439 // canonical topology (see canonical.h).
440 template <class Arc>
441 class Trimmer {
442  public:
443  using StateId = typename Arc::StateId;
444  using Weight = typename Arc::Weight;
445  using Label = typename Arc::Label;
446  // Explicit matcher (no epsilon loops)
447  using Matr = fst::ExplicitMatcher<fst::Matcher<fst::Fst<Arc>>>;
448  using PhiMatr = fst::PhiMatcher<fst::Matcher<fst::Fst<Arc>>>;
449 
450  Trimmer(fst::MutableFst<Arc> *fst, typename Arc::Label phi_label,
451  TrimType trim_type = TRIM_NEEDED_FINAL)
452  : fst_(fst),
453  phi_label_(phi_label),
454  dead_state_(fst::kNoStateId),
455  dead_if_not_needed_state_(fst::kNoStateId),
456  needed_state_(fst::kNoStateId),
457  trim_type_(trim_type),
458  chk_access_(false),
459  chk_coaccess_(false) {}
460 
461  // Removes inaccessible transitions due to failure labels
462  // If both PhiTrim() and (Sum)WeightTrim() are called, call PhiTrim() first.
463  void PhiTrim();
464 
465  // Removes ApproxZero() weight transitions where possible
466  // (optionally including phi_labeled ones) and connects.
467  void WeightTrim(bool include_phi, Weight approx_zero = ApproxZeroWeight());
468 
469  // Equivalent to SumBackoff() + WeightTrim() followed by restoring original
470  // weights on untrimmed transitions. Useful since it preserves
471  // backoff-completeness.
472  void SumWeightTrim(bool include_phi, Weight approx_zero = ApproxZeroWeight());
473 
474  // Removes inaccessible and non-coassessible states treating
475  // failure labels as regular labels.
476  void Connect() {
477  chk_access_ = true;
478  chk_coaccess_ = true;
479  }
480 
481  // Finalizes result.
482  void Finalize();
483 
484  private:
485  // Returns the asborbing state that will be deleted, adding one if needed.
486  StateId DeadState() {
487  if (dead_state_ == fst::kNoStateId) {
488  dead_state_ = fst_->AddState();
489  del_states_.push_back(dead_state_);
490  }
491  return dead_state_;
492  }
493 
494  // Returns the asborbing state that will be deleted if not needed,
495  // adding one if needed.
496  StateId DeadIfNotNeededState() {
497  if (dead_if_not_needed_state_ == fst::kNoStateId) {
498  dead_if_not_needed_state_ = fst_->AddState();
499  del_states_.push_back(dead_if_not_needed_state_);
500  }
501  return dead_if_not_needed_state_;
502  }
503 
504  // Returns the asborbing state that will be kept, adding one if needed.
505  StateId NeededState() {
506  if (needed_state_ == fst::kNoStateId) {
507  needed_state_ = fst_->AddState();
508  if (trim_type_ == TRIM_NEEDED_FINAL)
509  fst_->SetFinal(needed_state_, Weight::One());
510  }
511  return needed_state_;
512  }
513 
514  // Returns state ID of (first) failure state from s or kNoStateId
515  // in none.
516  StateId PhiNextState(StateId s) const {
517  namespace f = fst;
518  if (phi_label_ == f::kNoLabel) return f::kNoStateId;
519 
520  Matr matcher(fst_, f::MATCH_INPUT);
521  matcher.SetState(s);
522  return matcher.Find(phi_label_) ? matcher.Value().nextstate : f::kNoStateId;
523  }
524 
525  // Returns weight near Zero()
526  static Weight ApproxZeroWeight() {
527  namespace f = fst;
528  f::WeightConvert<f::Log64Weight, Weight> from_log;
529  return from_log(kApproxZeroWeight);
530  }
531 
532  fst::MutableFst<Arc> *fst_;
533  Label phi_label_;
534  StateId dead_state_; // dead state
535  StateId dead_if_not_needed_state_; // dead-if-not-needed state
536  StateId needed_state_; // needed state
537  TrimType trim_type_;
538  std::vector<StateId> del_states_; // states to delete
539  bool chk_access_; // checks accessibility
540  bool chk_coaccess_; // checks coaccesssibility
541 
542  fst::WeightConvert<Weight, fst::Log64Weight> to_log_;
543 
544  Trimmer(const Trimmer &) = delete;
545  Trimmer &operator=(const Trimmer &) = delete;
546 };
547 
548 template <class Arc>
550  namespace f = fst;
551  PhiAccess<Arc> phi_access(*fst_, phi_label_);
552  if (phi_access.Error()) {
553  fst_->SetProperties(f::kError, f::kError);
554  return;
555  }
556 
557  // Points forbidden arcs (including superfinal) to the dead state.
558  for (StateId s = 0; s < phi_access.NumStates(); ++s) {
559  // Marks an inaccessible state to be deleted.
560  if (!phi_access.Accessible(s)) {
561  del_states_.push_back(s);
562  continue;
563  }
564 
565  // Points forbidden arcs (including superfinal) to the dead state.
566  f::MutableArcIterator<f::MutableFst<Arc>> aiter(fst_, s);
567  phi_access.SetState(s);
568  for (; !phi_access.Done(); phi_access.Next()) {
569  ssize_t pos = phi_access.ForbiddenPosition();
570  if (pos != -1) {
571  aiter.Seek(pos);
572  Arc arc = aiter.Value();
573  arc.nextstate = DeadState();
574  arc.ilabel = f::kNoLabel;
575  arc.olabel = f::kNoLabel;
576  aiter.SetValue(arc);
577  } else {
578  fst_->SetFinal(s, Weight::Zero());
579  }
580  }
581  }
582  chk_coaccess_ = true;
583  if (!del_states_.empty()) {
584  if (dead_if_not_needed_state_ != f::kNoStateId) {
585  FSTERROR() << "Trimmer::PhiTrim() called after (Sum)WeightTrim()";
586  fst_->SetProperties(f::kError, f::kError);
587  }
588  fst_->DeleteStates(del_states_);
589  f::ArcSort(fst_, f::ILabelCompare<Arc>());
590  del_states_.clear();
591  dead_state_ = f::kNoLabel;
592  }
593 }
594 
595 template <class Arc>
596 void Trimmer<Arc>::WeightTrim(bool include_phi, Weight approx_zero) {
597  namespace f = fst;
598  const f::Log64Weight log_approx_zero = to_log_(approx_zero);
599  const StateId nstates = fst_->NumStates(); // excludes any added states
600  for (StateId s = 0; s < nstates; ++s) {
601  for (f::MutableArcIterator<f::MutableFst<Arc>> aiter(fst_, s);
602  !aiter.Done(); aiter.Next()) {
603  Arc arc = aiter.Value();
604  if ((include_phi || arc.ilabel != phi_label_) &&
605  arc.nextstate != dead_state_) {
606  f::Log64Weight arc_weight = to_log_(arc.weight);
607  if (ApproxZero(arc_weight, log_approx_zero)) {
608  // redirects to dead-if-not-needed state
609  arc.nextstate = DeadIfNotNeededState();
610  aiter.SetValue(arc);
611  }
612  }
613  }
614  }
615  chk_access_ = true;
616  chk_coaccess_ = true;
617 }
618 
619 template <class Arc>
620 void Trimmer<Arc>::SumWeightTrim(bool include_phi, Weight approx_zero) {
621  namespace f = fst;
622  const f::Log64Weight log_approx_zero = to_log_(approx_zero);
623  f::VectorFst<Arc> sum_fst(*fst_);
624  SumBackoff(&sum_fst, phi_label_);
625  const StateId nstates = fst_->NumStates(); // excludes any added states
626  for (StateId s = 0; s < nstates; ++s) {
627  f::ArcIterator<f::Fst<Arc>> sum_aiter(sum_fst, s);
628  for (f::MutableArcIterator<f::MutableFst<Arc>> aiter(fst_, s);
629  !aiter.Done(); aiter.Next(), sum_aiter.Next()) {
630  Arc arc = aiter.Value();
631  if ((include_phi || arc.ilabel != phi_label_) &&
632  arc.nextstate != dead_state_) {
633  const Arc &sum_arc = sum_aiter.Value();
634  f::Log64Weight sum_arc_weight = to_log_(sum_arc.weight);
635  if (ApproxZero(sum_arc_weight, log_approx_zero)) {
636  // redirects to dead-if-not-needed state
637  arc.nextstate = DeadIfNotNeededState();
638  aiter.SetValue(arc);
639  }
640  }
641  }
642  }
643  chk_access_ = true;
644  chk_coaccess_ = true;
645 }
646 
647 template <class Arc>
649  namespace f = fst;
650 
651  std::vector<bool> access;
652  std::vector<bool> coaccess;
653  uint64_t props = 0;
654  if (chk_access_ || chk_coaccess_) {
655  f::SccVisitor<Arc> scc_visitor(nullptr, &access, &coaccess, &props);
656  f::DfsVisit(*fst_, &scc_visitor);
657  PhiMatr phi_matcher(fst_, f::MATCH_INPUT, phi_label_);
658 
659  const ssize_t nstates = fst_->NumStates();
660  for (StateId s = 0; s < nstates; ++s) {
661  // Marks an inaccessible state for deletion.
662  if (chk_access_ && !access[s]) del_states_.push_back(s);
663 
664  // Marks a non-coaccessible, unneeded state for deletion.
665  if (chk_coaccess_ && s != needed_state_ && !coaccess[s])
666  del_states_.push_back(s);
667 
668  if (trim_type_ == TRIM_NEEDED_TRIM) continue;
669 
670  // Points arcs needed to preserve the failure semantics to the
671  // needed state. Gives them kApproxZeroWeight.
672  StateId phi_nextstate = PhiNextState(s);
673  if (chk_coaccess_ && phi_nextstate != f::kNoStateId) {
674  phi_matcher.SetState(phi_nextstate);
675  f::MutableArcIterator<f::MutableFst<Arc>> aiter(fst_, s);
676  for (; !aiter.Done(); aiter.Next()) {
677  Arc arc = aiter.Value();
678  if (arc.nextstate == dead_state_ || coaccess[arc.nextstate] ||
679  arc.ilabel == phi_label_ || arc.ilabel == 0) {
680  continue;
681  }
682  for (phi_matcher.Find(arc.ilabel); !phi_matcher.Done();
683  phi_matcher.Next()) {
684  const Arc &failed_arc = phi_matcher.Value();
685  if (coaccess[failed_arc.nextstate]) {
686  // redirects to the needed state
687  arc.nextstate = NeededState();
688  arc.weight = ApproxZeroWeight();
689  aiter.SetValue(arc);
690  break;
691  }
692  }
693  }
694  }
695  }
696  }
697 
698  if (!del_states_.empty()) fst_->DeleteStates(del_states_);
699  if (!del_states_.empty() || needed_state_ != f::kNoStateId)
700  f::ArcSort(fst_, f::ILabelCompare<Arc>());
701 }
702 
703 // Simple interface to trimming. Removes useless states and
704 // transitions in stochastic automata (irrespective of weights).
705 // Returns true on success.
706 template <class Arc>
707 bool Trim(fst::MutableFst<Arc> *fst,
708  typename Arc::Label phi_label = fst::kNoLabel,
709  TrimType trim_type = TRIM_NEEDED_FINAL) {
710  if (!IsCanonical(*fst, phi_label)) {
711  LOG(ERROR) << "Trim: FST is not canonical";
712  return false;
713  }
714  Trimmer<Arc> trim(fst, phi_label, trim_type);
715  trim.PhiTrim();
716  trim.Finalize();
717  return !fst->Properties(fst::kError, false);
718 }
719 
720 } // namespace sfst
721 
722 #endif // NLP_GRM2_SFST_TRIM_H_
Trimmer(fst::MutableFst< Arc > *fst, typename Arc::Label phi_label, TrimType trim_type=TRIM_NEEDED_FINAL)
Definition: trim.h:450
ssize_t ForbiddenPosition() const
Definition: trim.h:99
bool Accessible(StateId s) const
Definition: trim.h:85
void Connect()
Definition: trim.h:476
typename Arc::Label Label
Definition: trim.h:445
void PhiTrim()
Definition: trim.h:549
Definition: perplexity.h:32
bool PhiTopOrder(const fst::Fst< Arc > &fst, typename Arc::Label phi_label, std::vector< typename Arc::StateId > *top_order)
Definition: canonical.h:37
typename Arc::Weight Weight
Definition: trim.h:52
StateId NumStates() const
Definition: trim.h:79
Definition: sfstinfo.cc:39
fst::PhiMatcher< fst::Matcher< fst::Fst< Arc >>> PhiMatr
Definition: trim.h:448
bool ApproxZero(fst::Log64Weight weight, fst::Log64Weight approx_zero=kApproxZeroWeight)
Definition: sfst.h:84
bool SumBackoff(fst::MutableFst< Arc > *fst, typename Arc::Label phi_label)
Definition: backoff.h:250
typename Arc::StateId StateId
Definition: trim.h:51
void Finalize()
Definition: trim.h:648
bool IsCanonical(const fst::Fst< Arc > &fst, typename Arc::Label phi_label, std::vector< typename Arc::StateId > *top_order)
Definition: canonical.h:71
typename Arc::Weight Weight
Definition: trim.h:444
bool Trim(fst::MutableFst< Arc > *fst, typename Arc::Label phi_label=fst::kNoLabel, TrimType trim_type=TRIM_NEEDED_FINAL)
Definition: trim.h:707
PhiAccess(const fst::Fst< Arc > &fst, typename Arc::Label phi_label)
Definition: trim.h:65
void WeightTrim(bool include_phi, Weight approx_zero=ApproxZeroWeight())
Definition: trim.h:596
void Next()
Definition: trim.h:104
typename Arc::Label Label
Definition: trim.h:53
bool Error() const
Definition: trim.h:113
void SetState(StateId s)
Definition: trim.h:90
const fst::Log64Weight kApproxZeroWeight
Definition: sfst.h:34
void SumWeightTrim(bool include_phi, Weight approx_zero=ApproxZeroWeight())
Definition: trim.h:620
bool Done() const
Definition: trim.h:107
TrimType
Definition: trim.h:431
bool IsTrim(const fst::Fst< Arc > &fst, typename Arc::Label phi_label)
Definition: trim.h:405
fst::ExplicitMatcher< fst::Matcher< fst::Fst< Arc >>> Matr
Definition: trim.h:447
typename Arc::StateId StateId
Definition: trim.h:443
enum{UNDISCOVERED, PHI_DISCOVERED, DISCOVERED, PHI_VISITED, PHI_REDISCOVERED, VISITED} Status
Definition: trim.h:63
bool HasForbidden() const
Definition: trim.h:82