| // prune.h |
| |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| // |
| // Copyright 2005-2010 Google, Inc. |
| // Author: allauzen@google.com (Cyril Allauzen) |
| // |
| // \file |
| // Functions implementing pruning. |
| |
| #ifndef FST_LIB_PRUNE_H__ |
| #define FST_LIB_PRUNE_H__ |
| |
| #include <vector> |
| using std::vector; |
| |
| #include <fst/arcfilter.h> |
| #include <fst/heap.h> |
| #include <fst/shortest-distance.h> |
| |
| |
| namespace fst { |
| |
| template <class A, class ArcFilter> |
| class PruneOptions { |
| public: |
| typedef typename A::Weight Weight; |
| typedef typename A::StateId StateId; |
| |
| // Pruning weight threshold. |
| Weight weight_threshold; |
| // Pruning state threshold. |
| StateId state_threshold; |
| // Arc filter. |
| ArcFilter filter; |
| // If non-zero, passes in pre-computed shortest distance to final states. |
| const vector<Weight> *distance; |
| // Determines the degree of convergence required when computing shortest |
| // distances. |
| float delta; |
| |
| explicit PruneOptions(const Weight& w, StateId s, ArcFilter f, |
| vector<Weight> *d = 0, float e = kDelta) |
| : weight_threshold(w), |
| state_threshold(s), |
| filter(f), |
| distance(d), |
| delta(e) {} |
| private: |
| PruneOptions(); // disallow |
| }; |
| |
| |
| template <class S, class W> |
| class PruneCompare { |
| public: |
| typedef S StateId; |
| typedef W Weight; |
| |
| PruneCompare(const vector<Weight> &idistance, |
| const vector<Weight> &fdistance) |
| : idistance_(idistance), fdistance_(fdistance) {} |
| |
| bool operator()(const StateId x, const StateId y) const { |
| Weight wx = Times(x < idistance_.size() ? idistance_[x] : Weight::Zero(), |
| x < fdistance_.size() ? fdistance_[x] : Weight::Zero()); |
| Weight wy = Times(y < idistance_.size() ? idistance_[y] : Weight::Zero(), |
| y < fdistance_.size() ? fdistance_[y] : Weight::Zero()); |
| return less_(wx, wy); |
| } |
| |
| private: |
| const vector<Weight> &idistance_; |
| const vector<Weight> &fdistance_; |
| NaturalLess<Weight> less_; |
| }; |
| |
| |
| |
| // Pruning algorithm: this version modifies its input and it takes an |
| // options class as an argment. Delete states and arcs in 'fst' that |
| // do not belong to a successful path whose weight is no more than |
| // the weight of the shortest path Times() 'opts.weight_threshold'. |
| // When 'opts.state_threshold != kNoStateId', the resulting transducer |
| // will restricted further to have at most 'opts.state_threshold' |
| // states. Weights need to be commutative and have the path |
| // property. The weight 'w' of any cycle needs to be bounded, i.e., |
| // 'Plus(w, W::One()) = One()'. |
| template <class Arc, class ArcFilter> |
| void Prune(MutableFst<Arc> *fst, |
| const PruneOptions<Arc, ArcFilter> &opts) { |
| typedef typename Arc::Weight Weight; |
| typedef typename Arc::StateId StateId; |
| |
| if ((Weight::Properties() & (kPath | kCommutative)) |
| != (kPath | kCommutative)) { |
| FSTERROR() << "Prune: Weight needs to have the path property and" |
| << " be commutative: " |
| << Weight::Type(); |
| fst->SetProperties(kError, kError); |
| return; |
| } |
| StateId ns = fst->NumStates(); |
| if (ns == 0) return; |
| vector<Weight> idistance(ns, Weight::Zero()); |
| vector<Weight> tmp; |
| if (!opts.distance) { |
| tmp.reserve(ns); |
| ShortestDistance(*fst, &tmp, true, opts.delta); |
| } |
| const vector<Weight> *fdistance = opts.distance ? opts.distance : &tmp; |
| |
| if ((opts.state_threshold == 0) || |
| (fdistance->size() <= fst->Start()) || |
| ((*fdistance)[fst->Start()] == Weight::Zero())) { |
| fst->DeleteStates(); |
| return; |
| } |
| PruneCompare<StateId, Weight> compare(idistance, *fdistance); |
| Heap< StateId, PruneCompare<StateId, Weight>, false> heap(compare); |
| vector<bool> visited(ns, false); |
| vector<size_t> enqueued(ns, kNoKey); |
| vector<StateId> dead; |
| dead.push_back(fst->AddState()); |
| NaturalLess<Weight> less; |
| Weight limit = Times((*fdistance)[fst->Start()], opts.weight_threshold); |
| |
| StateId num_visited = 0; |
| StateId s = fst->Start(); |
| if (!less(limit, (*fdistance)[s])) { |
| idistance[s] = Weight::One(); |
| enqueued[s] = heap.Insert(s); |
| ++num_visited; |
| } |
| |
| while (!heap.Empty()) { |
| s = heap.Top(); |
| heap.Pop(); |
| enqueued[s] = kNoKey; |
| visited[s] = true; |
| if (less(limit, Times(idistance[s], fst->Final(s)))) |
| fst->SetFinal(s, Weight::Zero()); |
| for (MutableArcIterator< MutableFst<Arc> > ait(fst, s); |
| !ait.Done(); |
| ait.Next()) { |
| Arc arc = ait.Value(); |
| if (!opts.filter(arc)) continue; |
| Weight weight = Times(Times(idistance[s], arc.weight), |
| arc.nextstate < fdistance->size() |
| ? (*fdistance)[arc.nextstate] |
| : Weight::Zero()); |
| if (less(limit, weight)) { |
| arc.nextstate = dead[0]; |
| ait.SetValue(arc); |
| continue; |
| } |
| if (less(Times(idistance[s], arc.weight), idistance[arc.nextstate])) |
| idistance[arc.nextstate] = Times(idistance[s], arc.weight); |
| if (visited[arc.nextstate]) continue; |
| if ((opts.state_threshold != kNoStateId) && |
| (num_visited >= opts.state_threshold)) |
| continue; |
| if (enqueued[arc.nextstate] == kNoKey) { |
| enqueued[arc.nextstate] = heap.Insert(arc.nextstate); |
| ++num_visited; |
| } else { |
| heap.Update(enqueued[arc.nextstate], arc.nextstate); |
| } |
| } |
| } |
| for (size_t i = 0; i < visited.size(); ++i) |
| if (!visited[i]) dead.push_back(i); |
| fst->DeleteStates(dead); |
| } |
| |
| |
| // Pruning algorithm: this version modifies its input and simply takes |
| // the pruning threshold as an argument. Delete states and arcs in |
| // 'fst' that do not belong to a successful path whose weight is no |
| // more than the weight of the shortest path Times() |
| // 'weight_threshold'. When 'state_threshold != kNoStateId', the |
| // resulting transducer will be restricted further to have at most |
| // 'opts.state_threshold' states. Weights need to be commutative and |
| // have the path property. The weight 'w' of any cycle needs to be |
| // bounded, i.e., 'Plus(w, W::One()) = One()'. |
| template <class Arc> |
| void Prune(MutableFst<Arc> *fst, |
| typename Arc::Weight weight_threshold, |
| typename Arc::StateId state_threshold = kNoStateId, |
| double delta = kDelta) { |
| PruneOptions<Arc, AnyArcFilter<Arc> > opts(weight_threshold, state_threshold, |
| AnyArcFilter<Arc>(), 0, delta); |
| Prune(fst, opts); |
| } |
| |
| |
| // Pruning algorithm: this version writes the pruned input Fst to an |
| // output MutableFst and it takes an options class as an argument. |
| // 'ofst' contains states and arcs that belong to a successful path in |
| // 'ifst' whose weight is no more than the weight of the shortest path |
| // Times() 'opts.weight_threshold'. When 'opts.state_threshold != |
| // kNoStateId', 'ofst' will be restricted further to have at most |
| // 'opts.state_threshold' states. Weights need to be commutative and |
| // have the path property. The weight 'w' of any cycle needs to be |
| // bounded, i.e., 'Plus(w, W::One()) = One()'. |
| template <class Arc, class ArcFilter> |
| void Prune(const Fst<Arc> &ifst, |
| MutableFst<Arc> *ofst, |
| const PruneOptions<Arc, ArcFilter> &opts) { |
| typedef typename Arc::Weight Weight; |
| typedef typename Arc::StateId StateId; |
| |
| if ((Weight::Properties() & (kPath | kCommutative)) |
| != (kPath | kCommutative)) { |
| FSTERROR() << "Prune: Weight needs to have the path property and" |
| << " be commutative: " |
| << Weight::Type(); |
| ofst->SetProperties(kError, kError); |
| return; |
| } |
| ofst->DeleteStates(); |
| ofst->SetInputSymbols(ifst.InputSymbols()); |
| ofst->SetOutputSymbols(ifst.OutputSymbols()); |
| if (ifst.Start() == kNoStateId) |
| return; |
| NaturalLess<Weight> less; |
| if (less(opts.weight_threshold, Weight::One()) || |
| (opts.state_threshold == 0)) |
| return; |
| vector<Weight> idistance; |
| vector<Weight> tmp; |
| if (!opts.distance) |
| ShortestDistance(ifst, &tmp, true, opts.delta); |
| const vector<Weight> *fdistance = opts.distance ? opts.distance : &tmp; |
| |
| if ((fdistance->size() <= ifst.Start()) || |
| ((*fdistance)[ifst.Start()] == Weight::Zero())) { |
| return; |
| } |
| PruneCompare<StateId, Weight> compare(idistance, *fdistance); |
| Heap< StateId, PruneCompare<StateId, Weight>, false> heap(compare); |
| vector<StateId> copy; |
| vector<size_t> enqueued; |
| vector<bool> visited; |
| |
| StateId s = ifst.Start(); |
| Weight limit = Times(s < fdistance->size() ? (*fdistance)[s] : Weight::Zero(), |
| opts.weight_threshold); |
| while (copy.size() <= s) |
| copy.push_back(kNoStateId); |
| copy[s] = ofst->AddState(); |
| ofst->SetStart(copy[s]); |
| while (idistance.size() <= s) |
| idistance.push_back(Weight::Zero()); |
| idistance[s] = Weight::One(); |
| while (enqueued.size() <= s) { |
| enqueued.push_back(kNoKey); |
| visited.push_back(false); |
| } |
| enqueued[s] = heap.Insert(s); |
| |
| while (!heap.Empty()) { |
| s = heap.Top(); |
| heap.Pop(); |
| enqueued[s] = kNoKey; |
| visited[s] = true; |
| if (!less(limit, Times(idistance[s], ifst.Final(s)))) |
| ofst->SetFinal(copy[s], ifst.Final(s)); |
| for (ArcIterator< Fst<Arc> > ait(ifst, s); |
| !ait.Done(); |
| ait.Next()) { |
| const Arc &arc = ait.Value(); |
| if (!opts.filter(arc)) continue; |
| Weight weight = Times(Times(idistance[s], arc.weight), |
| arc.nextstate < fdistance->size() |
| ? (*fdistance)[arc.nextstate] |
| : Weight::Zero()); |
| if (less(limit, weight)) continue; |
| if ((opts.state_threshold != kNoStateId) && |
| (ofst->NumStates() >= opts.state_threshold)) |
| continue; |
| while (idistance.size() <= arc.nextstate) |
| idistance.push_back(Weight::Zero()); |
| if (less(Times(idistance[s], arc.weight), |
| idistance[arc.nextstate])) |
| idistance[arc.nextstate] = Times(idistance[s], arc.weight); |
| while (copy.size() <= arc.nextstate) |
| copy.push_back(kNoStateId); |
| if (copy[arc.nextstate] == kNoStateId) |
| copy[arc.nextstate] = ofst->AddState(); |
| ofst->AddArc(copy[s], Arc(arc.ilabel, arc.olabel, arc.weight, |
| copy[arc.nextstate])); |
| while (enqueued.size() <= arc.nextstate) { |
| enqueued.push_back(kNoKey); |
| visited.push_back(false); |
| } |
| if (visited[arc.nextstate]) continue; |
| if (enqueued[arc.nextstate] == kNoKey) |
| enqueued[arc.nextstate] = heap.Insert(arc.nextstate); |
| else |
| heap.Update(enqueued[arc.nextstate], arc.nextstate); |
| } |
| } |
| } |
| |
| |
| // Pruning algorithm: this version writes the pruned input Fst to an |
| // output MutableFst and simply takes the pruning threshold as an |
| // argument. 'ofst' contains states and arcs that belong to a |
| // successful path in 'ifst' whose weight is no more than |
| // the weight of the shortest path Times() 'weight_threshold'. When |
| // 'state_threshold != kNoStateId', 'ofst' will be restricted further |
| // to have at most 'opts.state_threshold' states. Weights need to be |
| // commutative and have the path property. The weight 'w' of any cycle |
| // needs to be bounded, i.e., 'Plus(w, W::One()) = W::One()'. |
| template <class Arc> |
| void Prune(const Fst<Arc> &ifst, |
| MutableFst<Arc> *ofst, |
| typename Arc::Weight weight_threshold, |
| typename Arc::StateId state_threshold = kNoStateId, |
| float delta = kDelta) { |
| PruneOptions<Arc, AnyArcFilter<Arc> > opts(weight_threshold, state_threshold, |
| AnyArcFilter<Arc>(), 0, delta); |
| Prune(ifst, ofst, opts); |
| } |
| |
| } // namespace fst |
| |
| #endif // FST_LIB_PRUNE_H_ |