NGram  ngram-1.3.15
OpenGrm-NGram library
ngrammerge-main.cc
Go to the documentation of this file.
1 // Copyright 2005-2013 Brian Roark
2 // Copyright 2005-2020 Google LLC
3 //
4 // Licensed under the Apache License, Version 2.0 (the 'License');
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an 'AS IS' BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 // Merges two input n-gram models into a single model.
17 
18 #include <cmath>
19 #include <cstring>
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 #include <fst/flags.h>
25 #include <fst/arc.h>
26 #include <fst/mutable-fst.h>
27 #include <fst/vector-fst.h>
28 #include <ngram/hist-arc.h>
30 #include <ngram/ngram-complete.h>
32 #include <ngram/ngram-context.h>
34 #include <ngram/ngram-hist-merge.h>
37 
38 DECLARE_double(alpha);
39 DECLARE_double(beta);
40 DECLARE_string(context_pattern);
41 DECLARE_string(contexts);
42 DECLARE_bool(normalize);
43 DECLARE_string(method);
44 DECLARE_int32(max_replace_order);
45 DECLARE_string(ofile);
46 DECLARE_int64(backoff_label);
47 DECLARE_double(norm_eps);
48 DECLARE_bool(check_consistency);
49 DECLARE_bool(complete);
50 DECLARE_bool(round_to_int);
51 
52 namespace {
53 
54 bool ValidMergeMethod() {
55  if (FST_FLAGS_method == "count_merge" ||
56  FST_FLAGS_method == "context_merge" ||
57  FST_FLAGS_method == "model_merge" ||
58  FST_FLAGS_method == "bayes_model_merge" ||
59  FST_FLAGS_method == "histogram_merge" ||
60  FST_FLAGS_method == "replace_merge") {
61  return true;
62  }
63  return false;
64 }
65 
66 template <class Arc>
67 bool ReadFst(const char *file, std::unique_ptr<fst::VectorFst<Arc>> *fst) {
68  std::string in_name = (strcmp(file, "-") != 0) ? file : "";
69  fst->reset(fst::VectorFst<Arc>::Read(file));
70  if (!*fst ||
71  (FST_FLAGS_complete && !ngram::NGramComplete(fst->get())))
72  return false;
73  return true;
74 }
75 
76 bool GetContexts(int in_count, std::vector<std::string> *contexts) {
77  contexts->clear();
78  if (!FST_FLAGS_contexts.empty()) {
79  ngram::NGramReadContexts(FST_FLAGS_contexts, contexts);
80  } else if (!FST_FLAGS_context_pattern.empty()) {
81  contexts->push_back("");
82  contexts->push_back(FST_FLAGS_context_pattern);
83  } else {
84  LOG(ERROR) << "Context patterns not specified";
85  return false;
86  }
87  if (contexts->size() != in_count) {
88  LOG(ERROR) << "Wrong number of context patterns specified";
89  return false;
90  }
91  return true;
92 }
93 
94 // Rounds -log count to values corresponding to the rounded integer count
95 // Reduces small floating point precision issues when dealing with int counts
96 // Primarily for testing that methods for deriving the same model are identical
97 void RoundCountsToInt(fst::StdMutableFst *fst) {
98  for (size_t s = 0; s < fst->NumStates(); ++s) {
99  for (fst::MutableArcIterator<fst::StdMutableFst> aiter(fst, s);
100  !aiter.Done(); aiter.Next()) {
101  fst::StdArc arc = aiter.Value();
102  auto weight = std::round(std::exp(-arc.weight.Value()));
103  arc.weight = -std::log(weight);
104  aiter.SetValue(arc);
105  }
106  if (fst->Final(s) != fst::StdArc::Weight::Zero()) {
107  auto weight = std::round(std::exp(-fst->Final(s).Value()));
108  fst->SetFinal(s, -std::log(weight));
109  }
110  }
111 }
112 
113 } // namespace
114 
115 int ngrammerge_main(int argc, char **argv) {
116  std::string usage = "Merge n-gram models.\n\n Usage: ";
117  usage += argv[0];
118  usage += " [--options] -ofile=out.fst in1.fst in2.fst [in3.fst ...]\n";
119  SET_FLAGS(usage.c_str(), &argc, &argv, true);
120 
121  if (argc < 3) {
122  ShowUsage();
123  return 1;
124  }
125 
126  std::string out_name = FST_FLAGS_ofile.empty()
127  ? (argc > 3 ? argv[3] : "")
128  : FST_FLAGS_ofile;
129 
130  int in_count = FST_FLAGS_ofile.empty() ? 2 : argc - 1;
131  if (in_count < 2) {
132  LOG(ERROR) << "Only one model given, no merging to do";
133  ShowUsage();
134  return 1;
135  }
136 
137  if (!ValidMergeMethod()) {
138  LOG(ERROR) << argv[0]
139  << ": bad merge method: " << FST_FLAGS_method;
140  return 1;
141  }
142 
143  if (FST_FLAGS_method != "histogram_merge") {
144  std::unique_ptr<fst::StdVectorFst> fst1;
145  if (!ReadFst<fst::StdArc>(argv[1], &fst1)) return 1;
146  std::unique_ptr<fst::StdVectorFst> fst2;
147  if (FST_FLAGS_method == "count_merge") {
148  ngram::NGramCountMerge ngramrg(fst1.get(),
149  FST_FLAGS_backoff_label,
150  FST_FLAGS_norm_eps,
151  FST_FLAGS_check_consistency);
152  for (int i = 2; i <= in_count; ++i) {
153  if (!ReadFst<fst::StdArc>(argv[i], &fst2)) return 1;
154  bool norm = FST_FLAGS_normalize && i == in_count;
155  ngramrg.MergeNGramModels(*fst2, FST_FLAGS_alpha,
156  FST_FLAGS_beta, norm);
157  if (ngramrg.Error()) return 1;
158  if (FST_FLAGS_round_to_int)
159  RoundCountsToInt(ngramrg.GetMutableFst());
160  }
161  ngramrg.GetFst().Write(out_name);
162  } else if (FST_FLAGS_method == "model_merge") {
163  ngram::NGramModelMerge ngramrg(fst1.get(),
164  FST_FLAGS_backoff_label,
165  FST_FLAGS_norm_eps,
166  FST_FLAGS_check_consistency);
167  for (int i = 2; i <= in_count; ++i) {
168  if (!ReadFst<fst::StdArc>(argv[i], &fst2)) return 1;
169  ngramrg.MergeNGramModels(*fst2, FST_FLAGS_alpha,
170  FST_FLAGS_beta,
171  FST_FLAGS_normalize);
172  if (ngramrg.Error()) return 1;
173  }
174  ngramrg.GetFst().Write(out_name);
175  } else if (FST_FLAGS_method == "bayes_model_merge") {
176  ngram::NGramBayesModelMerge ngramrg(fst1.get(),
177  FST_FLAGS_backoff_label,
178  FST_FLAGS_norm_eps);
179  for (int i = 2; i <= in_count; ++i) {
180  if (!ReadFst<fst::StdArc>(argv[i], &fst2)) return 1;
181  ngramrg.MergeNGramModels(*fst2, FST_FLAGS_alpha,
182  FST_FLAGS_beta);
183  if (ngramrg.Error()) return 1;
184  }
185  ngramrg.GetFst().Write(out_name);
186  } else if (FST_FLAGS_method == "replace_merge") {
187  if (in_count != 2) {
188  LOG(ERROR) << argv[0] << "Only 2 models allowed for replace merge";
189  return 1;
190  }
191  ngram::NGramReplaceMerge ngramrg(fst1.get(),
192  FST_FLAGS_backoff_label,
193  FST_FLAGS_norm_eps);
194  if (!ReadFst<fst::StdArc>(argv[2], &fst2)) return 1;
195  ngramrg.MergeNGramModels(*fst2, FST_FLAGS_max_replace_order,
196  FST_FLAGS_normalize);
197  if (ngramrg.Error()) return 1;
198  ngramrg.GetFst().Write(out_name);
199  } else if (FST_FLAGS_method == "context_merge") {
200  ngram::NGramContextMerge ngramrg(fst1.get(),
201  FST_FLAGS_backoff_label,
202  FST_FLAGS_norm_eps,
203  FST_FLAGS_check_consistency);
204  std::vector<std::string> contexts;
205  if (!GetContexts(in_count, &contexts)) return 1;
206  for (int i = 2; i <= in_count; ++i) {
207  if (!ReadFst<fst::StdArc>(argv[i], &fst2)) return 1;
208  bool norm = FST_FLAGS_normalize && i == in_count;
209  ngramrg.MergeNGramModels(*fst2, contexts[i - 1], norm);
210  if (ngramrg.Error()) return 1;
211  }
212  ngramrg.GetFst().Write(out_name);
213  }
214  } else {
215  std::unique_ptr<fst::VectorFst<ngram::HistogramArc>> hist_fst1;
216  if (!ReadFst<ngram::HistogramArc>(argv[1], &hist_fst1)) return 1;
217  ngram::NGramHistMerge ngramrg(
218  hist_fst1.get(), FST_FLAGS_backoff_label,
219  FST_FLAGS_norm_eps, FST_FLAGS_check_consistency);
220  for (int i = 2; i <= in_count; ++i) {
221  std::unique_ptr<fst::VectorFst<ngram::HistogramArc>> hist_fst2;
222  if (!ReadFst<ngram::HistogramArc>(argv[i], &hist_fst2)) return 1;
223  ngramrg.MergeNGramModels(*hist_fst2, FST_FLAGS_alpha,
224  FST_FLAGS_beta,
225  FST_FLAGS_normalize);
226  if (ngramrg.Error()) return 1;
227  }
228  ngramrg.GetFst().Write(out_name);
229  }
230  return 0;
231 }
int ngrammerge_main(int argc, char **argv)
DECLARE_bool(normalize)
void RoundCountsToInt(fst::StdMutableFst *fst)
Definition: ngram-count.cc:45
DECLARE_double(alpha)
DECLARE_int32(max_replace_order)
void MergeNGramModels(const fst::StdFst &infst2, double alpha, double beta)
bool NGramComplete(fst::MutableFst< Arc > *fst, typename Arc::Label backoff_label=0)
void MergeNGramModels(const fst::Fst< HistogramArc > &infst2, double alpha, double beta, bool norm=false)
bool NGramReadContexts(const std::string &file, std::vector< std::string > *contexts)
void MergeNGramModels(const fst::StdFst &infst2, double alpha, double beta, bool norm=false)
DECLARE_string(context_pattern)
void MergeNGramModels(const fst::StdFst &infst2, double alpha, double beta, bool norm=false)
DECLARE_int64(backoff_label)
void MergeNGramModels(const fst::StdFst &infst2, std::string_view context_pattern, bool norm=false)
void MergeNGramModels(const fst::StdFst &infst2, int max_replace_order=-1, bool norm=false)