24 #include <fst/flags.h> 26 #include <fst/mutable-fst.h> 27 #include <fst/vector-fst.h> 54 bool ValidMergeMethod() {
55 if (FST_FLAGS_method ==
"count_merge" ||
56 FST_FLAGS_method ==
"context_merge" ||
57 FST_FLAGS_method ==
"model_merge" ||
58 FST_FLAGS_method ==
"bayes_model_merge" ||
59 FST_FLAGS_method ==
"histogram_merge" ||
60 FST_FLAGS_method ==
"replace_merge") {
67 bool ReadFst(
const char *file, std::unique_ptr<fst::VectorFst<Arc>> *
fst) {
68 std::string in_name = (strcmp(file,
"-") != 0) ? file :
"";
69 fst->reset(fst::VectorFst<Arc>::Read(file));
76 bool GetContexts(
int in_count, std::vector<std::string> *contexts) {
78 if (!FST_FLAGS_contexts.empty()) {
80 }
else if (!FST_FLAGS_context_pattern.empty()) {
81 contexts->push_back(
"");
82 contexts->push_back(FST_FLAGS_context_pattern);
84 LOG(ERROR) <<
"Context patterns not specified";
87 if (contexts->size() != in_count) {
88 LOG(ERROR) <<
"Wrong number of context patterns specified";
98 for (
size_t s = 0; s < fst->NumStates(); ++s) {
99 for (fst::MutableArcIterator<fst::StdMutableFst> aiter(fst, s);
100 !aiter.Done(); aiter.Next()) {
101 fst::StdArc arc = aiter.Value();
102 auto weight = std::round(std::exp(-arc.weight.Value()));
103 arc.weight = -std::log(weight);
106 if (fst->Final(s) != fst::StdArc::Weight::Zero()) {
107 auto weight = std::round(std::exp(-fst->Final(s).Value()));
108 fst->SetFinal(s, -std::log(weight));
116 std::string usage =
"Merge n-gram models.\n\n Usage: ";
118 usage +=
" [--options] -ofile=out.fst in1.fst in2.fst [in3.fst ...]\n";
119 SET_FLAGS(usage.c_str(), &argc, &argv,
true);
126 std::string out_name = FST_FLAGS_ofile.empty()
127 ? (argc > 3 ? argv[3] :
"")
130 int in_count = FST_FLAGS_ofile.empty() ? 2 : argc - 1;
132 LOG(ERROR) <<
"Only one model given, no merging to do";
137 if (!ValidMergeMethod()) {
138 LOG(ERROR) << argv[0]
139 <<
": bad merge method: " << FST_FLAGS_method;
143 if (FST_FLAGS_method !=
"histogram_merge") {
144 std::unique_ptr<fst::StdVectorFst> fst1;
145 if (!ReadFst<fst::StdArc>(argv[1], &fst1))
return 1;
146 std::unique_ptr<fst::StdVectorFst> fst2;
147 if (FST_FLAGS_method ==
"count_merge") {
149 FST_FLAGS_backoff_label,
151 FST_FLAGS_check_consistency);
152 for (
int i = 2; i <= in_count; ++i) {
153 if (!ReadFst<fst::StdArc>(argv[i], &fst2))
return 1;
154 bool norm = FST_FLAGS_normalize && i == in_count;
156 FST_FLAGS_beta, norm);
157 if (ngramrg.Error())
return 1;
158 if (FST_FLAGS_round_to_int)
161 ngramrg.GetFst().Write(out_name);
162 }
else if (FST_FLAGS_method ==
"model_merge") {
164 FST_FLAGS_backoff_label,
166 FST_FLAGS_check_consistency);
167 for (
int i = 2; i <= in_count; ++i) {
168 if (!ReadFst<fst::StdArc>(argv[i], &fst2))
return 1;
171 FST_FLAGS_normalize);
172 if (ngramrg.Error())
return 1;
174 ngramrg.GetFst().Write(out_name);
175 }
else if (FST_FLAGS_method ==
"bayes_model_merge") {
177 FST_FLAGS_backoff_label,
179 for (
int i = 2; i <= in_count; ++i) {
180 if (!ReadFst<fst::StdArc>(argv[i], &fst2))
return 1;
183 if (ngramrg.Error())
return 1;
185 ngramrg.GetFst().Write(out_name);
186 }
else if (FST_FLAGS_method ==
"replace_merge") {
188 LOG(ERROR) << argv[0] <<
"Only 2 models allowed for replace merge";
192 FST_FLAGS_backoff_label,
194 if (!ReadFst<fst::StdArc>(argv[2], &fst2))
return 1;
196 FST_FLAGS_normalize);
197 if (ngramrg.Error())
return 1;
198 ngramrg.GetFst().Write(out_name);
199 }
else if (FST_FLAGS_method ==
"context_merge") {
201 FST_FLAGS_backoff_label,
203 FST_FLAGS_check_consistency);
204 std::vector<std::string> contexts;
205 if (!GetContexts(in_count, &contexts))
return 1;
206 for (
int i = 2; i <= in_count; ++i) {
207 if (!ReadFst<fst::StdArc>(argv[i], &fst2))
return 1;
208 bool norm = FST_FLAGS_normalize && i == in_count;
210 if (ngramrg.Error())
return 1;
212 ngramrg.GetFst().Write(out_name);
215 std::unique_ptr<fst::VectorFst<ngram::HistogramArc>> hist_fst1;
216 if (!ReadFst<ngram::HistogramArc>(argv[1], &hist_fst1))
return 1;
218 hist_fst1.get(), FST_FLAGS_backoff_label,
219 FST_FLAGS_norm_eps, FST_FLAGS_check_consistency);
220 for (
int i = 2; i <= in_count; ++i) {
221 std::unique_ptr<fst::VectorFst<ngram::HistogramArc>> hist_fst2;
222 if (!ReadFst<ngram::HistogramArc>(argv[i], &hist_fst2))
return 1;
225 FST_FLAGS_normalize);
226 if (ngramrg.Error())
return 1;
228 ngramrg.GetFst().Write(out_name);
int ngrammerge_main(int argc, char **argv)
void RoundCountsToInt(fst::StdMutableFst *fst)
DECLARE_int32(max_replace_order)
void MergeNGramModels(const fst::StdFst &infst2, double alpha, double beta)
bool NGramComplete(fst::MutableFst< Arc > *fst, typename Arc::Label backoff_label=0)
void MergeNGramModels(const fst::Fst< HistogramArc > &infst2, double alpha, double beta, bool norm=false)
bool NGramReadContexts(const std::string &file, std::vector< std::string > *contexts)
void MergeNGramModels(const fst::StdFst &infst2, double alpha, double beta, bool norm=false)
DECLARE_string(context_pattern)
void MergeNGramModels(const fst::StdFst &infst2, double alpha, double beta, bool norm=false)
DECLARE_int64(backoff_label)
void MergeNGramModels(const fst::StdFst &infst2, std::string_view context_pattern, bool norm=false)
void MergeNGramModels(const fst::StdFst &infst2, int max_replace_order=-1, bool norm=false)