GRM-SFST  sfst-1.2.1
OpenGrm SFst Library
rmphi.h
Go to the documentation of this file.
1 // Copyright 2018-2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // Replaces failure with epsilon transitions, correcting for
15 // over-acceptance by adding negatively-weighted transitions or by
16 // subtracting weights from existing transitions and final weights
17 
18 #ifndef NLP_GRM2_SFST_RMPHI_H_
19 #define NLP_GRM2_SFST_RMPHI_H_
20 
21 #include <sys/types.h>
22 
23 #include <cstddef>
24 #include <cstdint>
25 #include <memory>
26 #include <vector>
27 
28 #include <fst/arc-map.h>
29 #include <fst/arcsort.h>
30 #include <fst/fst.h>
31 #include <fst/matcher.h>
32 #include <fst/properties.h>
33 #include <fst/state-map.h>
34 #include <fst/weight.h>
35 #include <sfst/sfst.h>
36 
37 namespace sfst {
38 
39 namespace internal {
40 
41 // Replaces failure with epsilon transitions, correcting for
42 // over-acceptance by adding negatively-weighted transitions or by
43 // subtracting weights from existing transitions and final weights.
44 // Weights must be a ring (i.e., have a Minus() defined), e.g.
45 // SignedLogWeight. The phi label can be an 0 here, but then all
46 // epsilons will be interpreted as failure of course. (Non-failure)
47 // epsilons in the input are treated as regular symbols where each
48 // instance behaves as if it is uniquely labeled (i.e, they are not
49 // constrained by failure transitions). Assumes (but does not check)
50 // it has a canonical topology (see canonical.h) for a stochastic FST
51 // (when match_input = true, o.w. fst^-1 is assumed canonical). The
52 // rewrite_mode determines if both input and output failure labels are
53 // replaced. If this mapper is used to directly modify its input, it
54 // should be applied only to a phi-topsorted input.
55 template <class Arc>
56 class RmPhiMapper {
57  public:
58  typedef Arc FromArc;
59  typedef Arc ToArc;
60 
61  typedef typename Arc::StateId StateId;
62  typedef typename Arc::Weight Weight;
63  typedef typename Arc::Label Label;
64 
65  RmPhiMapper(const fst::Fst<Arc> &fst, Label phi, bool match_input = true,
66  fst::MatcherRewriteMode rewrite_mode =
67  fst::MATCHER_REWRITE_AUTO)
68  : fst_(fst.Copy()),
69  phi_(phi),
70  match_input_(match_input),
71  s_(fst::kNoStateId),
72  i_(0),
73  final_(Weight::Zero()),
74  matcher_(*fst_, match_input_ ?
75  fst::MATCH_INPUT : fst::MATCH_OUTPUT),
76  failpath_(*fst_, phi, match_input) {
77  namespace f = fst;
78  if (rewrite_mode == f::MATCHER_REWRITE_AUTO) {
79  rewrite_both_ = fst.Properties(f::kAcceptor, true);
80  } else if (rewrite_mode == f::MATCHER_REWRITE_ALWAYS) {
81  rewrite_both_ = true;
82  } else {
83  rewrite_both_ = false;
84  }
85  }
86 
87  // Allows updating Fst argument; pass only if changed.
89  const fst::Fst<Arc> *fst = nullptr)
90  : fst_(fst ? fst->Copy() : mapper.fst_->Copy()),
91  phi_(mapper.phi_),
92  match_input_(mapper.match_input_),
93  rewrite_both_(mapper.rewrite_both_),
94  s_(fst::kNoStateId),
95  i_(0),
96  final_(Weight::Zero()),
97  matcher_(*fst_,
98  match_input_ ? fst::MATCH_INPUT : fst::MATCH_OUTPUT),
99  failpath_(*fst_, phi_, match_input_) {}
100 
101  StateId Start() { return fst_->Start(); }
102  Weight Final(StateId s) const;
103 
104  void SetState(StateId s);
105 
106  bool Done() const { return i_ >= arcs_.size(); }
107  const Arc &Value() const { return arcs_[i_]; }
108  void Next() { ++i_; }
109 
110  fst::MapSymbolsAction
111  InputSymbolsAction() const { return fst::MAP_COPY_SYMBOLS; }
112 
113  fst::MapSymbolsAction
114  OutputSymbolsAction() const { return fst::MAP_COPY_SYMBOLS; }
115 
116  uint64_t Properties(uint64_t iprops) const {
117  namespace f = fst;
118  uint64_t oprops = iprops & f::kAddArcProperties & f::kSetFinalProperties &
119  f::kWeightInvariantProperties;
120  if (phi_ != 0 || !rewrite_both_) oprops &= f::kSetArcProperties;
121  return oprops;
122  }
123 
124  private:
125  // Constructor arguments
126  std::unique_ptr<fst::Fst<Arc>> fst_;
127  Label phi_;
128  bool match_input_;
129  bool rewrite_both_;
130 
131  mutable StateId s_; // current state
132 
133  // Arc and final weight data
134  std::vector<Arc> arcs_; // current arcs
135  ssize_t i_; // current arc position
136  Weight final_; // current final weight
137 
138  mutable fst::Matcher< fst::Fst<Arc> > matcher_;
139  // Failure path for current state
140  mutable FailurePath<Arc> failpath_;
141 };
142 
143 template <class Arc>
144 typename Arc::Weight RmPhiMapper<Arc>::Final(StateId s) const {
145  Weight sfinal = fst_->Final(s);
146  if (sfinal == Weight::Zero())
147  return sfinal;
148 
149  failpath_.SetState(s);
150  Weight fail_weight = Weight::One();
151  for (size_t i = 0; i < failpath_.Length(); ++i) {
152  fail_weight = Times(fail_weight, failpath_.GetWeight(i));
153  Weight dfinal = fst_->Final(failpath_.GetNextState(i));
154  if (dfinal != Weight::Zero()) {
155  // Subtracts corrective weight from final weight
156  sfinal = Minus(sfinal, Times(fail_weight, dfinal));
157  break;
158  }
159  }
160 
161  return sfinal;
162 }
163 
164 template <class Arc>
166  namespace f = fst;
167 
168  failpath_.SetState(s);
169 
170  i_ = 0;
171  arcs_.clear();
172  arcs_.reserve(fst_->NumArcs(s));
173 
174  Label prev_label = f::kNoLabel;
175  for (f::ArcIterator<f::Fst<Arc>> aiter(*fst_, s);
176  !aiter.Done(); aiter.Next()) {
177  Arc arc = aiter.Value();
178  Label label = match_input_ ? arc.ilabel : arc.olabel;
179  if (label == phi_) { // rewrites failure as epsilon transition
180  if (arc.ilabel == phi_ && (rewrite_both_ || match_input_))
181  arc.ilabel = 0;
182  if (arc.olabel == phi_ && (rewrite_both_ || !match_input_))
183  arc.olabel = 0;
184  } else if (label != 0) {
185  Weight fail_weight = Weight::One();
186  // skips any correction if label already processed
187  bool matched = label == prev_label;
188  for (size_t i = 0; i < failpath_.Length() && !matched; ++i) {
189  fail_weight = f::Times(fail_weight, failpath_.GetWeight(i));
190  matcher_.SetState(failpath_.GetNextState(i));
191  for (matcher_.Find(label); !matcher_.Done(); matcher_.Next()) {
192  Arc failarc = matcher_.Value();
193  Label faillabel = match_input_ ? failarc.ilabel : failarc.olabel;
194  Weight corr_weight = f::Times(fail_weight, failarc.weight);
195  Weight comb_weight = Minus(arc.weight, corr_weight);
196  if (faillabel == f::kNoLabel) {
197  // implicit self-loop
198  continue;
199  } else if (failarc.nextstate == arc.nextstate &&
200  failarc.ilabel == arc.ilabel &&
201  failarc.olabel == arc.olabel &&
202  !IsNegative(comb_weight) &&
203  comb_weight != Weight::Zero()) {
204  // Subtracts corrective weight from arc.
205  // Not applied if result is a negative weight
206  // since these arcs need to be treated specially
207  // by the shortest distance algorithm
208  arc.weight = comb_weight;
209  matched = true;
210  } else {
211  // Adds negatively-weighted correcting arc
212  failarc.weight = Minus(Weight::Zero(), corr_weight);
213  arcs_.push_back(failarc);
214  matched = true;
215  }
216  }
217  }
218  }
219  arcs_.push_back(arc); // adds input FST transitions
220  prev_label = label;
221  }
222 }
223 
224 // Converts a canonical SFST possibly with phi labels to an equivalent FST
225 // where the failure transitions are now epsilon transitions. Assumes input has
226 // no (non-phi) epsilons (or treats such epsilons as if they were regular
227 // symbols that are uniquely labeled wrt equivalence). The rewrite_mode
228 // determines if both input and output failure labels are replaced. The
229 // 'OArc' weight (e.g. for SignedLog64Arc) must have a Minus() operation
230 // (forming a ring) and a WeightConvert method from 'IArc'.
231 template <class IArc, class OArc,
232  class WC = fst::WeightConvert<typename IArc::Weight,
233  typename OArc::Weight>>
234 void RmPhi(const fst::Fst<IArc>& ifst,
235  fst::MutableFst<OArc> *ofst,
236  typename IArc::Label phi_label = fst::kNoLabel,
237  fst::MatcherRewriteMode rewrite_mode =
238  fst::MATCHER_REWRITE_AUTO,
239  const WC &weight_convert = WC()) {
240  namespace f = fst;
241  using WCM = f::WeightConvertMapper<IArc, OArc, WC>;
242 
243  WCM to_signed_mapper(weight_convert);
244  if (phi_label != f::kNoLabel) {
245  f::ArcMapFst tfst(ifst, to_signed_mapper);
246  RmPhiMapper<OArc> rm_phi_mapper(
247  tfst, phi_label, true, rewrite_mode);
248  f::StateMap(tfst, ofst, rm_phi_mapper);
249  f::ArcSort(ofst, f::ILabelCompare<OArc>());
250  } else {
251  f::ArcMap(ifst, ofst, to_signed_mapper);
252  }
253 }
254 
255 } // namespace internal
256 
257 } // namespace sfst
258 
259 #endif // NLP_GRM2_SFST_RMPHI_H_
Definition: perplexity.h:32
Entropy64Weight Minus(Entropy64Weight w1, Entropy64Weight w2)
Definition: perplexity.h:129
bool IsNegative(fst::SignedLog64Weight w)
Definition: sfst.h:79
bool Done() const
Definition: rmphi.h:106
Definition: sfstinfo.cc:39
uint64_t Properties(uint64_t iprops) const
Definition: rmphi.h:116
void RmPhi(const fst::Fst< IArc > &ifst, fst::MutableFst< OArc > *ofst, typename IArc::Label phi_label=fst::kNoLabel, fst::MatcherRewriteMode rewrite_mode=fst::MATCHER_REWRITE_AUTO, const WC &weight_convert=WC())
Definition: rmphi.h:234
fst::MapSymbolsAction InputSymbolsAction() const
Definition: rmphi.h:111
Arc::StateId StateId
Definition: rmphi.h:61
RmPhiMapper(const fst::Fst< Arc > &fst, Label phi, bool match_input=true, fst::MatcherRewriteMode rewrite_mode=fst::MATCHER_REWRITE_AUTO)
Definition: rmphi.h:65
fst::MapSymbolsAction OutputSymbolsAction() const
Definition: rmphi.h:114
const Arc & Value() const
Definition: rmphi.h:107
Arc::Weight Weight
Definition: rmphi.h:62
void SetState(StateId s)
Definition: rmphi.h:165
RmPhiMapper(const RmPhiMapper< Arc > &mapper, const fst::Fst< Arc > *fst=nullptr)
Definition: rmphi.h:88
Weight Final(StateId s) const
Definition: rmphi.h:144
Real64Weight Times(SignedLog64Weight w1, Real64Weight w2)
Definition: perplexity.h:41