16 #ifndef NLP_GRM2_SFST_SHORTEST_DISTANCE_H_ 17 #define NLP_GRM2_SFST_SHORTEST_DISTANCE_H_ 22 #include <unordered_map> 26 #include <fst/arcfilter.h> 28 #include <fst/matcher.h> 29 #include <fst/properties.h> 30 #include <fst/queue.h> 32 #include <fst/vector-fst.h> 33 #include <fst/weight.h> 45 public fst::QueueBase<typename Arc::StateId> {
48 using Label =
typename Arc::Label;
50 using MMap = std::unordered_multimap<StateId, StateId>;
59 const std::vector<Weight> &distance,
60 const std::vector<StateId> &astates,
61 size_t astart,
bool reverse =
false)
62 : fst::QueueBase<
StateId>(fst::OTHER_QUEUE),
64 astart_(reverse ? (astart + 1) : astart),
67 if (fst.Properties(f::kAcyclic,
true)) {
68 f::AnyArcFilter<Arc> arc_filter;
70 std::make_unique<f::TopOrderQueue<StateId>>(fst, arc_filter);
72 lev1_queue_ = std::make_unique<f::FifoQueue<StateId>>();
77 if (lev2_queue_.Empty())
79 return lev2_queue_.Head();
85 lev1_queue_->Enqueue(s);
87 StateId as = astates_[s - astart_];
89 map_.insert(std::make_pair(as, s));
95 if (lev2_queue_.Empty())
97 lev2_queue_.Dequeue();
101 if (s < astart_) lev1_queue_->Update(s);
105 return lev1_queue_->Empty() && lev2_queue_.Empty();
109 lev1_queue_->Clear();
118 void FillLev2Queue()
const {
119 StateId as = lev1_queue_->Head();
120 auto iter = map_.find(as);
121 while (iter != map_.end() && iter->first == as) {
126 lev2_queue_.Enqueue(s);
130 lev1_queue_->Dequeue();
131 lev2_queue_.Enqueue(as);
134 const std::vector<StateId> &astates_;
140 std::unique_ptr<fst::QueueBase<StateId>> lev1_queue_;
142 mutable fst::FifoQueue<StateId> lev2_queue_;
153 template <
class Arc,
class WeightEqual = fst::WeightApproxEqual>
164 const fst::Fst<Arc> &
fst,
165 float delta = fst::kShortestDelta)
167 phi_label_(fst::kNoLabel),
170 astart_ = f::CountStates(fst);
180 fst::MutableFst<Arc> *
fst,
181 typename Arc::Label phi_label = fst::kNoLabel,
182 float delta = fst::kShortestDelta)
184 phi_label_(phi_label),
186 astart_ = fst->NumStates();
198 bool ComputeDistance(
199 std::vector<Weight> *distance,
bool reverse =
false);
203 size_t operator()(
const std::pair<StateId, StateId> &p)
const {
204 static constexpr
auto prime = 7853;
205 return p.first + p.second * prime;
214 void BalancePaths(fst::MutableFst<Arc> *
fst);
216 const fst::Fst<Arc> &fst_;
225 std::vector<StateId> astates_;
231 template <
class Arc,
class WeightEqual>
233 WeightEqual>::BalancePaths(fst::MutableFst<Arc> *
fst) {
236 if (!astates_.empty())
return;
238 if (phi_label_ == f::kNoLabel || fst->Properties(f::kAcyclic,
true))
241 std::unordered_map<std::pair<StateId, StateId>,
StateId, AMapHash> amap;
243 for (StateId s = 0; s < astart_; ++s) {
244 StateId as = f::kNoStateId;
245 std::unordered_map<StateId, Weight> ns_weight;
247 for (f::ArcIterator<f::MutableFst<Arc>> aiter(*fst, s);
248 !aiter.Done(); aiter.Next()) {
249 const Arc &arc = aiter.Value();
251 if (arc.ilabel == 0 && arc.olabel == phi_label_)
253 auto it = ns_weight.find(arc.nextstate);
254 if (it != ns_weight.end()) {
255 it->second = Plus(it->second, arc.weight);
257 ns_weight[arc.nextstate] = arc.weight;
261 if (as == f::kNoStateId)
continue;
262 for (f::MutableArcIterator<f::MutableFst<Arc>> aiter(fst, s);
263 !aiter.Done(); aiter.Next()) {
264 Arc arc = aiter.Value();
268 ns_weight[arc.nextstate] == Weight::Zero())) {
273 std::pair<StateId, StateId> p(as, arc.nextstate);
274 auto it = amap.find(p);
275 if (it == amap.end()) {
276 StateId t = fst->AddState();
277 fst->AddArc(t, Arc(0, 0, Weight::One(), arc.nextstate));
279 astates_.push_back(as);
282 arc.nextstate = it->second;
290 template <
class Arc,
class WeightEqual>
292 std::vector<Weight> *distance,
bool reverse) {
295 using ArcFilter = f::AnyArcFilter<Arc>;
296 using Queue = f::QueueBase<StateId>;
299 ArcFilter arc_filter;
300 std::unique_ptr<f::Fst<Arc>> fst;
303 auto rfst = std::make_unique<f::VectorFst<Arc>>();
304 f::Reverse(fst_, rfst.get());
305 fst = std::move(rfst);
307 fst.reset(fst_.Copy());
311 astates_, astart_, reverse);
312 f::ShortestDistanceOptions<Arc, Queue, ArcFilter> opts(&queue, arc_filter);
316 f::internal::ShortestDistanceState<Arc, Queue, ArcFilter, WeightEqual>
317 sd_state(*fst, distance, opts,
false);
318 sd_state.ShortestDistance(opts.source);
319 if (sd_state.Error())
323 std::vector<Weight> rdistance;
324 while (rdistance.size() < distance->size() - 1)
325 rdistance.push_back((*distance)[rdistance.size() + 1]);
326 *distance = rdistance;
345 template <
class Arc,
class SignedArc = fst::SignedLog64Arc,
346 class WeightEqual = fst::WeightApproxEqual>
348 const fst::Fst<Arc> &
fst,
349 std::vector<typename Arc::Weight> *distance,
350 typename Arc::Label phi_label = fst::kNoLabel,
351 bool reverse =
false,
float delta = fst::kShortestDelta) {
353 using StateId =
typename Arc::StateId;
354 using Weight =
typename Arc::Weight;
355 using SignedStateId =
typename SignedArc::StateId;
356 using SignedWeight =
typename SignedArc::Weight;
358 f::VectorFst<SignedArc>
sfst;
360 size_t ins = sfst.NumStates();
362 sdist(&sfst, phi_label, delta);
363 std::vector<SignedWeight> sdistance;
365 return Weight::NoWeight();
366 f::WeightConvert<SignedWeight, Weight> from_signed_convert;
368 for (
size_t i = 0; i < sdistance.size(); ++i) {
369 auto dist = sdistance[i];
371 dist = SignedWeight::Zero();
372 distance->push_back(from_signed_convert(dist));
376 distance->resize(std::min(ins, distance->size()));
381 f::Adder<Weight> total_weight;
383 if (distance->size() > sfst.Start())
384 total_weight.Add((*distance)[sfst.Start()]);
386 for (
StateId s = 0; s < distance->size(); ++s) {
387 Weight final_weight = from_signed_convert(sfst.Final(s));
388 total_weight.Add(
Times((*distance)[s], final_weight));
392 return total_weight.Sum();
397 #endif // NLP_GRM2_SFST_SHORTEST_DISTANCE_H_
bool Empty() const override
bool IsNegative(fst::SignedLog64Weight w)
void Enqueue(StateId s) override
typename Arc::StateId StateId
SignedShortestDistance(const fst::Fst< Arc > &fst, float delta=fst::kShortestDelta)
typename Arc::StateId StateId
std::unordered_multimap< StateId, StateId > MMap
Arc::Weight ShortestDistance(const fst::Fst< Arc > &fst, std::vector< typename Arc::Weight > *distance, typename Arc::Label phi_label=fst::kNoLabel, bool reverse=false, float delta=fst::kShortestDelta)
void RmPhi(const fst::Fst< IArc > &ifst, fst::MutableFst< OArc > *ofst, typename IArc::Label phi_label=fst::kNoLabel, fst::MatcherRewriteMode rewrite_mode=fst::MATCHER_REWRITE_AUTO, const WC &weight_convert=WC())
typename Arc::Label Label
SignedShortestDistance(fst::MutableFst< Arc > *fst, typename Arc::Label phi_label=fst::kNoLabel, float delta=fst::kShortestDelta)
void Update(StateId s) override
bool ComputeDistance(std::vector< Weight > *distance, bool reverse=false)
SignedShortestDistanceQueue(const fst::Fst< Arc > &fst, const std::vector< Weight > &distance, const std::vector< StateId > &astates, size_t astart, bool reverse=false)
typename Arc::Label Label
typename Arc::Weight Weight
StateId Head() const override
Real64Weight Times(SignedLog64Weight w1, Real64Weight w2)
typename Arc::Weight Weight