16 #ifndef NLP_GRM2_SFST_PERPLEXITY_H_ 17 #define NLP_GRM2_SFST_PERPLEXITY_H_ 25 #include <fst/expectation-weight.h> 26 #include <fst/signed-log-weight.h> 41 inline Real64Weight
Times(SignedLog64Weight w1, Real64Weight w2) {
42 using Limits = fst::FloatLimits<double>;
43 if (w1 == SignedLog64Weight::Zero() && w2.Value() == Limits::PosInfinity())
44 return Real64Weight::Zero();
45 double s1 = w1.Value1().Value();
46 double l1 = w1.Value2().Value();
47 double p1 = s1 * exp(-l1);
48 double e = w2.Value();
49 return Real64Weight(p1 * e);
52 inline Real64Weight
Times(Real64Weight w1, SignedLog64Weight w2) {
53 using Limits = fst::FloatLimits<double>;
54 if (w2 == SignedLog64Weight::Zero() && w1.Value() == Limits::PosInfinity())
55 return Real64Weight::Zero();
56 double e = w1.Value();
57 double s2 = w2.Value1().Value();
58 double l2 = w2.Value2().Value();
59 double p2 = s2 * exp(-l2);
60 return Real64Weight(e * p2);
64 template <
class Weight>
81 if (w == Weight::Zero())
82 return EWeight::Zero();
85 RWeight rw = slw.Value2().Value();
90 case CROSS_ENTROPY_SOURCE:
91 return EWeight(slw, RWeight::Zero());
93 case CROSS_ENTROPY_TARGET:
94 return EWeight(SLWeight::One(), rw);
100 fst::WeightConvert<Weight, SLWeight> to_sl_;
110 return w.Value1().Value2();
124 return EWeight(slw, RWeight::Zero());
130 const SignedLog64Weight slz = SignedLog64Weight::Zero();
131 const Real64Weight rlz = Real64Weight::Zero();
133 Minus(rlz, w2.Value2())));
138 using SLWeight = fst::SignedLog64Weight;
139 using RWeight = fst::Real64Weight;
140 return sfst::Less(w.Value1(), SLWeight::Zero()) ||
155 return slw_equal_(w1.Value1(), w2.Value1()) &&
156 rw_equal_(w1.Value2(), w2.Value2());
161 WeightApproxEqual rw_equal_;
188 using EArc = fst::ExpectationArc<SLArc, RWeight>;
192 using WCM = fst::WeightConvertMapper<Arc, EArc, WT>;
193 using WCM1 = fst::WeightConvertMapper<EArc, fst::Log64Arc>;
197 Label phi_label = fst::kNoLabel,
198 Label unknown_label = fst::kNoLabel,
199 float delta = fst::kDelta,
200 float entropy_delta = kEntropyDelta)
201 : phi_label_(phi_label),
202 unknown_label_(unknown_label),
204 entropy_delta_(entropy_delta),
214 Label unknown_label = fst::kNoLabel,
215 float delta = fst::kDelta,
216 float entropy_delta = kEntropyDelta)
217 : phi_label_(phi_label),
218 unknown_label_(unknown_label),
220 entropy_delta_(entropy_delta),
228 void SetTarget(
const fst::Fst<Arc> &ifst);
231 bool Apply(
const fst::Fst<Arc> &
fst);
235 double te = GetTotalEntropy();
236 double sc = GetSourceCount();
242 double te = GetTotalEntropy();
243 double sc = GetTotalStateCount();
256 return sent_count_ - GetSourceCount();
266 state_count_.Reset();
273 void FindVocabSet(
const fst::Fst<Arc> &
sfst) {
277 for (f::StateIterator<f::Fst<Arc>> siter(sfst);
281 for (f::ArcIterator<f::Fst<Arc>> aiter(sfst, s);
284 const Arc &arc = aiter.Value();
285 if (arc.ilabel != phi_label_ && arc.ilabel != 0)
286 vocab_.insert(arc.ilabel);
293 void PrepareSource(
const fst::Fst<Arc> &ifst,
294 fst::MutableFst<EArc> *ofst);
297 bool IsSelfEntropy()
const {
299 return qfst_.Start() == f::kNoStateId;
304 double GetTotalEntropy()
const {
305 return entropy_.Sum().Value2().Value();
313 double GetSourceCount()
const {
314 double sign_count = entropy_.Sum().Value1().Value1().Value();
315 double mag_count = entropy_.Sum().Value1().Value2().Value();
316 return sign_count * exp(-mag_count);
323 double GetTotalStateCount()
const {
324 double sign_count = state_count_.Sum().Value1().Value();
325 double mag_count = state_count_.Sum().Value2().Value();
326 return sign_count * exp(-mag_count);
330 Label unknown_label_;
332 float entropy_delta_;
334 fst::VectorFst<EArc> qfst_;
335 fst::Adder<EWeight> entropy_;
336 fst::Adder<SLWeight> state_count_;
339 size_t fst_oov_count_;
340 std::unordered_set<Label> vocab_;
350 if (ifst.Start() == f::kNoLabel) {
351 LOG(ERROR) <<
"Perplexity: target FST has no states";
357 LOG(ERROR) <<
"Perplexity: target is not a normalized stochastic FST";
362 WT to_e(WT::CROSS_ENTROPY_TARGET);
364 f::ArcMap(ifst, &qfst_, wc_mapper);
366 if (unknown_label_ != f::kNoLabel)
376 fst::MutableFst<EArc> *ofst) {
380 LOG(ERROR) <<
"Perplexity: source (" << sent_count_
381 <<
") is not a normalized stochastic FST";
387 (IsSelfEntropy() ? WT::ENTROPY : WT::CROSS_ENTROPY_SOURCE);
391 f::ArcMap(ifst, ofst, wc_mapper);
394 if (unknown_label_ == f::kNoLabel || IsSelfEntropy())
397 for (
StateId s = 0; s < ofst->NumStates(); ++s) {
398 for (f::MutableArcIterator<f::MutableFst<EArc> > aiter(ofst, s);
401 EArc arc = aiter.Value();
402 if (arc.ilabel != 0) {
403 if (arc.ilabel == unknown_label_ || vocab_.count(arc.ilabel) == 0) {
404 arc.ilabel = arc.olabel = unknown_label_;
419 f::VectorFst<EArc> pfst, plogq_fst;
420 if (IsSelfEntropy()) {
421 PrepareSource(fst, &plogq_fst);
423 PrepareSource(fst, &pfst);
424 Intersect(pfst, qfst_, &plogq_fst, phi_label_,
true);
428 std::vector<EWeight> distance;
429 EWeight entropy = ShortestDistance<EArc, EArc, WEq>(
430 plogq_fst, &distance, phi_label_,
false, entropy_delta_);
435 distance.resize(plogq_fst.NumStates(), EWeight::Zero());
437 if (entropy.Member() && entropy != EWeight::Zero()) {
438 entropy_.Add(entropy);
439 oov_count_ += fst_oov_count_;
445 for (
auto w : distance)
446 state_count_.Add(w.Value1());
453 #endif // NLP_GRM2_SFST_PERPLEXITY_H_
fst::Entropy64Weight EWeight
SignedLog64Weight SLWeight
bool Less(fst::LogWeightTpl< T > weight1, fst::LogWeightTpl< T > weight2)
fst::Real64Weight RWeight
size_t NumSources() const
Entropy64WeightApproxEqual(float delta)
double GetPerplexity() const
void DiffStateWeights(const fst::Fst< Arc > &fst, std::vector< typename Arc::Weight > *weights, typename Arc::Label phi_label, bool fail_arc)
Perplexity(Label phi_label=fst::kNoLabel, Label unknown_label=fst::kNoLabel, float delta=fst::kDelta, float entropy_delta=kEntropyDelta)
fst::WeightConvertMapper< Arc, EArc, WT > WCM
SignedLog64Weight SLWeight
Entropy64Weight Minus(Entropy64Weight w1, Entropy64Weight w2)
bool Intersect(const fst::Fst< Arc > &ifst1, const fst::Fst< Arc > &ifst2, fst::MutableFst< Arc > *ofst, typename Arc::Label phi_label=fst::kNoLabel, bool trim=true, TrimType trim_type=TRIM_NEEDED_FINAL)
Perplexity(const fst::Fst< Arc > &fst, Label phi_label=fst::kNoLabel, Label unknown_label=fst::kNoLabel, float delta=fst::kDelta, float entropy_delta=kEntropyDelta)
typename Arc::StateId StateId
fst::SignedLog64Weight SLWeight
EWeight operator()(const LWeight &w) const
typename Arc::Label Label
fst::SignedLog64Arc SLArc
bool IsNormalized(const fst::Fst< Arc > &fst, typename Arc::Label phi_label=fst::kNoLabel, float delta=fst::kDelta)
double GetEntropy() const
bool operator()(const EWeight &w1, const EWeight &w2) const
ExpectationWeight< SignedLog64Weight, Real64Weight > Entropy64Weight
constexpr float kEntropyDelta
LWeight operator()(const EWeight &w) const
bool IsNegative(Entropy64Weight w)
fst::WeightConvertMapper< EArc, fst::Log64Arc > WCM1
typename Arc::Weight Weight
Real64Weight Times(SignedLog64Weight w1, Real64Weight w2)
fst::ExpectationArc< SLArc, RWeight > EArc