NGram  ngram-1.3.15
OpenGrm-NGram library
ngram-randgen.h
Go to the documentation of this file.
1 // Copyright 2005-2013 Brian Roark
2 // Copyright 2005-2020 Google LLC
3 //
4 // Licensed under the Apache License, Version 2.0 (the 'License');
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an 'AS IS' BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 // Classes to generate random sentences from an LM or more generally
17 // paths through any FST where epsilons are treated as failure transitions.
18 
19 #ifndef NGRAM_NGRAM_RANDGEN_H_
20 #define NGRAM_NGRAM_RANDGEN_H_
21 
22 #include <sys/types.h>
23 #include <unistd.h>
24 
25 #include <ctime>
26 #include <map>
27 #include <utility>
28 
29 #include <fst/accumulator.h>
30 
31 // Faster multinomial sampling possible if Gnu Scientific Library available.
32 #ifdef HAVE_GSL
33 #include <gsl/gsl_randist.h>
34 #include <gsl/gsl_rng.h>
35 #endif // HAVE_GSL
36 
37 #include <fst/fst.h>
38 #include <fst/randgen.h>
39 #include <ngram/util.h>
40 
41 namespace ngram {
42 
43 // Same as FastLogProbArcSelector but treats *all* epsilons as
44 // failure transitions that have a backoff weight. The LM must
45 // be fully normalized.
46 template <class A>
48  public:
49  typedef typename A::StateId StateId;
50  typedef typename A::Weight Weight;
51 
52  explicit NGramArcSelector(int seed = time(nullptr) + getpid()) : seed_(seed) {
53  srand(seed);
54  }
55 
56  // Samples one transition.
57  size_t operator()(const fst::Fst<A> &fst, StateId s, double total_prob,
58  fst::CacheLogAccumulator<A> *accumulator) const {
59  double r = rand() / (RAND_MAX + 1.0);
60  // In effect, subtract out excess mass from the cumulative distribution.
61  // Requires the backoff epsilon be the initial transition.
62  double z = r + total_prob - 1.0;
63  if (z <= 0.0) return 0;
64  fst::ArcIterator<fst::Fst<A> > aiter(fst, s);
65  return accumulator->LowerBound(-log(z), &aiter);
66  }
67 
68  int Seed() const { return seed_; }
69 
70  private:
71  int seed_;
72  fst::WeightConvert<Weight, fst::LogWeight> to_log_weight_;
73 };
74 
75 } // namespace ngram
76 
77 namespace fst {
78 
79 // Specialization for NGramArcSelector.
80 template <class A>
81 class ArcSampler<A, ngram::NGramArcSelector<A> > {
82  public:
84  typedef typename A::StateId StateId;
85  typedef typename A::Weight Weight;
86  typedef typename A::Label Label;
87  typedef CacheLogAccumulator<A> C;
88 
89  ArcSampler(const fst::Fst<A> &fst, const S &arc_selector,
90  int max_length = INT_MAX)
91  : fst_(fst),
92  arc_selector_(arc_selector),
93  max_length_(max_length),
94  matcher_(fst_, MATCH_INPUT) {
95  // Ensure the input FST has any epsilons as the initial transitions.
96  if (!fst_.Properties(kILabelSorted, true))
97  NGRAMERROR() << "ArcSampler: is not input-label sorted";
98  accumulator_.reset(new C());
99  accumulator_->Init(fst);
100 #ifdef HAVE_GSL
101  rng_ = gsl_rng_alloc(gsl_rng_taus);
102  gsl_rng_set(rng_, arc_selector.Seed());
103 #endif // HAVE_GSL
104  }
105 
106  ArcSampler(const ArcSampler<A, S> &sampler, const fst::Fst<A> *fst = 0)
107  : fst_(fst ? *fst : sampler.fst_),
108  arc_selector_(sampler.arc_selector_),
109  max_length_(sampler.max_length_),
110  matcher_(fst_, MATCH_INPUT) {
111  if (fst) {
112  accumulator_.reset(new C());
113  accumulator_->Init(*fst);
114  } else { // shallow copy
115  accumulator_.reset(new C(*sampler.accumulator_));
116  }
117  }
118 
120 #ifdef HAVE_GSL
121  gsl_rng_free(rng_);
122 #endif // HAVE_GSL
123  }
124 
125  bool Sample(const RandState<A> &rstate) {
126  sample_map_.clear();
127  forbidden_labels_.clear();
128 
129  if ((fst_.NumArcs(rstate.state_id) == 0 &&
130  fst_.Final(rstate.state_id) == Weight::Zero()) ||
131  rstate.length == max_length_) {
132  Reset();
133  return false;
134  }
135 
136  double total_prob = TotalProb(rstate.state_id);
137 
138 #ifdef HAVE_GSL
139  if (fst_.NumArcs(rstate.state_id) + 1 < rstate.nsamples) {
140  Weight numer_weight, denom_weight;
141  BackoffWeight(rstate.state_id, total_prob, &numer_weight, &denom_weight);
142  MultinomialSample(rstate, numer_weight);
143  Reset();
144  return true;
145  }
146 #endif // HAVE_GSL
147 
148  fst::ArcIterator<fst::Fst<A> > aiter(fst_, rstate.state_id);
149 
150  for (size_t i = 0; i < rstate.nsamples; ++i) {
151  size_t pos = 0;
152  Label label = kNoLabel;
153  do {
154  pos = arc_selector_(fst_, rstate.state_id, total_prob,
155  accumulator_.get());
156  if (pos < fst_.NumArcs(rstate.state_id)) {
157  aiter.Seek(pos);
158  label = aiter.Value().ilabel;
159  } else {
160  label = kNoLabel;
161  }
162  } while (ForbiddenLabel(label, rstate));
163  ++sample_map_[pos];
164  }
165  Reset();
166  return true;
167  }
168 
169  bool Done() const { return sample_iter_ == sample_map_.end(); }
170  void Next() { ++sample_iter_; }
171  std::pair<size_t, size_t> Value() const { return *sample_iter_; }
172  void Reset() { sample_iter_ = sample_map_.begin(); }
173  bool Error() const { return false; }
174 
175  private:
176  double TotalProb(StateId s) {
177  // Get cumulative weight at the state.
178  fst::ArcIterator<fst::Fst<A> > aiter(fst_, s);
179  accumulator_->SetState(s);
180  Weight total_weight =
181  accumulator_->Sum(fst_.Final(s), &aiter, 0, fst_.NumArcs(s));
182  return exp(-to_log_weight_(total_weight).Value());
183  }
184 
185  void BackoffWeight(StateId s, double total_prob, Weight *numer_weight,
186  Weight *denom_weight);
187 
188 #ifdef HAVE_GSL
189  void MultinomialSample(const RandState<A> &rstate, Weight fail_weight);
190 #endif // HAVE_GSL
191 
192  bool ForbiddenLabel(Label l, const RandState<A> &rstate);
193 
194  const fst::Fst<A> &fst_;
195  const S &arc_selector_;
196  int max_length_;
197 
198  // Stores (N, K) as described for Value().
199  std::map<size_t, size_t> sample_map_;
200  std::map<size_t, size_t>::const_iterator sample_iter_;
201  std::unique_ptr<C> accumulator_;
202 
203 #ifdef HAVE_GSL
204  gsl_rng *rng_; // GNU Sci Lib random number generator
205  std::vector<double> pr_; // multinomial parameters
206  std::vector<unsigned int> pos_; // sample positions
207  std::vector<unsigned int> n_; // sample counts
208 #endif // HAVE_GSL
209 
210  WeightConvert<fst::Log64Weight, Weight> to_weight_;
211  WeightConvert<Weight, fst::Log64Weight> to_log_weight_;
212  std::set<Label>
213  forbidden_labels_; // labels forbidden for failure transitions
214  Matcher<fst::Fst<A> > matcher_;
215 };
216 
217 // Finds and decomposes the backoff probability into its numerator and
218 // denominator.
219 template <class A>
220 void ArcSampler<A, ngram::NGramArcSelector<A> >::BackoffWeight(
221  StateId s, double total, Weight *numer_weight, Weight *denom_weight) {
222  // Get backoff prob.
223  double backoff = 0.0;
224  matcher_.SetState(s);
225  matcher_.Find(0);
226  for (; !matcher_.Done(); matcher_.Next()) {
227  const A &arc = matcher_.Value();
228  if (arc.ilabel != kNoLabel) { // not an implicit epsilon loop
229  backoff = exp(-to_log_weight_(arc.weight).Value());
230  break;
231  }
232  }
233 
234  if (backoff == 0.0) { // no backoff transition
235  *numer_weight = Weight::Zero();
236  *denom_weight = Weight::Zero();
237  return;
238  }
239 
240  // total = 1 - numer + backoff
241  double numer = 1.0 + backoff - total;
242  *numer_weight = to_weight_(-log(numer));
243 
244  // backoff = numer/denom
245  double denom = numer / backoff;
246  *denom_weight = to_weight_(-log(denom));
247 }
248 
249 #ifdef HAVE_GSL
250 template <class A>
251 void ArcSampler<A, ngram::NGramArcSelector<A> >::MultinomialSample(
252  const RandState<A> &rstate, Weight fail_weight) {
253  pr_.clear();
254  pos_.clear();
255  n_.clear();
256  size_t pos = 0;
257  for (fst::ArcIterator<fst::Fst<A> > aiter(fst_, rstate.state_id);
258  !aiter.Done(); aiter.Next(), ++pos) {
259  const A &arc = aiter.Value();
260  if (!ForbiddenLabel(arc.ilabel, rstate)) {
261  pos_.push_back(pos);
262  Weight weight = arc.ilabel == 0 ? fail_weight : arc.weight;
263  pr_.push_back(exp(-to_log_weight_(weight).Value()));
264  }
265  }
266  if (fst_.Final(rstate.state_id) != Weight::Zero() &&
267  !ForbiddenLabel(kNoLabel, rstate)) {
268  pos_.push_back(pos);
269  pr_.push_back(exp(-to_log_weight_(fst_.Final(rstate.state_id)).Value()));
270  }
271 
272  if (rstate.nsamples < UINT_MAX) {
273  n_.resize(pr_.size());
274  gsl_ran_multinomial(rng_, pr_.size(), rstate.nsamples, &(pr_[0]), &(n_[0]));
275  for (size_t i = 0; i < n_.size(); ++i)
276  if (n_[i] != 0) sample_map_[pos_[i]] = n_[i];
277  } else {
278  for (size_t i = 0; i < pr_.size(); ++i)
279  sample_map_[pos_[i]] = ceil(pr_[i] * rstate.nsamples);
280  }
281 }
282 #endif // HAVE_GSL
283 
284 template <class A>
285 bool ArcSampler<A, ngram::NGramArcSelector<A> >::ForbiddenLabel(
286  Label l, const RandState<A> &rstate) {
287  if (l == 0) return false;
288 
289  if (fst_.NumArcs(rstate.state_id) > rstate.nsamples) {
290  for (const RandState<A> *rs = &rstate; rs->parent != nullptr;
291  rs = rs->parent) {
292  StateId parent_id = rs->parent->state_id;
293  fst::ArcIterator<fst::Fst<A> > aiter(fst_, parent_id);
294  aiter.Seek(rs->select);
295  if (aiter.Value().ilabel != 0) // not backoff transition
296  return false;
297 
298  if (l == kNoLabel) { // super-final label
299  return fst_.Final(parent_id) != Weight::Zero();
300  } else {
301  matcher_.SetState(parent_id);
302  if (matcher_.Find(l)) return true;
303  }
304  }
305  return false;
306  } else {
307  if (forbidden_labels_.empty()) {
308  for (const RandState<A> *rs = &rstate; rs->parent != nullptr;
309  rs = rs->parent) {
310  StateId parent_id = rs->parent->state_id;
311  fst::ArcIterator<fst::Fst<A> > aiter(fst_, parent_id);
312  aiter.Seek(rs->select);
313  if (aiter.Value().ilabel != 0) // not backoff transition
314  break;
315 
316  for (aiter.Reset(); !aiter.Done(); aiter.Next()) {
317  Label l = aiter.Value().ilabel;
318  if (l != 0) forbidden_labels_.insert(l);
319  }
320 
321  if (fst_.Final(parent_id) != Weight::Zero())
322  forbidden_labels_.insert(kNoLabel);
323  }
324  }
325  return forbidden_labels_.count(l) > 0;
326  }
327 }
328 
329 } // namespace fst
330 
331 #endif // NGRAM_NGRAM_RANDGEN_H_
bool Sample(const RandState< A > &rstate)
ArcSampler(const ArcSampler< A, S > &sampler, const fst::Fst< A > *fst=0)
size_t operator()(const fst::Fst< A > &fst, StateId s, double total_prob, fst::CacheLogAccumulator< A > *accumulator) const
Definition: ngram-randgen.h:57
NGramArcSelector(int seed=time(nullptr)+getpid())
Definition: ngram-randgen.h:52
#define NGRAMERROR()
Definition: util.h:26
ArcSampler(const fst::Fst< A > &fst, const S &arc_selector, int max_length=INT_MAX)
Definition: ngram-randgen.h:89
std::pair< size_t, size_t > Value() const