27 #include <fst/extensions/far/far.h> 28 #include <fst/arc-map.h> 30 #include <fst/arcsort.h> 31 #include <fst/float-weight.h> 32 #include <fst/mutable-fst.h> 33 #include <fst/properties.h> 34 #include <fst/symbol-table.h> 35 #include <fst/vector-fst.h> 46 for (
size_t s = 0; s < fst->NumStates(); ++s) {
47 for (fst::MutableArcIterator<fst::StdMutableFst> aiter(fst, s);
48 !aiter.Done(); aiter.Next()) {
49 fst::StdArc arc = aiter.Value();
50 auto weight = std::round(std::exp(-arc.weight.Value()));
51 arc.weight = -std::log(weight);
54 if (fst->Final(s) != fst::StdArc::Weight::Zero()) {
55 auto weight = std::round(std::exp(-fst->Final(s).Value()));
56 fst->SetFinal(s, -std::log(weight));
63 const std::pair<std::vector<int>, std::pair<int, double>> &ngram_count,
64 std::string *
ngram,
const fst::SymbolTable &syms) {
65 std::vector<int> ngram_history = ngram_count.first;
67 for (
size_t i = 0; i < ngram_history.size(); ++i) {
69 ngram_history[i] > 0 ? syms.Find(ngram_history[i]) :
"<s>";
70 *ngram += symbol +
" ";
72 if (ngram_count.second.first > 0) {
73 *ngram += syms.Find(ngram_count.second.first);
77 return ngram_count.second.second;
83 fst::FarReader<fst::StdArc> *far_reader,
int fstnumber,
84 fst::SymbolTable *syms) {
85 std::unique_ptr<const fst::StdVectorFst> ifst(
86 new fst::StdVectorFst(*far_reader->GetFst()));
88 LOG(ERROR) << countname <<
": unable to read fst #" << fstnumber;
93 if (ifst->Properties(fst::kString,
true)) {
94 counted = ngram_counter->
Count(*ifst);
96 fst::VectorFst<fst::Log64Arc> log_ifst;
97 ArcMap(*ifst, &log_ifst, fst::StdToLog64Mapper());
98 counted = ngram_counter->
Count(&log_ifst);
100 if (!counted) LOG(ERROR) << countname <<
": fst #" << fstnumber <<
" skipped";
101 if (ifst->InputSymbols() !=
nullptr && syms->NumSymbols() == 0) {
103 *syms = *ifst->InputSymbols();
111 fst::StdMutableFst *
fst,
int fstnumber,
int order,
112 bool epsilon_as_backoff) {
114 if (ngram_counter.
Error()) {
117 fst::SymbolTable syms;
118 if (far_reader->Done() ||
119 !
GetCounts(
"ngramhistcount", &ngram_counter, far_reader, fstnumber,
123 ngram_counter.
GetFst(fst);
124 fst::ArcSort(fst, fst::StdILabelCompare());
125 if (syms.NumSymbols() > 0) {
126 fst->SetInputSymbols(&syms);
127 fst->SetOutputSymbols(&syms);
134 fst::VectorFst<HistogramArc> *
fst,
135 int order,
bool epsilon_as_backoff,
int backoff_label,
136 double norm_eps,
bool check_consistency,
bool normalize,
137 double alpha,
double beta) {
139 std::unique_ptr<NGramHistMerge> ngramrg;
140 while (!far_reader->Done()) {
141 fst::StdVectorFst in_fst;
143 epsilon_as_backoff)) {
144 LOG(ERROR) <<
"failed to count fst number " << fstnumber;
147 if (ngramrg ==
nullptr) {
149 ngramrg = std::make_unique<NGramHistMerge>(fst, backoff_label, norm_eps,
152 fst::VectorFst<HistogramArc> hist_fst;
154 bool norm = normalize && far_reader->Done();
155 ngramrg->MergeNGramModels(hist_fst, alpha, beta, norm);
166 fst::SymbolTable *syms,
bool require_symbols,
167 double add_to_symbol_unigram_count) {
169 while (!far_reader->Done()) {
170 if (!
GetCounts(
"ngramcount", ngram_counter, far_reader, fstnumber, syms))
175 if (require_symbols && syms->NumSymbols() == 0) {
176 LOG(ERROR) <<
"None of the input FSTs had a symbol table";
179 if (add_to_symbol_unigram_count > 0.0 && require_symbols) {
181 *syms, -log(add_to_symbol_unigram_count));
189 bool require_symbols,
bool epsilon_as_backoff,
190 bool round_to_int,
double add_to_symbol_unigram_count) {
192 fst::SymbolTable syms;
194 add_to_symbol_unigram_count)) {
197 ngram_counter.
GetFst(fst);
198 fst::ArcSort(fst, fst::StdILabelCompare());
199 if (syms.NumSymbols() > 0) {
200 fst->SetInputSymbols(&syms);
201 fst->SetOutputSymbols(&syms);
209 std::vector<std::string> *ngrams,
int order,
210 bool epsilon_as_backoff,
211 double add_to_symbol_unigram_count) {
213 fst::SymbolTable syms;
216 add_to_symbol_unigram_count)) {
220 std::vector<std::pair<std::vector<int>, std::pair<int, double>>> ngram_counts;
222 for (
size_t i = 0; i < ngram_counts.size(); ++i) {
225 ngrams->push_back(ngram +
'\t' + std::to_string(count));
bool GetNGramsAndSyms(fst::FarReader< fst::StdArc > *far_reader, NGramCounter< fst::Log64Weight > *ngram_counter, fst::SymbolTable *syms, bool require_symbols, double add_to_symbol_unigram_count)
void RoundCountsToInt(fst::StdMutableFst *fst)
bool GetNGramHistograms(fst::FarReader< fst::StdArc > *far_reader, fst::VectorFst< HistogramArc > *fst, int order, bool epsilon_as_backoff=false, int backoff_label=0, double norm_eps=kNormEps, bool check_consistency=false, bool normalize=false, double alpha=1.0, double beta=1.0)
bool Count(const fst::Fst< Arc > &fst)
void AddCountToSymbolUnigrams(const fst::SymbolTable &syms, Weight neg_log_count)
void GetFst(fst::MutableFst< Arc > *fst)
void GetReverseContextNGrams(std::vector< std::pair< std::vector< int >, std::pair< Label, double >>> *ngram_counts)
double GetNGramAndCount(const std::pair< std::vector< int >, std::pair< int, double >> &ngram_count, std::string *ngram, const fst::SymbolTable &syms)
bool GetNGramCounts(fst::FarReader< fst::StdArc > *far_reader, fst::StdMutableFst *fst, int order, bool require_symbols=true, bool epsilon_as_backoff=false, bool round_to_int=false, double add_to_symbol_unigram_count=0.0)
bool GetSingleCountFst(fst::FarReader< fst::StdArc > *far_reader, fst::StdMutableFst *fst, int fstnumber, int order, bool epsilon_as_backoff)
bool GetCounts(const std::string &countname, NGramCounter< fst::Log64Weight > *ngram_counter, fst::FarReader< fst::StdArc > *far_reader, int fstnumber, fst::SymbolTable *syms)