NGram  ngram-1.3.15
OpenGrm-NGram library
ngram-count.cc
Go to the documentation of this file.
1 // Copyright 2005-2013 Brian Roark
2 // Copyright 2005-2020 Google LLC
3 //
4 // Licensed under the Apache License, Version 2.0 (the 'License');
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an 'AS IS' BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 // Get counts from input strings.
17 
18 #include <ngram/ngram-count.h>
19 
20 #include <cmath>
21 #include <cstddef>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include <fst/extensions/far/far.h>
28 #include <fst/arc-map.h>
29 #include <fst/arc.h>
30 #include <fst/arcsort.h>
31 #include <fst/float-weight.h>
32 #include <fst/mutable-fst.h>
33 #include <fst/properties.h>
34 #include <fst/symbol-table.h>
35 #include <fst/vector-fst.h>
36 #include <ngram/hist-arc.h>
37 #include <ngram/hist-mapper.h>
38 #include <ngram/ngram-hist-merge.h>
39 
40 namespace ngram {
41 
42 // Rounds -log count to values corresponding to the rounded integer count;
43 // reduces small floating point precision issues when dealing with int counts;
44 // primarily for testing that methods for deriving the same model are identical.
45 void RoundCountsToInt(fst::StdMutableFst *fst) {
46  for (size_t s = 0; s < fst->NumStates(); ++s) {
47  for (fst::MutableArcIterator<fst::StdMutableFst> aiter(fst, s);
48  !aiter.Done(); aiter.Next()) {
49  fst::StdArc arc = aiter.Value();
50  auto weight = std::round(std::exp(-arc.weight.Value()));
51  arc.weight = -std::log(weight);
52  aiter.SetValue(arc);
53  }
54  if (fst->Final(s) != fst::StdArc::Weight::Zero()) {
55  auto weight = std::round(std::exp(-fst->Final(s).Value()));
56  fst->SetFinal(s, -std::log(weight));
57  }
58  }
59 }
60 
61 // Returns an n-gram string and double count from a (history, ngram) pair.
63  const std::pair<std::vector<int>, std::pair<int, double>> &ngram_count,
64  std::string *ngram, const fst::SymbolTable &syms) {
65  std::vector<int> ngram_history = ngram_count.first;
66  *ngram = "";
67  for (size_t i = 0; i < ngram_history.size(); ++i) {
68  std::string symbol =
69  ngram_history[i] > 0 ? syms.Find(ngram_history[i]) : "<s>";
70  *ngram += symbol + " ";
71  }
72  if (ngram_count.second.first > 0) {
73  *ngram += syms.Find(ngram_count.second.first);
74  } else {
75  *ngram += "</s>";
76  }
77  return ngram_count.second.second;
78 }
79 
80 // Gets ngram counts for the next fst in far_reader.
81 bool GetCounts(const std::string &countname,
82  NGramCounter<fst::Log64Weight> *ngram_counter,
83  fst::FarReader<fst::StdArc> *far_reader, int fstnumber,
84  fst::SymbolTable *syms) {
85  std::unique_ptr<const fst::StdVectorFst> ifst(
86  new fst::StdVectorFst(*far_reader->GetFst()));
87  if (!ifst) {
88  LOG(ERROR) << countname << ": unable to read fst #" << fstnumber;
89  return false;
90  }
91 
92  bool counted = false;
93  if (ifst->Properties(fst::kString, true)) {
94  counted = ngram_counter->Count(*ifst);
95  } else {
96  fst::VectorFst<fst::Log64Arc> log_ifst;
97  ArcMap(*ifst, &log_ifst, fst::StdToLog64Mapper());
98  counted = ngram_counter->Count(&log_ifst);
99  }
100  if (!counted) LOG(ERROR) << countname << ": fst #" << fstnumber << " skipped";
101  if (ifst->InputSymbols() != nullptr && syms->NumSymbols() == 0) {
102  // Retains symbol table if available and not yet retained.
103  *syms = *ifst->InputSymbols();
104  }
105 
106  return true;
107 }
108 
109 // Builds a count WFST from a single input.
110 bool GetSingleCountFst(fst::FarReader<fst::StdArc> *far_reader,
111  fst::StdMutableFst *fst, int fstnumber, int order,
112  bool epsilon_as_backoff) {
113  NGramCounter<fst::Log64Weight> ngram_counter(order, epsilon_as_backoff);
114  if (ngram_counter.Error()) {
115  return false;
116  }
117  fst::SymbolTable syms;
118  if (far_reader->Done() ||
119  !GetCounts("ngramhistcount", &ngram_counter, far_reader, fstnumber,
120  &syms)) {
121  return false;
122  }
123  ngram_counter.GetFst(fst);
124  fst::ArcSort(fst, fst::StdILabelCompare());
125  if (syms.NumSymbols() > 0) {
126  fst->SetInputSymbols(&syms);
127  fst->SetOutputSymbols(&syms);
128  }
129  return true;
130 }
131 
132 // Computes counts using the HistogramArc template.
133 bool GetNGramHistograms(fst::FarReader<fst::StdArc> *far_reader,
134  fst::VectorFst<HistogramArc> *fst,
135  int order, bool epsilon_as_backoff, int backoff_label,
136  double norm_eps, bool check_consistency, bool normalize,
137  double alpha, double beta) {
138  int fstnumber = 1;
139  std::unique_ptr<NGramHistMerge> ngramrg;
140  while (!far_reader->Done()) {
141  fst::StdVectorFst in_fst;
142  if (!GetSingleCountFst(far_reader, &in_fst, fstnumber, order,
143  epsilon_as_backoff)) {
144  LOG(ERROR) << "failed to count fst number " << fstnumber;
145  return false;
146  }
147  if (ngramrg == nullptr) {
148  ArcMap(in_fst, fst, ToHistogramMapper<fst::StdArc>());
149  ngramrg = std::make_unique<NGramHistMerge>(fst, backoff_label, norm_eps,
150  check_consistency);
151  } else {
152  fst::VectorFst<HistogramArc> hist_fst;
153  ArcMap(in_fst, &hist_fst, ToHistogramMapper<fst::StdArc>());
154  bool norm = normalize && far_reader->Done();
155  ngramrg->MergeNGramModels(hist_fst, alpha, beta, norm);
156  }
157  far_reader->Next();
158  ++fstnumber;
159  }
160  return true;
161 }
162 
163 // Derives n-gram counts (and symbols) from input FAR reader.
164 bool GetNGramsAndSyms(fst::FarReader<fst::StdArc> *far_reader,
165  NGramCounter<fst::Log64Weight> *ngram_counter,
166  fst::SymbolTable *syms, bool require_symbols,
167  double add_to_symbol_unigram_count) {
168  int fstnumber = 1;
169  while (!far_reader->Done()) {
170  if (!GetCounts("ngramcount", ngram_counter, far_reader, fstnumber, syms))
171  return false;
172  far_reader->Next();
173  ++fstnumber;
174  }
175  if (require_symbols && syms->NumSymbols() == 0) {
176  LOG(ERROR) << "None of the input FSTs had a symbol table";
177  return false;
178  }
179  if (add_to_symbol_unigram_count > 0.0 && require_symbols) {
180  ngram_counter->AddCountToSymbolUnigrams(
181  *syms, /*neg_log_count=*/-log(add_to_symbol_unigram_count));
182  }
183  return true;
184 }
185 
186 // Computes ngram counts and returns ngram format FST.
187 bool GetNGramCounts(fst::FarReader<fst::StdArc> *far_reader,
188  fst::StdMutableFst *fst, int order,
189  bool require_symbols, bool epsilon_as_backoff,
190  bool round_to_int, double add_to_symbol_unigram_count) {
191  NGramCounter<fst::Log64Weight> ngram_counter(order, epsilon_as_backoff);
192  fst::SymbolTable syms;
193  if (!GetNGramsAndSyms(far_reader, &ngram_counter, &syms, require_symbols,
194  add_to_symbol_unigram_count)) {
195  return false;
196  }
197  ngram_counter.GetFst(fst);
198  fst::ArcSort(fst, fst::StdILabelCompare());
199  if (syms.NumSymbols() > 0) {
200  fst->SetInputSymbols(&syms);
201  fst->SetOutputSymbols(&syms);
202  }
203  if (round_to_int) RoundCountsToInt(fst);
204  return true;
205 }
206 
207 // Computes ngram counts and returns vector of strings.
208 bool GetNGramCounts(fst::FarReader<fst::StdArc> *far_reader,
209  std::vector<std::string> *ngrams, int order,
210  bool epsilon_as_backoff,
211  double add_to_symbol_unigram_count) {
212  NGramCounter<fst::Log64Weight> ngram_counter(order, epsilon_as_backoff);
213  fst::SymbolTable syms;
214  if (!GetNGramsAndSyms(far_reader, &ngram_counter, &syms,
215  /* require_symbols = */ true,
216  add_to_symbol_unigram_count)) {
217  // Requires symbols from input far to output as vector of strings.
218  return false;
219  }
220  std::vector<std::pair<std::vector<int>, std::pair<int, double>>> ngram_counts;
221  ngram_counter.GetReverseContextNGrams<fst::StdArc>(&ngram_counts);
222  for (size_t i = 0; i < ngram_counts.size(); ++i) {
223  std::string ngram;
224  double count = GetNGramAndCount(ngram_counts[i], &ngram, syms);
225  ngrams->push_back(ngram + '\t' + std::to_string(count));
226  }
227  return true;
228 }
229 
230 } // namespace ngram
bool GetNGramsAndSyms(fst::FarReader< fst::StdArc > *far_reader, NGramCounter< fst::Log64Weight > *ngram_counter, fst::SymbolTable *syms, bool require_symbols, double add_to_symbol_unigram_count)
Definition: ngram-count.cc:164
void RoundCountsToInt(fst::StdMutableFst *fst)
Definition: ngram-count.cc:45
bool GetNGramHistograms(fst::FarReader< fst::StdArc > *far_reader, fst::VectorFst< HistogramArc > *fst, int order, bool epsilon_as_backoff=false, int backoff_label=0, double norm_eps=kNormEps, bool check_consistency=false, bool normalize=false, double alpha=1.0, double beta=1.0)
Definition: ngram-count.cc:133
bool Count(const fst::Fst< Arc > &fst)
Definition: ngram-count.h:77
void AddCountToSymbolUnigrams(const fst::SymbolTable &syms, Weight neg_log_count)
Definition: ngram-count.h:210
void GetFst(fst::MutableFst< Arc > *fst)
Definition: ngram-count.h:111
void GetReverseContextNGrams(std::vector< std::pair< std::vector< int >, std::pair< Label, double >>> *ngram_counts)
Definition: ngram-count.h:133
bool Error() const
Definition: ngram-count.h:233
double GetNGramAndCount(const std::pair< std::vector< int >, std::pair< int, double >> &ngram_count, std::string *ngram, const fst::SymbolTable &syms)
Definition: ngram-count.cc:62
bool GetNGramCounts(fst::FarReader< fst::StdArc > *far_reader, fst::StdMutableFst *fst, int order, bool require_symbols=true, bool epsilon_as_backoff=false, bool round_to_int=false, double add_to_symbol_unigram_count=0.0)
Definition: ngram-count.cc:187
bool GetSingleCountFst(fst::FarReader< fst::StdArc > *far_reader, fst::StdMutableFst *fst, int fstnumber, int order, bool epsilon_as_backoff)
Definition: ngram-count.cc:110
bool GetCounts(const std::string &countname, NGramCounter< fst::Log64Weight > *ngram_counter, fst::FarReader< fst::StdArc > *far_reader, int fstnumber, fst::SymbolTable *syms)
Definition: ngram-count.cc:81