GRM-SFST  sfst-1.2.1
OpenGrm SFst Library
shortest-distance.h
Go to the documentation of this file.
1 // Copyright 2018-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // Computes the shortest distance with failure transitions.
15 
16 #ifndef NLP_GRM2_SFST_SHORTEST_DISTANCE_H_
17 #define NLP_GRM2_SFST_SHORTEST_DISTANCE_H_
18 
19 #include <algorithm>
20 #include <cstddef>
21 #include <memory>
22 #include <unordered_map>
23 #include <utility>
24 #include <vector>
25 
26 #include <fst/arcfilter.h>
27 #include <fst/fst.h>
28 #include <fst/matcher.h>
29 #include <fst/properties.h>
30 #include <fst/queue.h>
31 #include <fst/shortest-distance.h>
32 #include <fst/vector-fst.h>
33 #include <fst/weight.h>
34 #include <sfst/rmphi.h>
35 #include <sfst/sfst.h>
36 
37 namespace sfst {
38 
39 namespace internal {
40 
41 // This queue is used with SignedShortestDistance to correctly compute the
42 // shortest distance when used with the output of RmPhi.
43 template <class Arc>
45  public fst::QueueBase<typename Arc::StateId> {
46  public:
47  using StateId = typename Arc::StateId;
48  using Label = typename Arc::Label;
49  using Weight = typename Arc::Weight;
50  using MMap = std::unordered_multimap<StateId, StateId>;
51 
52  // For each state q >= astart, q' = astates[q - astart] is its
53  // 'anti-state'; the q must be dequeued right before q'. This
54  // ensures that paths with negatively weighted transitions are
55  // matched up suitably with the corresponding paths of positive
56  // weight in the RmPhi construction. The astart value should be
57  // the number of states in the output of RmPhi.
58  SignedShortestDistanceQueue(const fst::Fst<Arc> &fst,
59  const std::vector<Weight> &distance,
60  const std::vector<StateId> &astates,
61  size_t astart, bool reverse = false)
62  : fst::QueueBase<StateId>(fst::OTHER_QUEUE),
63  astates_(astates),
64  astart_(reverse ? (astart + 1) : astart),
65  reverse_(reverse) {
66  namespace f = fst;
67  if (fst.Properties(f::kAcyclic, true)) {
68  f::AnyArcFilter<Arc> arc_filter;
69  lev1_queue_ =
70  std::make_unique<f::TopOrderQueue<StateId>>(fst, arc_filter);
71  } else {
72  lev1_queue_ = std::make_unique<f::FifoQueue<StateId>>();
73  }
74  }
75 
76  StateId Head() const override {
77  if (lev2_queue_.Empty())
78  FillLev2Queue();
79  return lev2_queue_.Head();
80  }
81 
82  void Enqueue(StateId s) override {
83  // The level1 queue is the base queue for states less than astart_.
84  if (s < astart_) {
85  lev1_queue_->Enqueue(s);
86  } else {
87  StateId as = astates_[s - astart_];
88  if (reverse_) ++as; // reverse FST has super-initial state 0
89  map_.insert(std::make_pair(as, s));
90  }
91  }
92 
93  void Dequeue() override {
94  // Top queue is level2 queue.
95  if (lev2_queue_.Empty())
96  FillLev2Queue();
97  lev2_queue_.Dequeue();
98  }
99 
100  void Update(StateId s) override {
101  if (s < astart_) lev1_queue_->Update(s);
102  }
103 
104  bool Empty() const override {
105  return lev1_queue_->Empty() && lev2_queue_.Empty();
106  }
107 
108  void Clear() override {
109  lev1_queue_->Clear();
110  lev2_queue_.Clear();
111  map_.clear();
112  }
113 
114  private:
115  // This dequeues a state 'as' from the level1 queue and enqueues it in
116  // the level2 queue but only after enqueuing any states s > astart_
117  // specified by the map_ that must be dequeued with 'as'.
118  void FillLev2Queue() const {
119  StateId as = lev1_queue_->Head();
120  auto iter = map_.find(as);
121  while (iter != map_.end() && iter->first == as) {
122  StateId s = iter->second;
123  // Enqueues in lev2_queue and dequeues from map_[as] (we can't use
124  // map_ directly as an active queue since an iterator to it could
125  // be invalidated by this->Enqueue())
126  lev2_queue_.Enqueue(s);
127  map_.erase(iter++);
128  }
129  // Finally enqueues 'as' in level2 queue.
130  lev1_queue_->Dequeue();
131  lev2_queue_.Enqueue(as);
132  }
133 
134  const std::vector<StateId> &astates_; // s -> anti-s
135  size_t astart_; // astates offset
136  bool reverse_;
137 
138  mutable MMap map_; // anti-s -> s
139  // For when s < astart_ queue; this queue can be changed to any discipline.
140  std::unique_ptr<fst::QueueBase<StateId>> lev1_queue_;
141  // For when s >= astart_ queue; this must be FIFO.
142  mutable fst::FifoQueue<StateId> lev2_queue_;
143 
144 
147  &operator=(const SignedShortestDistanceQueue &) = delete;
148 };
149 
150 // This version of shortest distance computes the shortest distance on
151 // a ring. The 'Arc' weight (e.g. SignedLog64Arc) must have a Minus()
152 // operation (forming a ring).
153 template <class Arc, class WeightEqual = fst::WeightApproxEqual>
155  public:
156  using StateId = typename Arc::StateId;
157  using Label = typename Arc::Label;
158  using Weight = typename Arc::Weight;
159 
160  // For cyclic input and 'negative' weights, any convergence is, in
161  // general, conditional and not absolute, so it will depend on the
162  // specific input.
164  const fst::Fst<Arc> &fst,
165  float delta = fst::kShortestDelta)
166  : fst_(fst),
167  phi_label_(fst::kNoLabel),
168  delta_(delta) {
169  namespace f = fst;
170  astart_ = f::CountStates(fst);
171  }
172 
173  // This version is designed to work with the output of RmPhi
174  // (called with MATCHER_REWRITE_NEVER). It may have to add states
175  // and epsilon transitions, but this ensures convergence and a
176  // correct result when the shortest distance is defined and finite.
177  // The phi_label is passed since it is kept on the output label by
178  // RmPhi in this case.
180  fst::MutableFst<Arc> *fst,
181  typename Arc::Label phi_label = fst::kNoLabel,
182  float delta = fst::kShortestDelta)
183  : fst_(*fst),
184  phi_label_(phi_label),
185  delta_(delta) {
186  astart_ = fst->NumStates();
187  BalancePaths(fst);
188  }
189 
190  // This computes the shortest distance to the final states when 'reverse =
191  // true', o.w. computes it from the initial state. Convergence within 'delta'
192  // w.r.t. weights ('-log probs') unless 'cmp_exp_weights = true' ('probs').
193  // The former is more generally applicable (e.g. for normalization), the
194  // latter can be faster (e.g. for counting). Returns false on error. An
195  // unvisited state S has distance Zero(), which will be stored in the
196  // 'distance' vector if S is less than the maximum visited state. Additional
197  // states may have be added if constructed with the second constructor above.
198  bool ComputeDistance(
199  std::vector<Weight> *distance, bool reverse = false);
200 
201  private:
202  struct AMapHash {
203  size_t operator()(const std::pair<StateId, StateId> &p) const {
204  static constexpr auto prime = 7853;
205  return p.first + p.second * prime;
206  }
207  };
208 
209  // This transforms a ring-weighted FST, generated
210  // by RmPhi, so that when used with the appropriate
211  // queue, the shortest distance will be correctly
212  // computed. This construction may add states and epsilon
213  // transitions.
214  void BalancePaths(fst::MutableFst<Arc> *fst);
215 
216  const fst::Fst<Arc> &fst_;
217  Label phi_label_;
218  float delta_;
219  // Any states >= this value are newly added.
220  size_t astart_;
221  // This vector is used as an argument to the shortest-path queue. It is used
222  // to ensure oppositely signed 'corresponding' paths are dequeued adjacent. In
223  // particular, q' = astates_[q - astart_] is the 'anti-state' for added state
224  // q; q must be dequeued right before the q'.
225  std::vector<StateId> astates_;
226 
228  SignedShortestDistance &operator=(const SignedShortestDistance &) = delete;
229 };
230 
231 template <class Arc, class WeightEqual>
232  void SignedShortestDistance<Arc,
233  WeightEqual>::BalancePaths(fst::MutableFst<Arc> *fst) {
234  namespace f = fst;
235 
236  if (!astates_.empty()) return;
237 
238  if (phi_label_ == f::kNoLabel || fst->Properties(f::kAcyclic, true))
239  return;
240 
241  std::unordered_map<std::pair<StateId, StateId>, StateId, AMapHash> amap;
242 
243  for (StateId s = 0; s < astart_; ++s) {
244  StateId as = f::kNoStateId;
245  std::unordered_map<StateId, Weight> ns_weight;
246  // Finds negative multiarcs
247  for (f::ArcIterator<f::MutableFst<Arc>> aiter(*fst, s);
248  !aiter.Done(); aiter.Next()) {
249  const Arc &arc = aiter.Value();
250  // TODO(riley): 'as' could be on a phi PATH
251  if (arc.ilabel == 0 && arc.olabel == phi_label_)
252  as = arc.nextstate; // the 'anti-state' for any added states
253  auto it = ns_weight.find(arc.nextstate);
254  if (it != ns_weight.end()) {
255  it->second = Plus(it->second, arc.weight);
256  } else {
257  ns_weight[arc.nextstate] = arc.weight;
258  }
259  }
260 
261  if (as == f::kNoStateId) continue;
262  for (f::MutableArcIterator<f::MutableFst<Arc>> aiter(fst, s);
263  !aiter.Done(); aiter.Next()) {
264  Arc arc = aiter.Value();
265  // Negative arc and multiarc
266  if (IsNegative(arc.weight) &&
267  (IsNegative(ns_weight[arc.nextstate]) ||
268  ns_weight[arc.nextstate] == Weight::Zero())) {
269  // Creates/reuses a shared state and epsilon arc that
270  // lengthens any negative arc that goes to arc.nextstate and
271  // has 'anti-state' as. This 'balances' the oppositely signed
272  // path lengths which facilitates the queue management.
273  std::pair<StateId, StateId> p(as, arc.nextstate);
274  auto it = amap.find(p);
275  if (it == amap.end()) {
276  StateId t = fst->AddState();
277  fst->AddArc(t, Arc(0, 0, Weight::One(), arc.nextstate));
278  amap[p] = t;
279  astates_.push_back(as);
280  arc.nextstate = t;
281  } else {
282  arc.nextstate = it->second;
283  }
284  aiter.SetValue(arc);
285  }
286  }
287  }
288 }
289 
290 template <class Arc, class WeightEqual>
292  std::vector<Weight> *distance, bool reverse) {
293  namespace f = fst;
294 
295  using ArcFilter = f::AnyArcFilter<Arc>;
296  using Queue = f::QueueBase<StateId>;
297 
298  distance->clear();
299  ArcFilter arc_filter;
300  std::unique_ptr<f::Fst<Arc>> fst;
301 
302  if (reverse) {
303  auto rfst = std::make_unique<f::VectorFst<Arc>>();
304  f::Reverse(fst_, rfst.get());
305  fst = std::move(rfst);
306  } else {
307  fst.reset(fst_.Copy());
308  }
309 
310  internal::SignedShortestDistanceQueue<Arc> queue(*fst, *distance,
311  astates_, astart_, reverse);
312  f::ShortestDistanceOptions<Arc, Queue, ArcFilter> opts(&queue, arc_filter);
313  opts.delta = delta_;
314 
315 
316  f::internal::ShortestDistanceState<Arc, Queue, ArcFilter, WeightEqual>
317  sd_state(*fst, distance, opts, false);
318  sd_state.ShortestDistance(opts.source);
319  if (sd_state.Error())
320  return false;
321 
322  if (reverse) {
323  std::vector<Weight> rdistance;
324  while (rdistance.size() < distance->size() - 1)
325  rdistance.push_back((*distance)[rdistance.size() + 1]);
326  *distance = rdistance;
327  }
328 
329  return true;
330 }
331 
332 } // namespace internal
333 
334 
335 // This version of shortest distance computes the shortest distance if failure
336 // transitions may be present. It computes the shortest distance to the final
337 // states when reverse = true, o.w. computes it from the initial state. Returns
338 // false on error. An unvisited state S has distance Zero(), which will be
339 // stored in the 'distance' vector if S is less than the maximum visited state.
340 // Assumes (but does not check) that the input has the canonical topology (see
341 // canonical.h). Also assumes input has no (non-phi) epsilons (or treats such
342 // epsilons w.r.t. the failure semantics as if they were regular,
343 // uniquely-labeled symbols). The 'SignedArc' weight must have a Minus()
344 // operation (forming a ring) and a WeightConvert method from 'Arc'.
345 template <class Arc, class SignedArc = fst::SignedLog64Arc,
346  class WeightEqual = fst::WeightApproxEqual>
347 typename Arc::Weight ShortestDistance(
348  const fst::Fst<Arc> &fst,
349  std::vector<typename Arc::Weight> *distance,
350  typename Arc::Label phi_label = fst::kNoLabel,
351  bool reverse = false, float delta = fst::kShortestDelta) {
352  namespace f = fst;
353  using StateId = typename Arc::StateId;
354  using Weight = typename Arc::Weight;
355  using SignedStateId = typename SignedArc::StateId;
356  using SignedWeight = typename SignedArc::Weight;
357 
358  f::VectorFst<SignedArc> sfst;
359  internal::RmPhi(fst, &sfst, phi_label, fst::MATCHER_REWRITE_NEVER);
360  size_t ins = sfst.NumStates();
362  sdist(&sfst, phi_label, delta);
363  std::vector<SignedWeight> sdistance;
364  if (!sdist.ComputeDistance(&sdistance, reverse))
365  return Weight::NoWeight();
366  f::WeightConvert<SignedWeight, Weight> from_signed_convert;
367  distance->clear();
368  for (size_t i = 0; i < sdistance.size(); ++i) {
369  auto dist = sdistance[i];
370  if (IsNegative(dist))
371  dist = SignedWeight::Zero();
372  distance->push_back(from_signed_convert(dist));
373  }
374 
375  // Removes any added states in the construction
376  distance->resize(std::min(ins, distance->size()));
377 
378  // Computes total weight. Note this is non-trivial from the distance vector
379  // w/o 'sfst' since it may have different final states than 'fst' due to
380  // phi-accessibility.
381  f::Adder<Weight> total_weight;
382  if (reverse) {
383  if (distance->size() > sfst.Start())
384  total_weight.Add((*distance)[sfst.Start()]);
385  } else {
386  for (StateId s = 0; s < distance->size(); ++s) {
387  Weight final_weight = from_signed_convert(sfst.Final(s));
388  total_weight.Add(Times((*distance)[s], final_weight));
389  }
390  }
391 
392  return total_weight.Sum();
393 }
394 
395 } // namespace sfst
396 
397 #endif // NLP_GRM2_SFST_SHORTEST_DISTANCE_H_
Definition: perplexity.h:32
bool IsNegative(fst::SignedLog64Weight w)
Definition: sfst.h:79
Definition: sfstinfo.cc:39
SignedShortestDistance(const fst::Fst< Arc > &fst, float delta=fst::kShortestDelta)
std::unordered_multimap< StateId, StateId > MMap
Arc::Weight ShortestDistance(const fst::Fst< Arc > &fst, std::vector< typename Arc::Weight > *distance, typename Arc::Label phi_label=fst::kNoLabel, bool reverse=false, float delta=fst::kShortestDelta)
void RmPhi(const fst::Fst< IArc > &ifst, fst::MutableFst< OArc > *ofst, typename IArc::Label phi_label=fst::kNoLabel, fst::MatcherRewriteMode rewrite_mode=fst::MATCHER_REWRITE_AUTO, const WC &weight_convert=WC())
Definition: rmphi.h:234
SignedShortestDistance(fst::MutableFst< Arc > *fst, typename Arc::Label phi_label=fst::kNoLabel, float delta=fst::kShortestDelta)
bool ComputeDistance(std::vector< Weight > *distance, bool reverse=false)
SignedShortestDistanceQueue(const fst::Fst< Arc > &fst, const std::vector< Weight > &distance, const std::vector< StateId > &astates, size_t astart, bool reverse=false)
Real64Weight Times(SignedLog64Weight w1, Real64Weight w2)
Definition: perplexity.h:41