| // rmepsilon.h |
| |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| // |
| // Copyright 2005-2010 Google, Inc. |
| // Author: allauzen@google.com (Cyril Allauzen) |
| // |
| // \file |
| // Functions and classes that implemement epsilon-removal. |
| |
| #ifndef FST_LIB_RMEPSILON_H__ |
| #define FST_LIB_RMEPSILON_H__ |
| |
| #include <unordered_map> |
| using std::tr1::unordered_map; |
| using std::tr1::unordered_multimap; |
| #include <fst/slist.h> |
| #include <stack> |
| #include <string> |
| #include <utility> |
| using std::pair; using std::make_pair; |
| #include <vector> |
| using std::vector; |
| |
| #include <fst/arcfilter.h> |
| #include <fst/cache.h> |
| #include <fst/connect.h> |
| #include <fst/factor-weight.h> |
| #include <fst/invert.h> |
| #include <fst/prune.h> |
| #include <fst/queue.h> |
| #include <fst/shortest-distance.h> |
| #include <fst/topsort.h> |
| |
| |
| namespace fst { |
| |
| template <class Arc, class Queue> |
| class RmEpsilonOptions |
| : public ShortestDistanceOptions<Arc, Queue, EpsilonArcFilter<Arc> > { |
| public: |
| typedef typename Arc::StateId StateId; |
| typedef typename Arc::Weight Weight; |
| |
| bool connect; // Connect output |
| Weight weight_threshold; // Pruning weight threshold. |
| StateId state_threshold; // Pruning state threshold. |
| |
| explicit RmEpsilonOptions(Queue *q, float d = kDelta, bool c = true, |
| Weight w = Weight::Zero(), |
| StateId n = kNoStateId) |
| : ShortestDistanceOptions< Arc, Queue, EpsilonArcFilter<Arc> >( |
| q, EpsilonArcFilter<Arc>(), kNoStateId, d), |
| connect(c), weight_threshold(w), state_threshold(n) {} |
| private: |
| RmEpsilonOptions(); // disallow |
| }; |
| |
| // Computation state of the epsilon-removal algorithm. |
| template <class Arc, class Queue> |
| class RmEpsilonState { |
| public: |
| typedef typename Arc::Label Label; |
| typedef typename Arc::StateId StateId; |
| typedef typename Arc::Weight Weight; |
| |
| RmEpsilonState(const Fst<Arc> &fst, |
| vector<Weight> *distance, |
| const RmEpsilonOptions<Arc, Queue> &opts) |
| : fst_(fst), distance_(distance), sd_state_(fst_, distance, opts, true), |
| expand_id_(0) {} |
| |
| // Compute arcs and final weight for state 's' |
| void Expand(StateId s); |
| |
| // Returns arcs of expanded state. |
| vector<Arc> &Arcs() { return arcs_; } |
| |
| // Returns final weight of expanded state. |
| const Weight &Final() const { return final_; } |
| |
| // Return true if an error has occured. |
| bool Error() const { return sd_state_.Error(); } |
| |
| private: |
| static const size_t kPrime0 = 7853; |
| static const size_t kPrime1 = 7867; |
| |
| struct Element { |
| Label ilabel; |
| Label olabel; |
| StateId nextstate; |
| |
| Element() {} |
| |
| Element(Label i, Label o, StateId s) |
| : ilabel(i), olabel(o), nextstate(s) {} |
| }; |
| |
| class ElementKey { |
| public: |
| size_t operator()(const Element& e) const { |
| return static_cast<size_t>(e.nextstate); |
| return static_cast<size_t>(e.nextstate + |
| e.ilabel * kPrime0 + |
| e.olabel * kPrime1); |
| } |
| |
| private: |
| }; |
| |
| class ElementEqual { |
| public: |
| bool operator()(const Element &e1, const Element &e2) const { |
| return (e1.ilabel == e2.ilabel) && (e1.olabel == e2.olabel) |
| && (e1.nextstate == e2.nextstate); |
| } |
| }; |
| |
| typedef unordered_map<Element, pair<StateId, size_t>, |
| ElementKey, ElementEqual> ElementMap; |
| |
| const Fst<Arc> &fst_; |
| // Distance from state being expanded in epsilon-closure. |
| vector<Weight> *distance_; |
| // Shortest distance algorithm computation state. |
| ShortestDistanceState<Arc, Queue, EpsilonArcFilter<Arc> > sd_state_; |
| // Maps an element 'e' to a pair 'p' corresponding to a position |
| // in the arcs vector of the state being expanded. 'e' corresponds |
| // to the position 'p.second' in the 'arcs_' vector if 'p.first' is |
| // equal to the state being expanded. |
| ElementMap element_map_; |
| EpsilonArcFilter<Arc> eps_filter_; |
| stack<StateId> eps_queue_; // Queue used to visit the epsilon-closure |
| vector<bool> visited_; // '[i] = true' if state 'i' has been visited |
| slist<StateId> visited_states_; // List of visited states |
| vector<Arc> arcs_; // Arcs of state being expanded |
| Weight final_; // Final weight of state being expanded |
| StateId expand_id_; // Unique ID for each call to Expand |
| |
| DISALLOW_COPY_AND_ASSIGN(RmEpsilonState); |
| }; |
| |
| template <class Arc, class Queue> |
| const size_t RmEpsilonState<Arc, Queue>::kPrime0; |
| template <class Arc, class Queue> |
| const size_t RmEpsilonState<Arc, Queue>::kPrime1; |
| |
| |
| template <class Arc, class Queue> |
| void RmEpsilonState<Arc,Queue>::Expand(typename Arc::StateId source) { |
| final_ = Weight::Zero(); |
| arcs_.clear(); |
| sd_state_.ShortestDistance(source); |
| if (sd_state_.Error()) |
| return; |
| eps_queue_.push(source); |
| |
| while (!eps_queue_.empty()) { |
| StateId state = eps_queue_.top(); |
| eps_queue_.pop(); |
| |
| while (visited_.size() <= state) visited_.push_back(false); |
| if (visited_[state]) continue; |
| visited_[state] = true; |
| visited_states_.push_front(state); |
| |
| for (ArcIterator< Fst<Arc> > ait(fst_, state); |
| !ait.Done(); |
| ait.Next()) { |
| Arc arc = ait.Value(); |
| arc.weight = Times((*distance_)[state], arc.weight); |
| |
| if (eps_filter_(arc)) { |
| while (visited_.size() <= arc.nextstate) |
| visited_.push_back(false); |
| if (!visited_[arc.nextstate]) |
| eps_queue_.push(arc.nextstate); |
| } else { |
| Element element(arc.ilabel, arc.olabel, arc.nextstate); |
| typename ElementMap::iterator it = element_map_.find(element); |
| if (it == element_map_.end()) { |
| element_map_.insert( |
| pair<Element, pair<StateId, size_t> > |
| (element, pair<StateId, size_t>(expand_id_, arcs_.size()))); |
| arcs_.push_back(arc); |
| } else { |
| if (((*it).second).first == expand_id_) { |
| Weight &w = arcs_[((*it).second).second].weight; |
| w = Plus(w, arc.weight); |
| } else { |
| ((*it).second).first = expand_id_; |
| ((*it).second).second = arcs_.size(); |
| arcs_.push_back(arc); |
| } |
| } |
| } |
| } |
| final_ = Plus(final_, Times((*distance_)[state], fst_.Final(state))); |
| } |
| |
| while (!visited_states_.empty()) { |
| visited_[visited_states_.front()] = false; |
| visited_states_.pop_front(); |
| } |
| ++expand_id_; |
| } |
| |
| // Removes epsilon-transitions (when both the input and output label |
| // are an epsilon) from a transducer. The result will be an equivalent |
| // FST that has no such epsilon transitions. This version modifies |
| // its input. It allows fine control via the options argument; see |
| // below for a simpler interface. |
| // |
| // The vector 'distance' will be used to hold the shortest distances |
| // during the epsilon-closure computation. The state queue discipline |
| // and convergence delta are taken in the options argument. |
| template <class Arc, class Queue> |
| void RmEpsilon(MutableFst<Arc> *fst, |
| vector<typename Arc::Weight> *distance, |
| const RmEpsilonOptions<Arc, Queue> &opts) { |
| typedef typename Arc::StateId StateId; |
| typedef typename Arc::Weight Weight; |
| typedef typename Arc::Label Label; |
| |
| if (fst->Start() == kNoStateId) { |
| return; |
| } |
| |
| // 'noneps_in[s]' will be set to true iff 's' admits a non-epsilon |
| // incoming transition or is the start state. |
| vector<bool> noneps_in(fst->NumStates(), false); |
| noneps_in[fst->Start()] = true; |
| for (StateId i = 0; i < fst->NumStates(); ++i) { |
| for (ArcIterator<Fst<Arc> > aiter(*fst, i); |
| !aiter.Done(); |
| aiter.Next()) { |
| if (aiter.Value().ilabel != 0 || aiter.Value().olabel != 0) |
| noneps_in[aiter.Value().nextstate] = true; |
| } |
| } |
| |
| // States sorted in topological order when (acyclic) or generic |
| // topological order (cyclic). |
| vector<StateId> states; |
| states.reserve(fst->NumStates()); |
| |
| if (fst->Properties(kTopSorted, false) & kTopSorted) { |
| for (StateId i = 0; i < fst->NumStates(); i++) |
| states.push_back(i); |
| } else if (fst->Properties(kAcyclic, false) & kAcyclic) { |
| vector<StateId> order; |
| bool acyclic; |
| TopOrderVisitor<Arc> top_order_visitor(&order, &acyclic); |
| DfsVisit(*fst, &top_order_visitor, EpsilonArcFilter<Arc>()); |
| // Sanity check: should be acyclic if property bit is set. |
| if(!acyclic) { |
| FSTERROR() << "RmEpsilon: inconsistent acyclic property bit"; |
| fst->SetProperties(kError, kError); |
| return; |
| } |
| states.resize(order.size()); |
| for (StateId i = 0; i < order.size(); i++) |
| states[order[i]] = i; |
| } else { |
| uint64 props; |
| vector<StateId> scc; |
| SccVisitor<Arc> scc_visitor(&scc, 0, 0, &props); |
| DfsVisit(*fst, &scc_visitor, EpsilonArcFilter<Arc>()); |
| vector<StateId> first(scc.size(), kNoStateId); |
| vector<StateId> next(scc.size(), kNoStateId); |
| for (StateId i = 0; i < scc.size(); i++) { |
| if (first[scc[i]] != kNoStateId) |
| next[i] = first[scc[i]]; |
| first[scc[i]] = i; |
| } |
| for (StateId i = 0; i < first.size(); i++) |
| for (StateId j = first[i]; j != kNoStateId; j = next[j]) |
| states.push_back(j); |
| } |
| |
| RmEpsilonState<Arc, Queue> |
| rmeps_state(*fst, distance, opts); |
| |
| while (!states.empty()) { |
| StateId state = states.back(); |
| states.pop_back(); |
| if (!noneps_in[state]) |
| continue; |
| rmeps_state.Expand(state); |
| fst->SetFinal(state, rmeps_state.Final()); |
| fst->DeleteArcs(state); |
| vector<Arc> &arcs = rmeps_state.Arcs(); |
| fst->ReserveArcs(state, arcs.size()); |
| while (!arcs.empty()) { |
| fst->AddArc(state, arcs.back()); |
| arcs.pop_back(); |
| } |
| } |
| |
| for (StateId s = 0; s < fst->NumStates(); ++s) { |
| if (!noneps_in[s]) |
| fst->DeleteArcs(s); |
| } |
| |
| if(rmeps_state.Error()) |
| fst->SetProperties(kError, kError); |
| fst->SetProperties( |
| RmEpsilonProperties(fst->Properties(kFstProperties, false)), |
| kFstProperties); |
| |
| if (opts.weight_threshold != Weight::Zero() || |
| opts.state_threshold != kNoStateId) |
| Prune(fst, opts.weight_threshold, opts.state_threshold); |
| if (opts.connect && (opts.weight_threshold == Weight::Zero() || |
| opts.state_threshold != kNoStateId)) |
| Connect(fst); |
| } |
| |
| // Removes epsilon-transitions (when both the input and output label |
| // are an epsilon) from a transducer. The result will be an equivalent |
| // FST that has no such epsilon transitions. This version modifies its |
| // input. It has a simplified interface; see above for a version that |
| // allows finer control. |
| // |
| // Complexity: |
| // - Time: |
| // - Unweighted: O(V2 + V E) |
| // - Acyclic: O(V2 + V E) |
| // - Tropical semiring: O(V2 log V + V E) |
| // - General: exponential |
| // - Space: O(V E) |
| // where V = # of states visited, E = # of arcs. |
| // |
| // References: |
| // - Mehryar Mohri. Generic Epsilon-Removal and Input |
| // Epsilon-Normalization Algorithms for Weighted Transducers, |
| // "International Journal of Computer Science", 13(1):129-143 (2002). |
| template <class Arc> |
| void RmEpsilon(MutableFst<Arc> *fst, |
| bool connect = true, |
| typename Arc::Weight weight_threshold = Arc::Weight::Zero(), |
| typename Arc::StateId state_threshold = kNoStateId, |
| float delta = kDelta) { |
| typedef typename Arc::StateId StateId; |
| typedef typename Arc::Weight Weight; |
| typedef typename Arc::Label Label; |
| |
| vector<Weight> distance; |
| AutoQueue<StateId> state_queue(*fst, &distance, EpsilonArcFilter<Arc>()); |
| RmEpsilonOptions<Arc, AutoQueue<StateId> > |
| opts(&state_queue, delta, connect, weight_threshold, state_threshold); |
| |
| RmEpsilon(fst, &distance, opts); |
| } |
| |
| |
| struct RmEpsilonFstOptions : CacheOptions { |
| float delta; |
| |
| RmEpsilonFstOptions(const CacheOptions &opts, float delta = kDelta) |
| : CacheOptions(opts), delta(delta) {} |
| |
| explicit RmEpsilonFstOptions(float delta = kDelta) : delta(delta) {} |
| }; |
| |
| |
| // Implementation of delayed RmEpsilonFst. |
| template <class A> |
| class RmEpsilonFstImpl : public CacheImpl<A> { |
| public: |
| using FstImpl<A>::SetType; |
| using FstImpl<A>::SetProperties; |
| using FstImpl<A>::SetInputSymbols; |
| using FstImpl<A>::SetOutputSymbols; |
| |
| using CacheBaseImpl< CacheState<A> >::PushArc; |
| using CacheBaseImpl< CacheState<A> >::HasArcs; |
| using CacheBaseImpl< CacheState<A> >::HasFinal; |
| using CacheBaseImpl< CacheState<A> >::HasStart; |
| using CacheBaseImpl< CacheState<A> >::SetArcs; |
| using CacheBaseImpl< CacheState<A> >::SetFinal; |
| using CacheBaseImpl< CacheState<A> >::SetStart; |
| |
| typedef typename A::Label Label; |
| typedef typename A::Weight Weight; |
| typedef typename A::StateId StateId; |
| typedef CacheState<A> State; |
| |
| RmEpsilonFstImpl(const Fst<A>& fst, const RmEpsilonFstOptions &opts) |
| : CacheImpl<A>(opts), |
| fst_(fst.Copy()), |
| delta_(opts.delta), |
| rmeps_state_( |
| *fst_, |
| &distance_, |
| RmEpsilonOptions<A, FifoQueue<StateId> >(&queue_, delta_, false)) { |
| SetType("rmepsilon"); |
| uint64 props = fst.Properties(kFstProperties, false); |
| SetProperties(RmEpsilonProperties(props, true), kCopyProperties); |
| SetInputSymbols(fst.InputSymbols()); |
| SetOutputSymbols(fst.OutputSymbols()); |
| } |
| |
| RmEpsilonFstImpl(const RmEpsilonFstImpl &impl) |
| : CacheImpl<A>(impl), |
| fst_(impl.fst_->Copy(true)), |
| delta_(impl.delta_), |
| rmeps_state_( |
| *fst_, |
| &distance_, |
| RmEpsilonOptions<A, FifoQueue<StateId> >(&queue_, delta_, false)) { |
| SetType("rmepsilon"); |
| SetProperties(impl.Properties(), kCopyProperties); |
| SetInputSymbols(impl.InputSymbols()); |
| SetOutputSymbols(impl.OutputSymbols()); |
| } |
| |
| ~RmEpsilonFstImpl() { |
| delete fst_; |
| } |
| |
| StateId Start() { |
| if (!HasStart()) { |
| SetStart(fst_->Start()); |
| } |
| return CacheImpl<A>::Start(); |
| } |
| |
| Weight Final(StateId s) { |
| if (!HasFinal(s)) { |
| Expand(s); |
| } |
| return CacheImpl<A>::Final(s); |
| } |
| |
| size_t NumArcs(StateId s) { |
| if (!HasArcs(s)) |
| Expand(s); |
| return CacheImpl<A>::NumArcs(s); |
| } |
| |
| size_t NumInputEpsilons(StateId s) { |
| if (!HasArcs(s)) |
| Expand(s); |
| return CacheImpl<A>::NumInputEpsilons(s); |
| } |
| |
| size_t NumOutputEpsilons(StateId s) { |
| if (!HasArcs(s)) |
| Expand(s); |
| return CacheImpl<A>::NumOutputEpsilons(s); |
| } |
| |
| uint64 Properties() const { return Properties(kFstProperties); } |
| |
| // Set error if found; return FST impl properties. |
| uint64 Properties(uint64 mask) const { |
| if ((mask & kError) && |
| (fst_->Properties(kError, false) || rmeps_state_.Error())) |
| SetProperties(kError, kError); |
| return FstImpl<A>::Properties(mask); |
| } |
| |
| void InitArcIterator(StateId s, ArcIteratorData<A> *data) { |
| if (!HasArcs(s)) |
| Expand(s); |
| CacheImpl<A>::InitArcIterator(s, data); |
| } |
| |
| void Expand(StateId s) { |
| rmeps_state_.Expand(s); |
| SetFinal(s, rmeps_state_.Final()); |
| vector<A> &arcs = rmeps_state_.Arcs(); |
| while (!arcs.empty()) { |
| PushArc(s, arcs.back()); |
| arcs.pop_back(); |
| } |
| SetArcs(s); |
| } |
| |
| private: |
| const Fst<A> *fst_; |
| float delta_; |
| vector<Weight> distance_; |
| FifoQueue<StateId> queue_; |
| RmEpsilonState<A, FifoQueue<StateId> > rmeps_state_; |
| |
| void operator=(const RmEpsilonFstImpl<A> &); // disallow |
| }; |
| |
| |
| // Removes epsilon-transitions (when both the input and output label |
| // are an epsilon) from a transducer. The result will be an equivalent |
| // FST that has no such epsilon transitions. This version is a |
| // delayed Fst. |
| // |
| // Complexity: |
| // - Time: |
| // - Unweighted: O(v^2 + v e) |
| // - General: exponential |
| // - Space: O(v e) |
| // where v = # of states visited, e = # of arcs visited. Constant time |
| // to visit an input state or arc is assumed and exclusive of caching. |
| // |
| // References: |
| // - Mehryar Mohri. Generic Epsilon-Removal and Input |
| // Epsilon-Normalization Algorithms for Weighted Transducers, |
| // "International Journal of Computer Science", 13(1):129-143 (2002). |
| // |
| // This class attaches interface to implementation and handles |
| // reference counting, delegating most methods to ImplToFst. |
| template <class A> |
| class RmEpsilonFst : public ImplToFst< RmEpsilonFstImpl<A> > { |
| public: |
| friend class ArcIterator< RmEpsilonFst<A> >; |
| friend class StateIterator< RmEpsilonFst<A> >; |
| |
| typedef A Arc; |
| typedef typename A::StateId StateId; |
| typedef CacheState<A> State; |
| typedef RmEpsilonFstImpl<A> Impl; |
| |
| RmEpsilonFst(const Fst<A> &fst) |
| : ImplToFst<Impl>(new Impl(fst, RmEpsilonFstOptions())) {} |
| |
| RmEpsilonFst(const Fst<A> &fst, const RmEpsilonFstOptions &opts) |
| : ImplToFst<Impl>(new Impl(fst, opts)) {} |
| |
| // See Fst<>::Copy() for doc. |
| RmEpsilonFst(const RmEpsilonFst<A> &fst, bool safe = false) |
| : ImplToFst<Impl>(fst, safe) {} |
| |
| // Get a copy of this RmEpsilonFst. See Fst<>::Copy() for further doc. |
| virtual RmEpsilonFst<A> *Copy(bool safe = false) const { |
| return new RmEpsilonFst<A>(*this, safe); |
| } |
| |
| virtual inline void InitStateIterator(StateIteratorData<A> *data) const; |
| |
| virtual void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const { |
| GetImpl()->InitArcIterator(s, data); |
| } |
| |
| private: |
| // Makes visible to friends. |
| Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); } |
| |
| void operator=(const RmEpsilonFst<A> &fst); // disallow |
| }; |
| |
| // Specialization for RmEpsilonFst. |
| template<class A> |
| class StateIterator< RmEpsilonFst<A> > |
| : public CacheStateIterator< RmEpsilonFst<A> > { |
| public: |
| explicit StateIterator(const RmEpsilonFst<A> &fst) |
| : CacheStateIterator< RmEpsilonFst<A> >(fst, fst.GetImpl()) {} |
| }; |
| |
| |
| // Specialization for RmEpsilonFst. |
| template <class A> |
| class ArcIterator< RmEpsilonFst<A> > |
| : public CacheArcIterator< RmEpsilonFst<A> > { |
| public: |
| typedef typename A::StateId StateId; |
| |
| ArcIterator(const RmEpsilonFst<A> &fst, StateId s) |
| : CacheArcIterator< RmEpsilonFst<A> >(fst.GetImpl(), s) { |
| if (!fst.GetImpl()->HasArcs(s)) |
| fst.GetImpl()->Expand(s); |
| } |
| |
| private: |
| DISALLOW_COPY_AND_ASSIGN(ArcIterator); |
| }; |
| |
| |
| template <class A> inline |
| void RmEpsilonFst<A>::InitStateIterator(StateIteratorData<A> *data) const { |
| data->base = new StateIterator< RmEpsilonFst<A> >(*this); |
| } |
| |
| |
| // Useful alias when using StdArc. |
| typedef RmEpsilonFst<StdArc> StdRmEpsilonFst; |
| |
| } // namespace fst |
| |
| #endif // FST_LIB_RMEPSILON_H__ |