16 #ifndef NLP_GRM2_SFST_RANDGEN_H_ 17 #define NLP_GRM2_SFST_RANDGEN_H_ 19 #include <sys/types.h> 34 #include <fst/accumulator.h> 35 #include <fst/arc-map.h> 37 #include <fst/arcsort.h> 38 #include <fst/determinize.h> 40 #include <fst/minimize.h> 41 #include <fst/mutable-fst.h> 43 #include <fst/relabel.h> 44 #include <fst/rmepsilon.h> 45 #include <fst/weight.h> 59 typedef typename A::Label
Label;
63 Label phi_label = fst::kNoLabel)
64 : seed_(seed), phi_label_(phi_label) { srand(seed); }
68 double pre_failure_prob,
double failure_prob,
69 double numer_prob, ssize_t failure_pos,
70 fst::CacheLogAccumulator<A> *accumulator)
const {
71 double r = rand()/(RAND_MAX + 1.0);
72 accumulator->SetState(s);
73 fst::ArcIterator<fst::Fst<A>> aiter(fst, s);
75 if (failure_pos == -1 || r <= pre_failure_prob) {
77 return accumulator->LowerBound(-log(r), &aiter);
80 if (r <= pre_failure_prob + numer_prob) {
86 aiter.Seek(failure_pos + 1);
89 r += failure_prob - numer_prob;
90 if (r < 0.0)
return failure_pos;
91 return accumulator->LowerBound(-log(r), &aiter);
94 int Seed()
const {
return seed_; }
101 fst::WeightConvert<Weight, fst::LogWeight> to_log_weight_;
116 typedef CacheLogAccumulator<A>
C;
118 ArcSampler(
const Fst<A> &
fst,
const S &arc_selector,
int max_length = INT_MAX)
120 arc_selector_(arc_selector),
121 max_length_(max_length),
122 phi_label_(arc_selector.
PhiLabel()),
123 accumulator_(std::make_unique<C>()),
124 matcher_(fst_, MATCH_INPUT) {
125 accumulator_->Init(fst);
126 rng_.seed(arc_selector.
Seed());
130 : fst_(
fst ? *
fst : sampler.fst_),
131 arc_selector_(sampler.arc_selector_),
132 max_length_(sampler.max_length_),
133 phi_label_(sampler.phi_label_),
134 matcher_(fst_, MATCH_INPUT) {
136 accumulator_ = std::make_unique<C>();
137 accumulator_->Init(*
fst);
139 accumulator_ = std::make_unique<C>(*sampler.accumulator_);
143 bool Sample(
const RandState<A> &rstate);
145 bool Done()
const {
return sample_iter_ == sample_map_.end(); }
146 void Next() { ++sample_iter_; }
147 std::pair<size_t, size_t>
Value()
const {
return *sample_iter_; }
148 void Reset() { sample_iter_ = sample_map_.begin(); }
150 bool Error()
const {
return false; }
153 typedef std::mt19937 RNG;
155 double TotalProb(StateId s) {
157 ArcIterator< Fst<A> > aiter(fst_, s);
158 accumulator_->SetState(s);
159 Weight total_weight = accumulator_->Sum(fst_.Final(s), &aiter, 0,
161 return exp(-to_log_weight_(total_weight).Value());
164 double PreFailureProb(StateId s, ssize_t failure_pos) {
167 if (failure_pos < 1)
return 0.0;
168 ArcIterator< Fst<A> > aiter(fst_, s);
169 accumulator_->SetState(s);
170 Weight cumul_weight = accumulator_->Sum(Weight::Zero(), &aiter, 0,
172 return exp(-to_log_weight_(cumul_weight).Value());
177 bool ForbiddenPosition(
size_t p,
const RandState<A> &rstate);
181 bool ForbiddenLabel(Label l,
const RandState<A> &rstate);
183 void MultinomialSample(
const RandState<A> &rstate, Weight fail_weight);
186 const S &arc_selector_;
191 std::map<size_t, size_t> sample_map_;
192 std::map<size_t, size_t>::const_iterator sample_iter_;
193 std::unique_ptr<C> accumulator_;
196 std::vector<double> pr_;
197 std::vector<unsigned int> pos_;
198 std::vector<unsigned int> n_;
199 WeightConvert<Weight, Log64Weight> to_log_weight_;
200 WeightConvert<Log64Weight, Weight> to_weight_;
201 std::set<size_t> forbidden_positions_;
202 std::set<Label> forbidden_labels_;
203 ExplicitMatcher<SortedMatcher<Fst<A>>> matcher_;
207 bool ArcSampler<A, sfst::SFstArcSelector<A> >::Sample(
208 const RandState<A> &rstate) {
210 forbidden_positions_.clear();
211 forbidden_labels_.clear();
212 bool is_final = fst_.Final(rstate.state_id) != Weight::Zero();
213 size_t narcs_and_final = fst_.NumArcs(rstate.state_id) + is_final;
215 if (narcs_and_final == 0 || rstate.length == max_length_) {
220 double failure_prob = 0.0;
221 double total_prob = 1.0;
222 double numer_prob = 0.0;
223 ssize_t failure_pos = -1;
224 matcher_.SetState(rstate.state_id);
226 if (phi_label_ != kNoLabel && matcher_.Find(phi_label_)) {
227 const A &arc = matcher_.Value();
228 Log64Weight failure_weight = to_log_weight_(arc.weight);
229 failure_prob = exp(-failure_weight.Value());
230 failure_pos = matcher_.GetMatcher()->Position();
231 total_prob = TotalProb(rstate.state_id);
233 numer_prob = 1.0 + failure_prob - total_prob;
236 if (fst_.NumArcs(rstate.state_id) + 1 < rstate.nsamples) {
237 Weight numer_weight = to_weight_(-log(numer_prob));
238 MultinomialSample(rstate, numer_weight);
243 double pre_failure_prob = PreFailureProb(rstate.state_id, failure_pos);
245 for (
size_t i = 0; i < rstate.nsamples; ++i) {
248 if (forbidden_positions_.size() == narcs_and_final)
250 pos = arc_selector_(fst_, rstate.state_id, pre_failure_prob,
251 failure_prob, numer_prob, failure_pos,
253 }
while (ForbiddenPosition(pos, rstate));
261 bool ArcSampler<A, sfst::SFstArcSelector<A> >::ForbiddenPosition(
262 size_t pos,
const RandState<A> &rstate) {
263 if (forbidden_positions_.count(pos) > 0)
266 Label label = kNoLabel;
267 if (pos < fst_.NumArcs(rstate.state_id)) {
268 ArcIterator<Fst<A>> aiter(fst_, rstate.state_id);
270 label = aiter.Value().ilabel;
272 bool forbidden_label = ForbiddenLabel(label, rstate);
274 forbidden_positions_.insert(pos);
275 return forbidden_label;
279 bool ArcSampler<A, sfst::SFstArcSelector<A> >::ForbiddenLabel(
280 Label l,
const RandState<A> &rstate) {
281 if (phi_label_ == kNoLabel || l == phi_label_ || l == 0)
284 if (fst_.NumArcs(rstate.state_id) > rstate.nsamples) {
285 for (
const RandState<A> *rs = &rstate;
286 rs->parent !=
nullptr;
288 StateId parent_id = rs->parent->state_id;
289 ArcIterator < Fst<A> > aiter(fst_, parent_id);
290 aiter.Seek(rs->select);
291 if (aiter.Value().ilabel != phi_label_)
295 if (fst_.Final(parent_id) != Weight::Zero())
298 matcher_.SetState(parent_id);
299 if (matcher_.Find(l))
305 if (forbidden_labels_.empty()) {
306 for (
const RandState<A> *rs = &rstate;
307 rs->parent !=
nullptr;
309 StateId parent_id = rs->parent->state_id;
310 ArcIterator < Fst<A> > aiter(fst_, parent_id);
311 aiter.Seek(rs->select);
312 if (aiter.Value().ilabel != phi_label_)
315 for (aiter.Reset(); !aiter.Done(); aiter.Next()) {
316 Label l = aiter.Value().ilabel;
318 forbidden_labels_.insert(l);
321 if (fst_.Final(parent_id) != Weight::Zero())
322 forbidden_labels_.insert(kNoLabel);
325 return forbidden_labels_.count(l) > 0;
330 void ArcSampler<A, sfst::SFstArcSelector<A> >::MultinomialSample(
331 const RandState<A> &rstate,
Weight fail_weight) {
336 for (ArcIterator< Fst<A> > aiter(fst_, rstate.state_id);
338 aiter.Next(), ++pos) {
339 const A &arc = aiter.Value();
340 if (!ForbiddenLabel(arc.ilabel, rstate)) {
342 Weight weight = arc.ilabel == phi_label_ ? fail_weight : arc.weight;
343 pr_.push_back(exp(-to_log_weight_(weight).Value()));
346 if (fst_.Final(rstate.state_id) != Weight::Zero()
347 && !ForbiddenLabel(kNoLabel, rstate)) {
349 pr_.push_back(exp(-to_log_weight_(fst_.Final(rstate.state_id)).Value()));
352 if (rstate.nsamples < std::numeric_limits<RNG::result_type>::max()) {
353 n_.resize(pr_.size());
354 OneMultinomialSample(pr_, rstate.nsamples, &n_, &rng_);
355 for (
size_t i = 0; i < n_.size(); ++i)
356 if (n_[i] != 0) sample_map_[pos_[i]] = n_[i];
358 for (
size_t i = 0; i < pr_.size(); ++i)
359 sample_map_[pos_[i]] = ceil(pr_[i] * rstate.nsamples);
366 using Label =
typename A::Label;
368 WeightConvertMapper<A, Log64Arc> to_log_mapper;
369 WeightConvertMapper<Log64Arc, A> from_log_mapper;
370 VectorFst<Log64Arc> lfst, dfst;
371 ArcMap(*fst, &lfst, to_log_mapper);
372 if (phi_label != kNoLabel && phi_label != 0) {
373 std::vector<std::pair<Label, Label>> relab = {{phi_label, 0}};
374 Relabel(&lfst, relab, relab);
377 Determinize(lfst, &dfst);
379 ArcMap(dfst, fst, from_log_mapper);
380 ArcSort(fst, ILabelCompare<A>());
385 #endif // NLP_GRM2_SFST_RANDGEN_H_
std::pair< size_t, size_t > Value() const
CacheLogAccumulator< A > C
void RandMinimize(MutableFst< A > *fst, typename A::Label phi_label)
size_t operator()(const fst::Fst< A > &fst, StateId s, double pre_failure_prob, double failure_prob, double numer_prob, ssize_t failure_pos, fst::CacheLogAccumulator< A > *accumulator) const
sfst::SFstArcSelector< A > S
ArcSampler(const Fst< A > &fst, const S &arc_selector, int max_length=INT_MAX)
SFstArcSelector(int seed=time(nullptr), Label phi_label=fst::kNoLabel)
ArcSampler(const ArcSampler< A, S > &sampler, const Fst< A > *fst=nullptr)