16 #ifndef NLP_GRM2_SFST_STATE_WEIGHTS_H_ 17 #define NLP_GRM2_SFST_STATE_WEIGHTS_H_ 24 #include <fst/float-weight.h> 26 #include <fst/matcher.h> 27 #include <fst/signed-log-weight.h> 28 #include <fst/weight.h> 42 template <
class Weight>
46 Weight approx_zero = Weight::NoWeight())
48 approx_zero_(approx_zero.Member() ? to_log_(approx_zero) :
51 bool operator()(
const Weight &w1,
const Weight &w2)
const {
52 return fst::ApproxEqual(w1, w2, delta_) ||
53 (
Less(to_log_(w1), approx_zero_) &&
54 Less(to_log_(w2), approx_zero_));
58 fst::WeightConvert<Weight, fst::Log64Weight> to_log_;
60 fst::Log64Weight approx_zero_;
67 using Weight = fst::SignedLogWeightTpl<T>;
69 Weight approx_zero = Weight::NoWeight())
71 approx_zero_(approx_zero.Member() ? approx_zero :
75 return fst::ApproxEqual(w1, w2, delta_) ||
76 (
Less(w1.Value2(), approx_zero_.Value2()) &&
77 Less(w2.Value2(), approx_zero_.Value2()));
88 template <
class Weight>
91 f::WeightConvert<f::Log64Weight, Weight> from_log;
92 f::WeightConvert<Weight, f::Log64Weight> to_log;
94 f::Adder<f::Log64Weight> sum;
95 for (
size_t i = 0; i < weights->size(); ++i)
96 sum.Add(to_log((*weights)[i]));
97 for (
size_t i = 0; i < weights->size(); ++i)
98 (*weights)[i] = Divide((*weights)[i], from_log(sum.Sum()));
103 template <
class Weight>
105 const std::vector<Weight> &weights2,
106 float delta = fst::kDelta,
107 Weight approx_zero = Weight::NoWeight()) {
109 if (weights1.size() != weights2.size())
return false;
110 return std::equal(weights1.begin(), weights1.end(),
111 weights2.begin(), aepred);
115 template <
class Weight>
117 const std::vector<Weight> &weights) {
119 for (
size_t s = 0; s < weights.size(); ++s)
120 strm << s <<
"\t" << weights[s] <<
"\n";
124 template <
class Weight>
126 const std::vector<Weight> &weights) {
128 if (!filename.empty()) {
129 ostrm.open(filename);
131 LOG(ERROR) <<
"WriteWeights: Can't open file: " << filename;
135 std::ostream &strm = ostrm.is_open() ? ostrm : std::cout;
138 LOG(ERROR) <<
"WritePotentials: Write failed: " 139 << (filename.empty() ?
"standard output" : filename);
158 std::vector<typename Arc::Weight> *weights,
159 typename Arc::Label phi_label,
bool fail_arc) {
161 using StateId =
typename Arc::StateId;
162 using Weight =
typename Arc::Weight;
163 using Matr = f::ExplicitMatcher<f::Matcher<f::Fst<Arc>>>;
165 f::WeightConvert<f::Log64Weight, Weight> from_log;
166 f::WeightConvert<Weight, f::Log64Weight> to_log;
168 if (phi_label == f::kNoLabel)
171 std::vector<StateId> top_order;
172 bool acyclic =
PhiTopOrder(fst, phi_label, &top_order);
174 FSTERROR() <<
"SumStateWeights: FST not canonical (phi-cyclic)";
177 Matr matcher(fst, f::MATCH_INPUT);
178 for (StateId i = 0; i < top_order.size(); ++i) {
179 StateId s = top_order[i];
181 if (matcher.Find(phi_label)) {
182 const Arc &arc = matcher.Value();
183 f::Log64Weight w1 = to_log((*weights)[arc.nextstate]);
184 f::Log64Weight w2 = to_log(fail_arc ?
Times((*weights)[s], arc.weight) :
186 (*weights)[arc.nextstate] = from_log(Plus(w1, w2));
199 std::vector<typename Arc::Weight> *weights,
200 typename Arc::Label phi_label,
bool fail_arc) {
202 using StateId =
typename Arc::StateId;
203 using Weight =
typename Arc::Weight;
204 using Matr = f::ExplicitMatcher<f::Matcher<f::Fst<Arc>>>;
206 f::WeightConvert<f::Log64Weight, Weight> from_log;
207 f::WeightConvert<Weight, f::Log64Weight> to_log;
209 if (phi_label == f::kNoLabel)
212 std::vector<StateId> top_order;
213 bool acyclic =
PhiTopOrder(fst, phi_label, &top_order);
215 FSTERROR() <<
"DiffStateWeights: FST not canonical (phi-cyclic)";
218 Matr matcher(fst, f::MATCH_INPUT);
219 for (StateId i = top_order.size() - 1 ; i >= 0; --i) {
220 StateId s = top_order[i];
222 if (matcher.Find(phi_label)) {
223 const Arc &arc = matcher.Value();
224 f::Log64Weight w1 = to_log((*weights)[arc.nextstate]);
225 f::Log64Weight w2 = to_log(fail_arc ?
Times((*weights)[s], arc.weight) :
227 (*weights)[arc.nextstate] =
Less(w2, w1) ?
228 from_log(
Minus(w1, w2)) : Weight::Zero();
237 void SumStates(
const fst::Fst<Arc> &
fst,
typename Arc::Label phi_label,
238 std::vector<typename Arc::Weight> *weights) {
240 using ArcItr = f::ArcIterator<f::Fst<Arc>>;
241 using StateId =
typename Arc::StateId;
242 using Weight =
typename Arc::Weight;
243 f::WeightConvert<f::Log64Weight, Weight> from_log;
244 f::WeightConvert<Weight, f::Log64Weight> to_log;
247 std::vector<StateId> top_order;
248 bool acyclic =
PhiTopOrder(fst, phi_label, &top_order);
250 FSTERROR() <<
"SumStates: FST not canonical (phi-cyclic)";
253 weights->resize(top_order.size(), Weight::Zero());
255 for (StateId i = top_order.size() - 1 ; i >= 0; --i) {
256 StateId s = top_order[i];
257 f::Adder<f::Log64Weight> ssum(to_log(fst.Final(s)));
258 for (ArcItr aiter(fst, s); !aiter.Done(); aiter.Next()) {
259 const Arc &arc = aiter.Value();
260 ssum.Add(to_log(arc.weight));
262 (*weights)[s] = from_log(ssum.Sum());
269 #endif // NLP_GRM2_SFST_STATE_WEIGHTS_H_ bool Less(fst::LogWeightTpl< T > weight1, fst::LogWeightTpl< T > weight2)
bool operator()(const Weight &w1, const Weight &w2) const
fst::SignedLogWeightTpl< T > Weight
void DiffStateWeights(const fst::Fst< Arc > &fst, std::vector< typename Arc::Weight > *weights, typename Arc::Label phi_label, bool fail_arc)
bool PhiTopOrder(const fst::Fst< Arc > &fst, typename Arc::Label phi_label, std::vector< typename Arc::StateId > *top_order)
Entropy64Weight Minus(Entropy64Weight w1, Entropy64Weight w2)
void WriteWeights(std::ostream &strm, const std::vector< Weight > &weights)
void SumStateWeights(const fst::Fst< Arc > &fst, std::vector< typename Arc::Weight > *weights, typename Arc::Label phi_label, bool fail_arc)
const fst::Log64Weight kApproxZeroWeight
ApproxEqualPred(float delta=fst::kDelta, Weight approx_zero=Weight::NoWeight())
void NormWeights(std::vector< Weight > *weights)
ApproxEqualPred(float delta=fst::kDelta, Weight approx_zero=Weight::NoWeight())
bool ApproxEqualWeights(const std::vector< Weight > &weights1, const std::vector< Weight > &weights2, float delta=fst::kDelta, Weight approx_zero=Weight::NoWeight())
bool operator()(const Weight &w1, const Weight &w2) const
Real64Weight Times(SignedLog64Weight w1, Real64Weight w2)
void SumStates(const fst::Fst< Arc > &fst, typename Arc::Label phi_label, std::vector< typename Arc::Weight > *weights)