22 #include <unordered_map> 29 #include <fst/signed-log-weight.h> 30 #include <fst/vector-fst.h> 31 #include <fst/weight.h> 40 public fst::QueueBase<fst::SignedLog64Arc::StateId> {
42 using Arc = fst::SignedLog64Arc;
46 using MMap = std::unordered_multimap<StateId, StateId>;
55 const std::vector<Weight> &distance,
56 const std::vector<StateId> &astates,
57 size_t astart,
bool reverse =
false)
58 : QueueBase<
StateId>(fst::OTHER_QUEUE),
60 astart_(reverse ? (astart + 1) : astart),
63 if (fst.Properties(f::kAcyclic,
true)) {
64 f::AnyArcFilter<Arc> arc_filter;
65 lev1_queue_.reset(
new f::TopOrderQueue<StateId>(fst, arc_filter));
67 lev1_queue_.reset(
new f::FifoQueue<StateId>());
72 if (lev2_queue_.Empty())
74 return lev2_queue_.Head();
80 lev1_queue_->Enqueue(s);
82 StateId as = astates_[s - astart_];
84 map_.insert(std::make_pair(as, s));
90 if (lev2_queue_.Empty())
92 lev2_queue_.Dequeue();
96 if (s < astart_) lev1_queue_->Update(s);
100 return lev1_queue_->Empty() && lev2_queue_.Empty();
104 lev1_queue_->Clear();
113 void FillLev2Queue()
const {
114 StateId as = lev1_queue_->Head();
115 auto iter = map_.find(as);
116 while (iter != map_.end() && iter->first == as) {
121 lev2_queue_.Enqueue(s);
125 lev1_queue_->Dequeue();
126 lev2_queue_.Enqueue(as);
129 const std::vector<StateId> &astates_;
135 std::unique_ptr<fst::QueueBase<StateId>> lev1_queue_;
137 mutable fst::FifoQueue<StateId> lev2_queue_;
146 size_t operator()(
const std::pair<StateId, StateId> &p)
const {
147 static constexpr
auto prime = 7853;
148 return p.first + p.second * prime;
152 void SLShortestDistance::BalancePaths(fst::MutableFst<Arc> *
fst) {
155 if (!astates_.empty())
return;
157 if (phi_label_ == f::kNoLabel || fst->Properties(f::kAcyclic,
true))
160 std::unordered_map<std::pair<StateId, StateId>,
StateId,
AMapHash> amap;
162 for (StateId s = 0; s < astart_; ++s) {
163 StateId as = f::kNoStateId;
164 std::unordered_map<StateId, Weight> ns_weight;
166 for (f::ArcIterator<f::MutableFst<Arc>> aiter(*fst, s);
167 !aiter.Done(); aiter.Next()) {
168 const Arc &arc = aiter.Value();
170 if (arc.ilabel == 0 && arc.olabel == phi_label_)
172 auto it = ns_weight.find(arc.nextstate);
173 if (it != ns_weight.end()) {
174 it->second = Plus(it->second, arc.weight);
176 ns_weight[arc.nextstate] = arc.weight;
180 if (as == f::kNoStateId)
continue;
181 for (f::MutableArcIterator<f::MutableFst<Arc>> aiter(fst, s);
182 !aiter.Done(); aiter.Next()) {
183 Arc arc = aiter.Value();
185 if (
Less(arc.weight, Weight::Zero()) &&
186 (!
Less(Weight::Zero(), ns_weight[arc.nextstate]))) {
191 std::pair<StateId, StateId> p(as, arc.nextstate);
192 auto it = amap.find(p);
193 if (it == amap.end()) {
194 StateId t = fst->AddState();
195 fst->AddArc(t,
Arc(0, 0, Weight::One(), arc.nextstate));
197 astates_.push_back(as);
200 arc.nextstate = it->second;
209 std::vector<fst::SignedLog64Weight> *distance,
bool reverse) {
213 f::AnyArcFilter<Arc> arc_filter;
216 f::VectorFst<Arc> rfst;
217 f::Reverse(fst_, &rfst);
218 std::vector<Weight> rdistance;
220 astates_, astart_,
true);
221 f::ShortestDistanceOptions<Arc, f::QueueBase<StateId>,
222 f::AnyArcFilter<Arc>>
223 opts(&queue, arc_filter);
225 ShortestDistance(rfst, &rdistance, opts);
227 while (distance->size() < rdistance.size() - 1)
228 distance->push_back(rdistance[distance->size() + 1]);
231 astates_, astart_,
false);
232 f::ShortestDistanceOptions<Arc, f::QueueBase<StateId>,
233 f::AnyArcFilter<Arc>>
234 opts(&queue, arc_filter);
236 ShortestDistance(fst_, distance, opts);
239 if (!distance->empty() && !(*distance)[0].Member()) {
240 LOG(ERROR) <<
"SLShortestDistance: shortest distance computation failed";
std::unordered_multimap< StateId, StateId > MMap
StateId Head() const override
bool Empty() const override
bool Less(fst::Log64Weight weight1, fst::Log64Weight weight2)
fst::SignedLog64Arc::StateId StateId
void Update(StateId s) override
size_t operator()(const std::pair< StateId, StateId > &p) const
void Enqueue(StateId s) override
bool ComputeDistance(std::vector< fst::SignedLog64Weight > *distance, bool reverse=false)
SLShortestDistanceQueue(const fst::Fst< Arc > &fst, const std::vector< Weight > &distance, const std::vector< StateId > &astates, size_t astart, bool reverse=false)