GRM-SFST  sfst-1.2.1
OpenGrm SFst Library
randgen.h
Go to the documentation of this file.
1 // Copyright 2018-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // The FST must be fully normalized.
15 
16 #ifndef NLP_GRM2_SFST_RANDGEN_H_
17 #define NLP_GRM2_SFST_RANDGEN_H_
18 
19 #include <sys/types.h>
20 #include <time.h>
21 
22 #include <climits>
23 #include <cmath>
24 #include <cstddef>
25 #include <cstdlib>
26 #include <limits>
27 #include <map>
28 #include <memory>
29 #include <random>
30 #include <set>
31 #include <utility>
32 #include <vector>
33 
34 #include <fst/accumulator.h>
35 #include <fst/arc-map.h>
36 #include <fst/arc.h>
37 #include <fst/arcsort.h>
38 #include <fst/determinize.h>
39 #include <fst/fst.h>
40 #include <fst/minimize.h>
41 #include <fst/mutable-fst.h>
42 #include <fst/randgen.h>
43 #include <fst/relabel.h>
44 #include <fst/rmepsilon.h>
45 #include <fst/weight.h>
46 
47 namespace sfst {
48 
49 // Same as FastLogProbArcSelector but handles failure transitions
50 // labeled with 'phi'. Assumes (but does not check) the input is a
51 // normalized (see normalize.h) stochastic FST. Note that
52 // (non-failure) epsilons are treated as regular symbols where each
53 // instance behaves as if it is uniquely labeled (i.e., they are not
54 // constrained by failure transitions).
55 template <class A>
57  public:
58  typedef typename A::StateId StateId;
59  typedef typename A::Label Label;
60  typedef typename A::Weight Weight;
61 
62  explicit SFstArcSelector(int seed = time(nullptr),
63  Label phi_label = fst::kNoLabel)
64  : seed_(seed), phi_label_(phi_label) { srand(seed); }
65 
66  // Samples one transition.
67  size_t operator()(const fst::Fst<A> &fst, StateId s,
68  double pre_failure_prob, double failure_prob,
69  double numer_prob, ssize_t failure_pos,
70  fst::CacheLogAccumulator<A> *accumulator) const {
71  double r = rand()/(RAND_MAX + 1.0); // NOLINT
72  accumulator->SetState(s);
73  fst::ArcIterator<fst::Fst<A>> aiter(fst, s);
74 
75  if (failure_pos == -1 || r <= pre_failure_prob) {
76  // Selects before the failure transition.
77  return accumulator->LowerBound(-log(r), &aiter);
78  }
79 
80  if (r <= pre_failure_prob + numer_prob) {
81  // Selects the failure transition.
82  return failure_pos;
83  }
84 
85  // Selects after the failure transition.
86  aiter.Seek(failure_pos + 1);
87  // Adjusts for the incorrect mass in the cumulative distribution
88  // at the failure transition
89  r += failure_prob - numer_prob;
90  if (r < 0.0) return failure_pos;
91  return accumulator->LowerBound(-log(r), &aiter);
92  }
93 
94  int Seed() const { return seed_; }
95 
96  Label PhiLabel() const { return phi_label_; }
97 
98  private:
99  int seed_;
100  Label phi_label_;
101  fst::WeightConvert<Weight, fst::LogWeight> to_log_weight_;
102 };
103 
104 } // namespace sfst
105 
106 namespace fst {
107 
108 // Specialization for SFstArcSelector.
109 template <class A>
110 class ArcSampler<A, sfst::SFstArcSelector<A> > {
111  public:
113  typedef typename A::StateId StateId;
114  typedef typename A::Weight Weight;
115  typedef typename A::Label Label;
116  typedef CacheLogAccumulator<A> C;
117 
118  ArcSampler(const Fst<A> &fst, const S &arc_selector, int max_length = INT_MAX)
119  : fst_(fst),
120  arc_selector_(arc_selector),
121  max_length_(max_length),
122  phi_label_(arc_selector.PhiLabel()),
123  accumulator_(std::make_unique<C>()),
124  matcher_(fst_, MATCH_INPUT) {
125  accumulator_->Init(fst);
126  rng_.seed(arc_selector.Seed());
127  }
128 
129  ArcSampler(const ArcSampler<A, S> &sampler, const Fst<A> *fst = nullptr)
130  : fst_(fst ? *fst : sampler.fst_),
131  arc_selector_(sampler.arc_selector_),
132  max_length_(sampler.max_length_),
133  phi_label_(sampler.phi_label_),
134  matcher_(fst_, MATCH_INPUT) {
135  if (fst) {
136  accumulator_ = std::make_unique<C>();
137  accumulator_->Init(*fst);
138  } else { // shallow copy
139  accumulator_ = std::make_unique<C>(*sampler.accumulator_);
140  }
141  }
142 
143  bool Sample(const RandState<A> &rstate);
144 
145  bool Done() const { return sample_iter_ == sample_map_.end(); }
146  void Next() { ++sample_iter_; }
147  std::pair<size_t, size_t> Value() const { return *sample_iter_; }
148  void Reset() { sample_iter_ = sample_map_.begin(); }
149 
150  bool Error() const { return false; }
151 
152  private:
153  typedef std::mt19937 RNG;
154 
155  double TotalProb(StateId s) {
156  // Gets cumulative weight at the state.
157  ArcIterator< Fst<A> > aiter(fst_, s);
158  accumulator_->SetState(s);
159  Weight total_weight = accumulator_->Sum(fst_.Final(s), &aiter, 0,
160  fst_.NumArcs(s));
161  return exp(-to_log_weight_(total_weight).Value());
162  }
163 
164  double PreFailureProb(StateId s, ssize_t failure_pos) {
165  // Gets cumulative weight up to (but not including) the
166  // failure transition at the state.
167  if (failure_pos < 1) return 0.0;
168  ArcIterator< Fst<A> > aiter(fst_, s);
169  accumulator_->SetState(s);
170  Weight cumul_weight = accumulator_->Sum(Weight::Zero(), &aiter, 0,
171  failure_pos);
172  return exp(-to_log_weight_(cumul_weight).Value());
173  }
174 
175  // Determines if a sample candidate at a given arc position
176  // is disallowed because its label was present in a parent state.
177  bool ForbiddenPosition(size_t p, const RandState<A> &rstate);
178 
179  // Determines if a sample candidate of a given label is disallowed
180  // because its label was present in a parent state.
181  bool ForbiddenLabel(Label l, const RandState<A> &rstate);
182 
183  void MultinomialSample(const RandState<A> &rstate, Weight fail_weight);
184 
185  const Fst<A> &fst_;
186  const S &arc_selector_;
187  int max_length_;
188  Label phi_label_;
189 
190  // Stores (N, K) as described for Value().
191  std::map<size_t, size_t> sample_map_;
192  std::map<size_t, size_t>::const_iterator sample_iter_;
193  std::unique_ptr<C> accumulator_;
194 
195  RNG rng_;
196  std::vector<double> pr_; // multinomial parameters
197  std::vector<unsigned int> pos_; // sample positions
198  std::vector<unsigned int> n_; // sample counts
199  WeightConvert<Weight, Log64Weight> to_log_weight_;
200  WeightConvert<Log64Weight, Weight> to_weight_;
201  std::set<size_t> forbidden_positions_; // arc pos forbidden fo failure arcs
202  std::set<Label> forbidden_labels_; // labels forbidden for failure arcs
203  ExplicitMatcher<SortedMatcher<Fst<A>>> matcher_;
204 };
205 
206 template <class A>
207 bool ArcSampler<A, sfst::SFstArcSelector<A> >::Sample(
208  const RandState<A> &rstate) {
209  sample_map_.clear();
210  forbidden_positions_.clear();
211  forbidden_labels_.clear();
212  bool is_final = fst_.Final(rstate.state_id) != Weight::Zero();
213  size_t narcs_and_final = fst_.NumArcs(rstate.state_id) + is_final;
214 
215  if (narcs_and_final == 0 || rstate.length == max_length_) {
216  Reset();
217  return false;
218  }
219 
220  double failure_prob = 0.0; // failure probability
221  double total_prob = 1.0; // total sum of state transitions
222  double numer_prob = 0.0; // numerator in failure probablity
223  ssize_t failure_pos = -1; // arc position of failure transition
224  matcher_.SetState(rstate.state_id);
225 
226  if (phi_label_ != kNoLabel && matcher_.Find(phi_label_)) {
227  const A &arc = matcher_.Value();
228  Log64Weight failure_weight = to_log_weight_(arc.weight);
229  failure_prob = exp(-failure_weight.Value());
230  failure_pos = matcher_.GetMatcher()->Position();
231  total_prob = TotalProb(rstate.state_id);
232  // total = failure - numer + 1
233  numer_prob = 1.0 + failure_prob - total_prob;
234  }
235 
236  if (fst_.NumArcs(rstate.state_id) + 1 < rstate.nsamples) {
237  Weight numer_weight = to_weight_(-log(numer_prob));
238  MultinomialSample(rstate, numer_weight);
239  Reset();
240  return true;
241  }
242 
243  double pre_failure_prob = PreFailureProb(rstate.state_id, failure_pos);
244 
245  for (size_t i = 0; i < rstate.nsamples; ++i) {
246  size_t pos = 0;
247  do {
248  if (forbidden_positions_.size() == narcs_and_final)
249  return false; // non-coaccessible state
250  pos = arc_selector_(fst_, rstate.state_id, pre_failure_prob,
251  failure_prob, numer_prob, failure_pos,
252  accumulator_.get());
253  } while (ForbiddenPosition(pos, rstate));
254  ++sample_map_[pos];
255  }
256  Reset();
257  return true;
258 }
259 
260 template <class A>
261 bool ArcSampler<A, sfst::SFstArcSelector<A> >::ForbiddenPosition(
262  size_t pos, const RandState<A> &rstate) {
263  if (forbidden_positions_.count(pos) > 0)
264  return true;
265 
266  Label label = kNoLabel; // super-final label
267  if (pos < fst_.NumArcs(rstate.state_id)) {
268  ArcIterator<Fst<A>> aiter(fst_, rstate.state_id);
269  aiter.Seek(pos);
270  label = aiter.Value().ilabel;
271  }
272  bool forbidden_label = ForbiddenLabel(label, rstate);
273  if (forbidden_label)
274  forbidden_positions_.insert(pos);
275  return forbidden_label;
276 }
277 
278 template <class A>
279 bool ArcSampler<A, sfst::SFstArcSelector<A> >::ForbiddenLabel(
280  Label l, const RandState<A> &rstate) {
281  if (phi_label_ == kNoLabel || l == phi_label_ || l == 0)
282  return false;
283 
284  if (fst_.NumArcs(rstate.state_id) > rstate.nsamples) {
285  for (const RandState<A> *rs = &rstate;
286  rs->parent != nullptr;
287  rs = rs->parent) {
288  StateId parent_id = rs->parent->state_id;
289  ArcIterator < Fst<A> > aiter(fst_, parent_id);
290  aiter.Seek(rs->select);
291  if (aiter.Value().ilabel != phi_label_) // not failure transition
292  return false;
293 
294  if (l == kNoLabel) { // super-final label
295  if (fst_.Final(parent_id) != Weight::Zero())
296  return true;
297  } else {
298  matcher_.SetState(parent_id);
299  if (matcher_.Find(l))
300  return true;
301  }
302  }
303  return false;
304  } else {
305  if (forbidden_labels_.empty()) {
306  for (const RandState<A> *rs = &rstate;
307  rs->parent != nullptr;
308  rs = rs->parent) {
309  StateId parent_id = rs->parent->state_id;
310  ArcIterator < Fst<A> > aiter(fst_, parent_id);
311  aiter.Seek(rs->select);
312  if (aiter.Value().ilabel != phi_label_) // not failure transition
313  break;
314 
315  for (aiter.Reset(); !aiter.Done(); aiter.Next()) {
316  Label l = aiter.Value().ilabel;
317  if (l != phi_label_)
318  forbidden_labels_.insert(l);
319  }
320 
321  if (fst_.Final(parent_id) != Weight::Zero())
322  forbidden_labels_.insert(kNoLabel);
323  }
324  }
325  return forbidden_labels_.count(l) > 0;
326  }
327 }
328 
329 template <class A>
330 void ArcSampler<A, sfst::SFstArcSelector<A> >::MultinomialSample(
331  const RandState<A> &rstate, Weight fail_weight) {
332  pr_.clear();
333  pos_.clear();
334  n_.clear();
335  size_t pos = 0;
336  for (ArcIterator< Fst<A> > aiter(fst_, rstate.state_id);
337  !aiter.Done();
338  aiter.Next(), ++pos) {
339  const A &arc = aiter.Value();
340  if (!ForbiddenLabel(arc.ilabel, rstate)) {
341  pos_.push_back(pos);
342  Weight weight = arc.ilabel == phi_label_ ? fail_weight : arc.weight;
343  pr_.push_back(exp(-to_log_weight_(weight).Value()));
344  }
345  }
346  if (fst_.Final(rstate.state_id) != Weight::Zero()
347  && !ForbiddenLabel(kNoLabel, rstate)) {
348  pos_.push_back(pos);
349  pr_.push_back(exp(-to_log_weight_(fst_.Final(rstate.state_id)).Value()));
350  }
351 
352  if (rstate.nsamples < std::numeric_limits<RNG::result_type>::max()) {
353  n_.resize(pr_.size());
354  OneMultinomialSample(pr_, rstate.nsamples, &n_, &rng_);
355  for (size_t i = 0; i < n_.size(); ++i)
356  if (n_[i] != 0) sample_map_[pos_[i]] = n_[i];
357  } else {
358  for (size_t i = 0; i < pr_.size(); ++i)
359  sample_map_[pos_[i]] = ceil(pr_[i] * rstate.nsamples);
360  }
361 }
362 
363 // Epsilons/phi removes and minimizes the result of weighted SFST randgen.
364 template <class A>
365  void RandMinimize(MutableFst<A> *fst, typename A::Label phi_label) {
366  using Label = typename A::Label;
367 
368  WeightConvertMapper<A, Log64Arc> to_log_mapper;
369  WeightConvertMapper<Log64Arc, A> from_log_mapper;
370  VectorFst<Log64Arc> lfst, dfst;
371  ArcMap(*fst, &lfst, to_log_mapper);
372  if (phi_label != kNoLabel && phi_label != 0) {
373  std::vector<std::pair<Label, Label>> relab = {{phi_label, 0}};
374  Relabel(&lfst, relab, relab);
375  }
376  RmEpsilon(&lfst);
377  Determinize(lfst, &dfst);
378  Minimize(&dfst);
379  ArcMap(dfst, fst, from_log_mapper);
380  ArcSort(fst, ILabelCompare<A>());
381 }
382 
383 } // namespace fst
384 
385 #endif // NLP_GRM2_SFST_RANDGEN_H_
A::StateId StateId
Definition: randgen.h:58
A::Weight Weight
Definition: randgen.h:60
Definition: perplexity.h:32
std::pair< size_t, size_t > Value() const
Definition: randgen.h:147
Definition: sfstinfo.cc:39
void RandMinimize(MutableFst< A > *fst, typename A::Label phi_label)
Definition: randgen.h:365
int Seed() const
Definition: randgen.h:94
size_t operator()(const fst::Fst< A > &fst, StateId s, double pre_failure_prob, double failure_prob, double numer_prob, ssize_t failure_pos, fst::CacheLogAccumulator< A > *accumulator) const
Definition: randgen.h:67
Label PhiLabel() const
Definition: randgen.h:96
ArcSampler(const Fst< A > &fst, const S &arc_selector, int max_length=INT_MAX)
Definition: randgen.h:118
SFstArcSelector(int seed=time(nullptr), Label phi_label=fst::kNoLabel)
Definition: randgen.h:62
ArcSampler(const ArcSampler< A, S > &sampler, const Fst< A > *fst=nullptr)
Definition: randgen.h:129