17 #ifndef NLP_GRM2_SFST_TRIM_H_ 18 #define NLP_GRM2_SFST_TRIM_H_ 20 #include <sys/types.h> 24 #include <unordered_map> 29 #include <fst/arcsort.h> 30 #include <fst/dfs-visit.h> 31 #include <fst/float-weight.h> 33 #include <fst/matcher.h> 34 #include <fst/properties.h> 35 #include <fst/queue.h> 53 using Label =
typename Arc::Label;
67 phi_label_(phi_label),
72 if (fst_.Start() == f::kNoStateId)
return;
74 ForbiddenLabels forbidden_labels;
75 FindForbiddenLabels(&forbidden_labels);
76 FindForbiddenPositions(forbidden_labels);
82 bool HasForbidden()
const {
return !forbidden_positions_.empty(); }
92 forbidden_positions_iter_ = forbidden_positions_.find(s);
100 return forbidden_positions_iter_->second;
104 void Next() { ++forbidden_positions_iter_; }
108 return forbidden_positions_iter_ == forbidden_positions_.end() ||
109 forbidden_positions_iter_->first != s_;
113 bool Error()
const {
return error_; }
119 Compare(
const std::vector<Status> &status,
120 const std::vector<StateId> &top_order)
121 : status_(status), top_order_(top_order.size(), fst::kNoStateId) {
123 for (
StateId s = 0; s < top_order.size(); ++s) {
124 top_order_[top_order[s]] = s;
130 if (status_[s1] == DISCOVERED && status_[s2] == DISCOVERED) {
132 }
else if (status_[s1] == DISCOVERED) {
134 }
else if (status_[s2] == DISCOVERED) {
137 return top_order_[s1] < top_order_[s2];
142 const std::vector<Status> &status_;
143 std::vector<StateId> top_order_;
147 using PhiAccessQueue = fst::ShortestFirstQueue<StateId, Compare>;
150 using ForbiddenLabels = std::vector<std::unordered_set<Label>>;
153 using AllowedLabels = std::unordered_multimap<StateId, Label>;
156 using ForbiddenPositions = std::multimap<StateId, ssize_t>;
159 using Matr = fst::ExplicitMatcher<fst::Matcher<fst::Fst<Arc>>>;
162 void FindForbiddenLabels(ForbiddenLabels *forbidden_labels);
165 void FindForbiddenPositions(
const ForbiddenLabels &forbidden_labels);
168 void PhiDiscover(
StateId s,
StateId source, PhiAccessQueue *queue,
169 ForbiddenLabels *forbidden_labels,
170 AllowedLabels *allowed_labels);
173 void Discover(
StateId s, PhiAccessQueue *queue,
174 ForbiddenLabels *forbidden_labels);
178 ForbiddenLabels *forbidden_labels)
const;
184 ForbiddenLabels *forbidden_labels,
185 AllowedLabels *allowed_labels)
const;
187 void SetError() { error_ =
true; }
189 const fst::Fst<Arc> &fst_;
195 std::vector<Status> status_;
198 ForbiddenPositions forbidden_positions_;
202 typename ForbiddenPositions::iterator forbidden_positions_iter_;
212 std::vector<StateId> top_order;
213 bool acyclic =
PhiTopOrder(fst_, phi_label_, &top_order);
215 FSTERROR() <<
"PhiAccess: FST is not canonical (phi-cyclic)";
220 nstates_ = top_order.size();
221 forbidden_labels->resize(nstates_);
223 if (phi_label_ == f::kNoLabel)
return;
225 AllowedLabels allowed_labels;
226 status_.resize(nstates_, UNDISCOVERED);
227 Compare comp(status_, top_order);
228 PhiAccessQueue queue(comp);
229 StateId initial = fst_.Start();
230 status_[initial] = DISCOVERED;
231 queue.Enqueue(initial);
235 while (!queue.Empty()) {
239 if (status_[s] == DISCOVERED || status_[s] == PHI_DISCOVERED) {
241 for (f::ArcIterator<f::Fst<Arc>> aiter(fst_, s); !aiter.Done();
243 const Arc &arc = aiter.Value();
244 if (arc.ilabel == phi_label_) {
245 PhiDiscover(arc.nextstate, s, &queue, forbidden_labels,
247 }
else if (status_[s] == DISCOVERED ||
248 (*forbidden_labels)[s].count(arc.ilabel) == 0) {
249 Discover(arc.nextstate, &queue, forbidden_labels);
252 status_[s] = (status_[s] == DISCOVERED) ? VISITED : PHI_VISITED;
255 Matr matcher(fst_, f::MATCH_INPUT);
257 auto liter = allowed_labels.find(s);
258 while (liter != allowed_labels.end() && liter->first == s) {
259 if (matcher.Find(liter->second)) {
260 for (; !matcher.Done(); matcher.Next()) {
261 const Arc &arc = matcher.Value();
262 Discover(arc.nextstate, &queue, forbidden_labels);
264 allowed_labels.erase(liter++);
271 if (matcher.Find(phi_label_)) {
272 const Arc &arc = matcher.Value();
273 PhiDiscover(arc.nextstate, s, &queue, forbidden_labels,
276 status_[s] = PHI_VISITED;
283 const ForbiddenLabels &forbidden_labels) {
285 f::SortedMatcher<f::Fst<Arc>> matcher(fst_, f::MATCH_INPUT);
286 for (
StateId s = 0; s < nstates_; ++s) {
287 if (forbidden_labels[s].empty())
continue;
289 for (
auto liter = forbidden_labels[s].begin();
290 liter != forbidden_labels[s].end(); ++liter) {
291 if (*liter != f::kNoLabel && matcher.Find(*liter)) {
293 for (; !matcher.Done(); matcher.Next()) {
294 ssize_t pos = matcher.Position();
295 forbidden_positions_.insert(std::make_pair(s, pos));
297 }
else if (*liter == f::kNoLabel && fst_.Final(s) != Weight::Zero()) {
299 forbidden_positions_.insert(std::make_pair(s, -1));
307 PhiAccessQueue *queue,
308 ForbiddenLabels *forbidden_labels,
309 AllowedLabels *allowed_labels) {
310 switch (status_[s]) {
312 SetForbidden(s, source, forbidden_labels);
313 status_[s] = PHI_DISCOVERED;
317 UpdateForbidden(s, source, forbidden_labels,
nullptr);
320 if (UpdateForbidden(s, source, forbidden_labels, allowed_labels)) {
321 status_[s] = PHI_REDISCOVERED;
325 case PHI_REDISCOVERED:
326 UpdateForbidden(s, source, forbidden_labels, allowed_labels);
335 ForbiddenLabels *forbidden_labels) {
336 switch (status_[s]) {
338 status_[s] = DISCOVERED;
343 case PHI_REDISCOVERED:
344 status_[s] = DISCOVERED;
346 (*forbidden_labels)[s].clear();
355 ForbiddenLabels *forbidden_labels)
const {
357 for (f::ArcIterator<f::Fst<Arc>> aiter(fst_, source); !aiter.Done();
359 const Arc &arc = aiter.Value();
360 if (arc.ilabel != phi_label_ && arc.ilabel != 0)
361 (*forbidden_labels)[s].insert(arc.ilabel);
363 if (fst_.Final(source) != Weight::Zero())
364 (*forbidden_labels)[s].insert(f::kNoLabel);
366 (*forbidden_labels)[s].insert((*forbidden_labels)[source].begin(),
367 (*forbidden_labels)[source].end());
372 ForbiddenLabels *forbidden_labels,
373 AllowedLabels *allowed_labels)
const {
375 Matr matcher(fst_, f::MATCH_INPUT);
376 bool newly_allowed =
false;
377 matcher.SetState(source);
378 auto liter = (*forbidden_labels)[s].begin();
379 while (liter != (*forbidden_labels)[s].end()) {
380 if (*liter != f::kNoLabel && matcher.Find(*liter)) {
382 }
else if (*liter == f::kNoLabel && fst_.Final(source) != Weight::Zero()) {
384 }
else if ((*forbidden_labels)[source].find(*liter) !=
385 (*forbidden_labels)[source].end()) {
389 if (*liter != f::kNoLabel) {
390 if (allowed_labels !=
nullptr)
391 allowed_labels->insert(std::make_pair(s, *liter));
393 newly_allowed =
true;
394 (*forbidden_labels)[s].erase(liter++);
397 return newly_allowed;
405 bool IsTrim(
const fst::Fst<Arc> &
fst,
typename Arc::Label phi_label) {
408 f::SccVisitor<Arc> scc_visitor(
nullptr,
nullptr,
nullptr, &props);
409 f::DfsVisit(fst, &scc_visitor);
410 if (props & (f::kNotAccessible | f::kNotCoAccessible))
return false;
447 using Matr = fst::ExplicitMatcher<fst::Matcher<fst::Fst<Arc>>>;
448 using PhiMatr = fst::PhiMatcher<fst::Matcher<fst::Fst<Arc>>>;
450 Trimmer(fst::MutableFst<Arc> *
fst,
typename Arc::Label phi_label,
453 phi_label_(phi_label),
454 dead_state_(fst::kNoStateId),
455 dead_if_not_needed_state_(fst::kNoStateId),
456 needed_state_(fst::kNoStateId),
457 trim_type_(trim_type),
459 chk_coaccess_(false) {}
467 void WeightTrim(
bool include_phi,
Weight approx_zero = ApproxZeroWeight());
472 void SumWeightTrim(
bool include_phi,
Weight approx_zero = ApproxZeroWeight());
478 chk_coaccess_ =
true;
487 if (dead_state_ == fst::kNoStateId) {
488 dead_state_ = fst_->AddState();
489 del_states_.push_back(dead_state_);
496 StateId DeadIfNotNeededState() {
497 if (dead_if_not_needed_state_ == fst::kNoStateId) {
498 dead_if_not_needed_state_ = fst_->AddState();
499 del_states_.push_back(dead_if_not_needed_state_);
501 return dead_if_not_needed_state_;
506 if (needed_state_ == fst::kNoStateId) {
507 needed_state_ = fst_->AddState();
509 fst_->SetFinal(needed_state_, Weight::One());
511 return needed_state_;
518 if (phi_label_ == f::kNoLabel)
return f::kNoStateId;
520 Matr matcher(fst_, f::MATCH_INPUT);
522 return matcher.Find(phi_label_) ? matcher.Value().nextstate : f::kNoStateId;
526 static Weight ApproxZeroWeight() {
528 f::WeightConvert<f::Log64Weight, Weight> from_log;
532 fst::MutableFst<Arc> *fst_;
535 StateId dead_if_not_needed_state_;
538 std::vector<StateId> del_states_;
542 fst::WeightConvert<Weight, fst::Log64Weight> to_log_;
552 if (phi_access.
Error()) {
553 fst_->SetProperties(f::kError, f::kError);
561 del_states_.push_back(s);
566 f::MutableArcIterator<f::MutableFst<Arc>> aiter(fst_, s);
568 for (; !phi_access.
Done(); phi_access.
Next()) {
572 Arc arc = aiter.Value();
573 arc.nextstate = DeadState();
574 arc.ilabel = f::kNoLabel;
575 arc.olabel = f::kNoLabel;
578 fst_->SetFinal(s, Weight::Zero());
582 chk_coaccess_ =
true;
583 if (!del_states_.empty()) {
584 if (dead_if_not_needed_state_ != f::kNoStateId) {
585 FSTERROR() <<
"Trimmer::PhiTrim() called after (Sum)WeightTrim()";
586 fst_->SetProperties(f::kError, f::kError);
588 fst_->DeleteStates(del_states_);
589 f::ArcSort(fst_, f::ILabelCompare<Arc>());
591 dead_state_ = f::kNoLabel;
598 const f::Log64Weight log_approx_zero = to_log_(approx_zero);
599 const StateId nstates = fst_->NumStates();
600 for (
StateId s = 0; s < nstates; ++s) {
601 for (f::MutableArcIterator<f::MutableFst<Arc>> aiter(fst_, s);
602 !aiter.Done(); aiter.Next()) {
603 Arc arc = aiter.Value();
604 if ((include_phi || arc.ilabel != phi_label_) &&
605 arc.nextstate != dead_state_) {
606 f::Log64Weight arc_weight = to_log_(arc.weight);
607 if (
ApproxZero(arc_weight, log_approx_zero)) {
609 arc.nextstate = DeadIfNotNeededState();
616 chk_coaccess_ =
true;
622 const f::Log64Weight log_approx_zero = to_log_(approx_zero);
623 f::VectorFst<Arc> sum_fst(*fst_);
625 const StateId nstates = fst_->NumStates();
626 for (
StateId s = 0; s < nstates; ++s) {
627 f::ArcIterator<f::Fst<Arc>> sum_aiter(sum_fst, s);
628 for (f::MutableArcIterator<f::MutableFst<Arc>> aiter(fst_, s);
629 !aiter.Done(); aiter.Next(), sum_aiter.Next()) {
630 Arc arc = aiter.Value();
631 if ((include_phi || arc.ilabel != phi_label_) &&
632 arc.nextstate != dead_state_) {
633 const Arc &sum_arc = sum_aiter.Value();
634 f::Log64Weight sum_arc_weight = to_log_(sum_arc.weight);
635 if (
ApproxZero(sum_arc_weight, log_approx_zero)) {
637 arc.nextstate = DeadIfNotNeededState();
644 chk_coaccess_ =
true;
651 std::vector<bool> access;
652 std::vector<bool> coaccess;
654 if (chk_access_ || chk_coaccess_) {
655 f::SccVisitor<Arc> scc_visitor(
nullptr, &access, &coaccess, &props);
656 f::DfsVisit(*fst_, &scc_visitor);
657 PhiMatr phi_matcher(fst_, f::MATCH_INPUT, phi_label_);
659 const ssize_t nstates = fst_->NumStates();
660 for (
StateId s = 0; s < nstates; ++s) {
662 if (chk_access_ && !access[s]) del_states_.push_back(s);
665 if (chk_coaccess_ && s != needed_state_ && !coaccess[s])
666 del_states_.push_back(s);
672 StateId phi_nextstate = PhiNextState(s);
673 if (chk_coaccess_ && phi_nextstate != f::kNoStateId) {
674 phi_matcher.SetState(phi_nextstate);
675 f::MutableArcIterator<f::MutableFst<Arc>> aiter(fst_, s);
676 for (; !aiter.Done(); aiter.Next()) {
677 Arc arc = aiter.Value();
678 if (arc.nextstate == dead_state_ || coaccess[arc.nextstate] ||
679 arc.ilabel == phi_label_ || arc.ilabel == 0) {
682 for (phi_matcher.Find(arc.ilabel); !phi_matcher.Done();
683 phi_matcher.Next()) {
684 const Arc &failed_arc = phi_matcher.Value();
685 if (coaccess[failed_arc.nextstate]) {
687 arc.nextstate = NeededState();
688 arc.weight = ApproxZeroWeight();
698 if (!del_states_.empty()) fst_->DeleteStates(del_states_);
699 if (!del_states_.empty() || needed_state_ != f::kNoStateId)
700 f::ArcSort(fst_, f::ILabelCompare<Arc>());
708 typename Arc::Label phi_label = fst::kNoLabel,
711 LOG(ERROR) <<
"Trim: FST is not canonical";
717 return !fst->Properties(fst::kError,
false);
722 #endif // NLP_GRM2_SFST_TRIM_H_ Trimmer(fst::MutableFst< Arc > *fst, typename Arc::Label phi_label, TrimType trim_type=TRIM_NEEDED_FINAL)
ssize_t ForbiddenPosition() const
bool Accessible(StateId s) const
typename Arc::Label Label
bool PhiTopOrder(const fst::Fst< Arc > &fst, typename Arc::Label phi_label, std::vector< typename Arc::StateId > *top_order)
typename Arc::Weight Weight
StateId NumStates() const
fst::PhiMatcher< fst::Matcher< fst::Fst< Arc >>> PhiMatr
bool ApproxZero(fst::Log64Weight weight, fst::Log64Weight approx_zero=kApproxZeroWeight)
bool SumBackoff(fst::MutableFst< Arc > *fst, typename Arc::Label phi_label)
typename Arc::StateId StateId
bool IsCanonical(const fst::Fst< Arc > &fst, typename Arc::Label phi_label, std::vector< typename Arc::StateId > *top_order)
typename Arc::Weight Weight
bool Trim(fst::MutableFst< Arc > *fst, typename Arc::Label phi_label=fst::kNoLabel, TrimType trim_type=TRIM_NEEDED_FINAL)
PhiAccess(const fst::Fst< Arc > &fst, typename Arc::Label phi_label)
void WeightTrim(bool include_phi, Weight approx_zero=ApproxZeroWeight())
typename Arc::Label Label
const fst::Log64Weight kApproxZeroWeight
void SumWeightTrim(bool include_phi, Weight approx_zero=ApproxZeroWeight())
bool IsTrim(const fst::Fst< Arc > &fst, typename Arc::Label phi_label)
fst::ExplicitMatcher< fst::Matcher< fst::Fst< Arc >>> Matr
typename Arc::StateId StateId
enum{UNDISCOVERED, PHI_DISCOVERED, DISCOVERED, PHI_VISITED, PHI_REDISCOVERED, VISITED} Status
bool HasForbidden() const