GRM-SFST  sfst-1.1.0
OpenGrm SFst Library
shortest-distance.cc
Go to the documentation of this file.
1 
2 // Licensed under the Apache License, Version 2.0 (the 'License');
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 //
6 // http://www.apache.org/licenses/LICENSE-2.0
7 //
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an 'AS IS' BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 //
14 // Copyright 2018 Google, Inc.
15 // shortest-distance.cc
16 //
17 // Computes the shortest distance with failure transitions.
18 
19 #include <sfst/shortest-distance.h>
20 
21 #include <memory>
22 #include <unordered_map>
23 #include <utility>
24 #include <vector>
25 
26 #include <fst/log.h>
27 #include <fst/fst.h>
28 #include <fst/shortest-distance.h>
29 #include <fst/signed-log-weight.h>
30 #include <fst/vector-fst.h>
31 #include <fst/weight.h>
32 #include <sfst/sfst.h>
33 
34 namespace sfst {
35 
36 // This queue is used with SLShortestDistance to correctly compute the
37 // shortest distance on the signed-log semiring when used with
38 // the output of SLRmPhi.
40  public fst::QueueBase<fst::SignedLog64Arc::StateId> {
41  public:
42  using Arc = fst::SignedLog64Arc;
43  using StateId = Arc::StateId;
44  using Label = Arc::Label;
45  using Weight = Arc::Weight;
46  using MMap = std::unordered_multimap<StateId, StateId>;
47 
48  // For each state q >= astart, q' = astates[q - astart] is its
49  // 'anti-state'; the q must be dequeued right before q'. This
50  // ensures that paths with negatively weighted transitions are
51  // matched up suitably with the corresponding paths of positive
52  // weight in the SLRmPhi construction. The astart value should be
53  // the number of states in the output of SLRmPhi.
54  SLShortestDistanceQueue(const fst::Fst<Arc> &fst,
55  const std::vector<Weight> &distance,
56  const std::vector<StateId> &astates,
57  size_t astart, bool reverse = false)
58  : QueueBase<StateId>(fst::OTHER_QUEUE),
59  astates_(astates),
60  astart_(reverse ? (astart + 1) : astart),
61  reverse_(reverse) {
62  namespace f = fst;
63  if (fst.Properties(f::kAcyclic, true)) {
64  f::AnyArcFilter<Arc> arc_filter;
65  lev1_queue_.reset(new f::TopOrderQueue<StateId>(fst, arc_filter));
66  } else {
67  lev1_queue_.reset(new f::FifoQueue<StateId>());
68  }
69  }
70 
71  StateId Head() const override {
72  if (lev2_queue_.Empty())
73  FillLev2Queue();
74  return lev2_queue_.Head();
75  }
76 
77  void Enqueue(StateId s) override {
78  // The level1 queue is the base queue for states less than astart_.
79  if (s < astart_) {
80  lev1_queue_->Enqueue(s);
81  } else {
82  StateId as = astates_[s - astart_];
83  if (reverse_) ++as; // reverse FST has super-initial state 0
84  map_.insert(std::make_pair(as, s));
85  }
86  }
87 
88  void Dequeue() override {
89  // Top queue is level2 queue.
90  if (lev2_queue_.Empty())
91  FillLev2Queue();
92  lev2_queue_.Dequeue();
93  }
94 
95  void Update(StateId s) override {
96  if (s < astart_) lev1_queue_->Update(s);
97  }
98 
99  bool Empty() const override {
100  return lev1_queue_->Empty() && lev2_queue_.Empty();
101  }
102 
103  void Clear() override {
104  lev1_queue_->Clear();
105  lev2_queue_.Clear();
106  map_.clear();
107  }
108 
109  private:
110  // This dequeues a state 'as' from the level1 queue and enqueues it in
111  // the level2 queue but only after enqueuing any states s > astart_
112  // specified by the map_ that must be dequeued with 'as'.
113  void FillLev2Queue() const {
114  StateId as = lev1_queue_->Head();
115  auto iter = map_.find(as);
116  while (iter != map_.end() && iter->first == as) {
117  StateId s = iter->second;
118  // Enqueues in lev2_queue and dequeues from map_[as] (we can't use
119  // map_ directly as an active queue since an iterator to it could
120  // be invalidated by this->Enqueue())
121  lev2_queue_.Enqueue(s);
122  map_.erase(iter++);
123  }
124  // Finally enqueues 'as' in level2 queue.
125  lev1_queue_->Dequeue();
126  lev2_queue_.Enqueue(as);
127  }
128 
129  const std::vector<StateId> &astates_; // s -> anti-s
130  size_t astart_; // astates offset
131  bool reverse_;
132 
133  mutable MMap map_; // anti-s -> s
134  // For when s < astart_ queue; this queue can be changed to any discipline.
135  std::unique_ptr<fst::QueueBase<StateId>> lev1_queue_;
136  // For when s >= astart_ queue; this must be FIFO.
137  mutable fst::FifoQueue<StateId> lev2_queue_;
138 
139 
141  SLShortestDistanceQueue &operator=(const SLShortestDistanceQueue &) = delete;
142 };
143 
144 struct AMapHash {
145  using StateId = fst::SignedLog64Arc::StateId;
146  size_t operator()(const std::pair<StateId, StateId> &p) const {
147  static constexpr auto prime = 7853;
148  return p.first + p.second * prime;
149  }
150 };
151 
152 void SLShortestDistance::BalancePaths(fst::MutableFst<Arc> *fst) {
153  namespace f = fst;
154 
155  if (!astates_.empty()) return;
156 
157  if (phi_label_ == f::kNoLabel || fst->Properties(f::kAcyclic, true))
158  return;
159 
160  std::unordered_map<std::pair<StateId, StateId>, StateId, AMapHash> amap;
161 
162  for (StateId s = 0; s < astart_; ++s) {
163  StateId as = f::kNoStateId;
164  std::unordered_map<StateId, Weight> ns_weight;
165  // Finds negative multiarcs
166  for (f::ArcIterator<f::MutableFst<Arc>> aiter(*fst, s);
167  !aiter.Done(); aiter.Next()) {
168  const Arc &arc = aiter.Value();
169  // TODO(riley): 'as' could be on a phi PATH
170  if (arc.ilabel == 0 && arc.olabel == phi_label_)
171  as = arc.nextstate; // the 'anti-state' for any added states
172  auto it = ns_weight.find(arc.nextstate);
173  if (it != ns_weight.end()) {
174  it->second = Plus(it->second, arc.weight);
175  } else {
176  ns_weight[arc.nextstate] = arc.weight;
177  }
178  }
179 
180  if (as == f::kNoStateId) continue;
181  for (f::MutableArcIterator<f::MutableFst<Arc>> aiter(fst, s);
182  !aiter.Done(); aiter.Next()) {
183  Arc arc = aiter.Value();
184  // Negative arc and multiarc
185  if (Less(arc.weight, Weight::Zero()) &&
186  (!Less(Weight::Zero(), ns_weight[arc.nextstate]))) {
187  // Creates/reuses a shared state and epsilon arc that
188  // lengthens any negative arc that goes to arc.nextstate and
189  // has 'anti-state' as. This 'balances' the oppositely signed
190  // path lengths which facilitates the queue management.
191  std::pair<StateId, StateId> p(as, arc.nextstate);
192  auto it = amap.find(p);
193  if (it == amap.end()) {
194  StateId t = fst->AddState();
195  fst->AddArc(t, Arc(0, 0, Weight::One(), arc.nextstate));
196  amap[p] = t;
197  astates_.push_back(as);
198  arc.nextstate = t;
199  } else {
200  arc.nextstate = it->second;
201  }
202  aiter.SetValue(arc);
203  }
204  }
205  }
206 }
207 
209  std::vector<fst::SignedLog64Weight> *distance, bool reverse) {
210  namespace f = fst;
211 
212  distance->clear();
213  f::AnyArcFilter<Arc> arc_filter;
214 
215  if (reverse) {
216  f::VectorFst<Arc> rfst;
217  f::Reverse(fst_, &rfst);
218  std::vector<Weight> rdistance;
219  SLShortestDistanceQueue queue(rfst, rdistance,
220  astates_, astart_, true);
221  f::ShortestDistanceOptions<Arc, f::QueueBase<StateId>,
222  f::AnyArcFilter<Arc>>
223  opts(&queue, arc_filter);
224  opts.delta = delta_;
225  ShortestDistance(rfst, &rdistance, opts);
226 
227  while (distance->size() < rdistance.size() - 1)
228  distance->push_back(rdistance[distance->size() + 1]);
229  } else {
230  SLShortestDistanceQueue queue(fst_, *distance,
231  astates_, astart_, false);
232  f::ShortestDistanceOptions<Arc, f::QueueBase<StateId>,
233  f::AnyArcFilter<Arc>>
234  opts(&queue, arc_filter);
235  opts.delta = delta_;
236  ShortestDistance(fst_, distance, opts);
237  }
238 
239  if (!distance->empty() && !(*distance)[0].Member()) {
240  LOG(ERROR) << "SLShortestDistance: shortest distance computation failed";
241  return false;
242  }
243  return true;
244 }
245 
246 } // namespace sfst
std::unordered_multimap< StateId, StateId > MMap
Definition: randgen.h:99
Definition: sfstinfo.cc:40
StateId Head() const override
bool Less(fst::Log64Weight weight1, fst::Log64Weight weight2)
Definition: sfst.h:40
fst::SignedLog64Arc::StateId StateId
void Update(StateId s) override
size_t operator()(const std::pair< StateId, StateId > &p) const
void Enqueue(StateId s) override
bool ComputeDistance(std::vector< fst::SignedLog64Weight > *distance, bool reverse=false)
SLShortestDistanceQueue(const fst::Fst< Arc > &fst, const std::vector< Weight > &distance, const std::vector< StateId > &astates, size_t astart, bool reverse=false)