19 #ifndef NGRAM_NGRAM_RANDGEN_H_ 20 #define NGRAM_NGRAM_RANDGEN_H_ 22 #include <sys/types.h> 29 #include <fst/accumulator.h> 33 #include <gsl/gsl_randist.h> 34 #include <gsl/gsl_rng.h> 38 #include <fst/randgen.h> 58 fst::CacheLogAccumulator<A> *accumulator)
const {
59 double r = rand() / (RAND_MAX + 1.0);
62 double z = r + total_prob - 1.0;
63 if (z <= 0.0)
return 0;
64 fst::ArcIterator<fst::Fst<A> > aiter(fst, s);
65 return accumulator->LowerBound(-log(z), &aiter);
68 int Seed()
const {
return seed_; }
72 fst::WeightConvert<Weight, fst::LogWeight> to_log_weight_;
86 typedef typename A::Label
Label;
87 typedef CacheLogAccumulator<A>
C;
90 int max_length = INT_MAX)
92 arc_selector_(arc_selector),
93 max_length_(max_length),
94 matcher_(fst_, MATCH_INPUT) {
96 if (!fst_.Properties(kILabelSorted,
true))
97 NGRAMERROR() <<
"ArcSampler: is not input-label sorted";
98 accumulator_.reset(
new C());
99 accumulator_->Init(fst);
101 rng_ = gsl_rng_alloc(gsl_rng_taus);
102 gsl_rng_set(rng_, arc_selector.
Seed());
107 : fst_(
fst ? *
fst : sampler.fst_),
108 arc_selector_(sampler.arc_selector_),
109 max_length_(sampler.max_length_),
110 matcher_(fst_, MATCH_INPUT) {
112 accumulator_.reset(
new C());
113 accumulator_->Init(*
fst);
115 accumulator_.reset(
new C(*sampler.accumulator_));
125 bool Sample(
const RandState<A> &rstate) {
127 forbidden_labels_.clear();
129 if ((fst_.NumArcs(rstate.state_id) == 0 &&
130 fst_.Final(rstate.state_id) == Weight::Zero()) ||
131 rstate.length == max_length_) {
136 double total_prob = TotalProb(rstate.state_id);
139 if (fst_.NumArcs(rstate.state_id) + 1 < rstate.nsamples) {
140 Weight numer_weight, denom_weight;
141 BackoffWeight(rstate.state_id, total_prob, &numer_weight, &denom_weight);
142 MultinomialSample(rstate, numer_weight);
148 fst::ArcIterator<fst::Fst<A> > aiter(fst_, rstate.state_id);
150 for (
size_t i = 0; i < rstate.nsamples; ++i) {
152 Label label = kNoLabel;
154 pos = arc_selector_(fst_, rstate.state_id, total_prob,
156 if (pos < fst_.NumArcs(rstate.state_id)) {
158 label = aiter.Value().ilabel;
162 }
while (ForbiddenLabel(label, rstate));
169 bool Done()
const {
return sample_iter_ == sample_map_.end(); }
170 void Next() { ++sample_iter_; }
171 std::pair<size_t, size_t>
Value()
const {
return *sample_iter_; }
172 void Reset() { sample_iter_ = sample_map_.begin(); }
173 bool Error()
const {
return false; }
176 double TotalProb(StateId s) {
178 fst::ArcIterator<fst::Fst<A> > aiter(fst_, s);
179 accumulator_->SetState(s);
180 Weight total_weight =
181 accumulator_->Sum(fst_.Final(s), &aiter, 0, fst_.NumArcs(s));
182 return exp(-to_log_weight_(total_weight).Value());
185 void BackoffWeight(StateId s,
double total_prob, Weight *numer_weight,
186 Weight *denom_weight);
189 void MultinomialSample(
const RandState<A> &rstate, Weight fail_weight);
192 bool ForbiddenLabel(Label l,
const RandState<A> &rstate);
194 const fst::Fst<A> &fst_;
195 const S &arc_selector_;
199 std::map<size_t, size_t> sample_map_;
200 std::map<size_t, size_t>::const_iterator sample_iter_;
201 std::unique_ptr<C> accumulator_;
205 std::vector<double> pr_;
206 std::vector<unsigned int> pos_;
207 std::vector<unsigned int> n_;
210 WeightConvert<fst::Log64Weight, Weight> to_weight_;
211 WeightConvert<Weight, fst::Log64Weight> to_log_weight_;
214 Matcher<fst::Fst<A> > matcher_;
220 void ArcSampler<A, ngram::NGramArcSelector<A> >::BackoffWeight(
221 StateId s,
double total, Weight *numer_weight, Weight *denom_weight) {
223 double backoff = 0.0;
224 matcher_.SetState(s);
226 for (; !matcher_.Done(); matcher_.Next()) {
227 const A &arc = matcher_.Value();
228 if (arc.ilabel != kNoLabel) {
229 backoff = exp(-to_log_weight_(arc.weight).Value());
234 if (backoff == 0.0) {
235 *numer_weight = Weight::Zero();
236 *denom_weight = Weight::Zero();
241 double numer = 1.0 + backoff - total;
242 *numer_weight = to_weight_(-log(numer));
245 double denom = numer / backoff;
246 *denom_weight = to_weight_(-log(denom));
251 void ArcSampler<A, ngram::NGramArcSelector<A> >::MultinomialSample(
252 const RandState<A> &rstate, Weight fail_weight) {
257 for (fst::ArcIterator<fst::Fst<A> > aiter(fst_, rstate.state_id);
258 !aiter.Done(); aiter.Next(), ++pos) {
259 const A &arc = aiter.Value();
260 if (!ForbiddenLabel(arc.ilabel, rstate)) {
262 Weight weight = arc.ilabel == 0 ? fail_weight : arc.weight;
263 pr_.push_back(exp(-to_log_weight_(weight).Value()));
266 if (fst_.Final(rstate.state_id) != Weight::Zero() &&
267 !ForbiddenLabel(kNoLabel, rstate)) {
269 pr_.push_back(exp(-to_log_weight_(fst_.Final(rstate.state_id)).Value()));
272 if (rstate.nsamples < UINT_MAX) {
273 n_.resize(pr_.size());
274 gsl_ran_multinomial(rng_, pr_.size(), rstate.nsamples, &(pr_[0]), &(n_[0]));
275 for (
size_t i = 0; i < n_.size(); ++i)
276 if (n_[i] != 0) sample_map_[pos_[i]] = n_[i];
278 for (
size_t i = 0; i < pr_.size(); ++i)
279 sample_map_[pos_[i]] = ceil(pr_[i] * rstate.nsamples);
285 bool ArcSampler<A, ngram::NGramArcSelector<A> >::ForbiddenLabel(
286 Label l,
const RandState<A> &rstate) {
287 if (l == 0)
return false;
289 if (fst_.NumArcs(rstate.state_id) > rstate.nsamples) {
290 for (
const RandState<A> *rs = &rstate; rs->parent !=
nullptr;
292 StateId parent_id = rs->parent->state_id;
293 fst::ArcIterator<fst::Fst<A> > aiter(fst_, parent_id);
294 aiter.Seek(rs->select);
295 if (aiter.Value().ilabel != 0)
299 return fst_.Final(parent_id) != Weight::Zero();
301 matcher_.SetState(parent_id);
302 if (matcher_.Find(l))
return true;
307 if (forbidden_labels_.empty()) {
308 for (
const RandState<A> *rs = &rstate; rs->parent !=
nullptr;
310 StateId parent_id = rs->parent->state_id;
311 fst::ArcIterator<fst::Fst<A> > aiter(fst_, parent_id);
312 aiter.Seek(rs->select);
313 if (aiter.Value().ilabel != 0)
316 for (aiter.Reset(); !aiter.Done(); aiter.Next()) {
317 Label l = aiter.Value().ilabel;
318 if (l != 0) forbidden_labels_.insert(l);
321 if (fst_.Final(parent_id) != Weight::Zero())
322 forbidden_labels_.insert(kNoLabel);
325 return forbidden_labels_.count(l) > 0;
331 #endif // NGRAM_NGRAM_RANDGEN_H_
CacheLogAccumulator< A > C
bool Sample(const RandState< A > &rstate)
ArcSampler(const ArcSampler< A, S > &sampler, const fst::Fst< A > *fst=0)
size_t operator()(const fst::Fst< A > &fst, StateId s, double total_prob, fst::CacheLogAccumulator< A > *accumulator) const
NGramArcSelector(int seed=time(nullptr)+getpid())
ngram::NGramArcSelector< A > S
ArcSampler(const fst::Fst< A > &fst, const S &arc_selector, int max_length=INT_MAX)
std::pair< size_t, size_t > Value() const