NGram  ngram-1.3.14
OpenGrm-NGram library
ngrammerge-main.cc
Go to the documentation of this file.
1 // Copyright 2005-2013 Brian Roark
2 // Copyright 2005-2020 Google LLC
3 //
4 // Licensed under the Apache License, Version 2.0 (the 'License');
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an 'AS IS' BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 // Merges two input n-gram models into a single model.
17 
18 #include <cmath>
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include <fst/flags.h>
25 #include <ngram/ngram-complete.h>
28 #include <ngram/ngram-hist-merge.h>
31 
32 DECLARE_double(alpha);
33 DECLARE_double(beta);
34 DECLARE_string(context_pattern);
35 DECLARE_string(contexts);
36 DECLARE_bool(normalize);
37 DECLARE_string(method);
38 DECLARE_int32(max_replace_order);
39 DECLARE_string(ofile);
40 DECLARE_int64(backoff_label);
41 DECLARE_double(norm_eps);
42 DECLARE_bool(check_consistency);
43 DECLARE_bool(complete);
44 DECLARE_bool(round_to_int);
45 
46 namespace {
47 
48 bool ValidMergeMethod() {
49  if (FST_FLAGS_method == "count_merge" ||
50  FST_FLAGS_method == "context_merge" ||
51  FST_FLAGS_method == "model_merge" ||
52  FST_FLAGS_method == "bayes_model_merge" ||
53  FST_FLAGS_method == "histogram_merge" ||
54  FST_FLAGS_method == "replace_merge") {
55  return true;
56  }
57  return false;
58 }
59 
60 template <class Arc>
61 bool ReadFst(const char *file, std::unique_ptr<fst::VectorFst<Arc>> *fst) {
62  std::string in_name = (strcmp(file, "-") != 0) ? file : "";
63  fst->reset(fst::VectorFst<Arc>::Read(file));
64  if (!*fst ||
65  (FST_FLAGS_complete && !ngram::NGramComplete(fst->get())))
66  return false;
67  return true;
68 }
69 
70 bool GetContexts(int in_count, std::vector<std::string> *contexts) {
71  contexts->clear();
72  if (!FST_FLAGS_contexts.empty()) {
73  ngram::NGramReadContexts(FST_FLAGS_contexts, contexts);
74  } else if (!FST_FLAGS_context_pattern.empty()) {
75  contexts->push_back("");
76  contexts->push_back(FST_FLAGS_context_pattern);
77  } else {
78  LOG(ERROR) << "Context patterns not specified";
79  return false;
80  }
81  if (contexts->size() != in_count) {
82  LOG(ERROR) << "Wrong number of context patterns specified";
83  return false;
84  }
85  return true;
86 }
87 
88 // Rounds -log count to values corresponding to the rounded integer count
89 // Reduces small floating point precision issues when dealing with int counts
90 // Primarily for testing that methods for deriving the same model are identical
91 void RoundCountsToInt(fst::StdMutableFst *fst) {
92  for (size_t s = 0; s < fst->NumStates(); ++s) {
93  for (fst::MutableArcIterator<fst::StdMutableFst> aiter(fst, s);
94  !aiter.Done(); aiter.Next()) {
95  fst::StdArc arc = aiter.Value();
96  auto weight = std::round(std::exp(-arc.weight.Value()));
97  arc.weight = -std::log(weight);
98  aiter.SetValue(arc);
99  }
100  if (fst->Final(s) != fst::StdArc::Weight::Zero()) {
101  auto weight = std::round(std::exp(-fst->Final(s).Value()));
102  fst->SetFinal(s, -std::log(weight));
103  }
104  }
105 }
106 
107 } // namespace
108 
109 int ngrammerge_main(int argc, char **argv) {
110  std::string usage = "Merge n-gram models.\n\n Usage: ";
111  usage += argv[0];
112  usage += " [--options] -ofile=out.fst in1.fst in2.fst [in3.fst ...]\n";
113  std::set_new_handler(FailedNewHandler);
114  SET_FLAGS(usage.c_str(), &argc, &argv, true);
115 
116  if (argc < 3) {
117  ShowUsage();
118  return 1;
119  }
120 
121  std::string out_name = FST_FLAGS_ofile.empty()
122  ? (argc > 3 ? argv[3] : "")
123  : FST_FLAGS_ofile;
124 
125  int in_count = FST_FLAGS_ofile.empty() ? 2 : argc - 1;
126  if (in_count < 2) {
127  LOG(ERROR) << "Only one model given, no merging to do";
128  ShowUsage();
129  return 1;
130  }
131 
132  if (!ValidMergeMethod()) {
133  LOG(ERROR) << argv[0]
134  << ": bad merge method: " << FST_FLAGS_method;
135  return 1;
136  }
137 
138  if (FST_FLAGS_method != "histogram_merge") {
139  std::unique_ptr<fst::StdVectorFst> fst1;
140  if (!ReadFst<fst::StdArc>(argv[1], &fst1)) return 1;
141  std::unique_ptr<fst::StdVectorFst> fst2;
142  if (FST_FLAGS_method == "count_merge") {
143  ngram::NGramCountMerge ngramrg(fst1.get(),
144  FST_FLAGS_backoff_label,
145  FST_FLAGS_norm_eps,
146  FST_FLAGS_check_consistency);
147  for (int i = 2; i <= in_count; ++i) {
148  if (!ReadFst<fst::StdArc>(argv[i], &fst2)) return 1;
149  bool norm = FST_FLAGS_normalize && i == in_count;
150  ngramrg.MergeNGramModels(*fst2, FST_FLAGS_alpha,
151  FST_FLAGS_beta, norm);
152  if (ngramrg.Error()) return 1;
153  if (FST_FLAGS_round_to_int)
154  RoundCountsToInt(ngramrg.GetMutableFst());
155  }
156  ngramrg.GetFst().Write(out_name);
157  } else if (FST_FLAGS_method == "model_merge") {
158  ngram::NGramModelMerge ngramrg(fst1.get(),
159  FST_FLAGS_backoff_label,
160  FST_FLAGS_norm_eps,
161  FST_FLAGS_check_consistency);
162  for (int i = 2; i <= in_count; ++i) {
163  if (!ReadFst<fst::StdArc>(argv[i], &fst2)) return 1;
164  ngramrg.MergeNGramModels(*fst2, FST_FLAGS_alpha,
165  FST_FLAGS_beta,
166  FST_FLAGS_normalize);
167  if (ngramrg.Error()) return 1;
168  }
169  ngramrg.GetFst().Write(out_name);
170  } else if (FST_FLAGS_method == "bayes_model_merge") {
171  ngram::NGramBayesModelMerge ngramrg(fst1.get(),
172  FST_FLAGS_backoff_label,
173  FST_FLAGS_norm_eps);
174  for (int i = 2; i <= in_count; ++i) {
175  if (!ReadFst<fst::StdArc>(argv[i], &fst2)) return 1;
176  ngramrg.MergeNGramModels(*fst2, FST_FLAGS_alpha,
177  FST_FLAGS_beta);
178  if (ngramrg.Error()) return 1;
179  }
180  ngramrg.GetFst().Write(out_name);
181  } else if (FST_FLAGS_method == "replace_merge") {
182  if (in_count != 2) {
183  LOG(ERROR) << argv[0] << "Only 2 models allowed for replace merge";
184  return 1;
185  }
186  ngram::NGramReplaceMerge ngramrg(fst1.get(),
187  FST_FLAGS_backoff_label,
188  FST_FLAGS_norm_eps);
189  if (!ReadFst<fst::StdArc>(argv[2], &fst2)) return 1;
190  ngramrg.MergeNGramModels(*fst2, FST_FLAGS_max_replace_order,
191  FST_FLAGS_normalize);
192  if (ngramrg.Error()) return 1;
193  ngramrg.GetFst().Write(out_name);
194  } else if (FST_FLAGS_method == "context_merge") {
195  ngram::NGramContextMerge ngramrg(fst1.get(),
196  FST_FLAGS_backoff_label,
197  FST_FLAGS_norm_eps,
198  FST_FLAGS_check_consistency);
199  std::vector<std::string> contexts;
200  if (!GetContexts(in_count, &contexts)) return 1;
201  for (int i = 2; i <= in_count; ++i) {
202  if (!ReadFst<fst::StdArc>(argv[i], &fst2)) return 1;
203  bool norm = FST_FLAGS_normalize && i == in_count;
204  ngramrg.MergeNGramModels(*fst2, contexts[i - 1], norm);
205  if (ngramrg.Error()) return 1;
206  }
207  ngramrg.GetFst().Write(out_name);
208  }
209  } else {
210  std::unique_ptr<fst::VectorFst<ngram::HistogramArc>> hist_fst1;
211  if (!ReadFst<ngram::HistogramArc>(argv[1], &hist_fst1)) return 1;
212  ngram::NGramHistMerge ngramrg(
213  hist_fst1.get(), FST_FLAGS_backoff_label,
214  FST_FLAGS_norm_eps, FST_FLAGS_check_consistency);
215  for (int i = 2; i <= in_count; ++i) {
216  std::unique_ptr<fst::VectorFst<ngram::HistogramArc>> hist_fst2;
217  if (!ReadFst<ngram::HistogramArc>(argv[i], &hist_fst2)) return 1;
218  ngramrg.MergeNGramModels(*hist_fst2, FST_FLAGS_alpha,
219  FST_FLAGS_beta,
220  FST_FLAGS_normalize);
221  if (ngramrg.Error()) return 1;
222  }
223  ngramrg.GetFst().Write(out_name);
224  }
225  return 0;
226 }
int ngrammerge_main(int argc, char **argv)
void MergeNGramModels(const fst::StdFst &infst2, const std::string &context_pattern, bool norm=false)
DECLARE_bool(normalize)
void RoundCountsToInt(fst::StdMutableFst *fst)
Definition: ngram-count.cc:30
DECLARE_double(alpha)
DECLARE_int32(max_replace_order)
void MergeNGramModels(const fst::StdFst &infst2, double alpha, double beta)
bool NGramComplete(fst::MutableFst< Arc > *fst, typename Arc::Label backoff_label=0)
void MergeNGramModels(const fst::Fst< HistogramArc > &infst2, double alpha, double beta, bool norm=false)
bool NGramReadContexts(const std::string &file, std::vector< std::string > *contexts)
void MergeNGramModels(const fst::StdFst &infst2, double alpha, double beta, bool norm=false)
DECLARE_string(context_pattern)
void MergeNGramModels(const fst::StdFst &infst2, double alpha, double beta, bool norm=false)
DECLARE_int64(backoff_label)
void MergeNGramModels(const fst::StdFst &infst2, int max_replace_order=-1, bool norm=false)