| // expand.h |
| |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| // |
| // Copyright 2005-2010 Google, Inc. |
| // Author: riley@google.com (Michael Riley) |
| // |
| // \file |
| // Expand a PDT to an FST. |
| |
| #ifndef FST_EXTENSIONS_PDT_EXPAND_H__ |
| #define FST_EXTENSIONS_PDT_EXPAND_H__ |
| |
| #include <vector> |
| using std::vector; |
| |
| #include <fst/extensions/pdt/pdt.h> |
| #include <fst/extensions/pdt/paren.h> |
| #include <fst/extensions/pdt/shortest-path.h> |
| #include <fst/extensions/pdt/reverse.h> |
| #include <fst/cache.h> |
| #include <fst/mutable-fst.h> |
| #include <fst/queue.h> |
| #include <fst/state-table.h> |
| #include <fst/test-properties.h> |
| |
| namespace fst { |
| |
| template <class Arc> |
| struct ExpandFstOptions : public CacheOptions { |
| bool keep_parentheses; |
| PdtStack<typename Arc::StateId, typename Arc::Label> *stack; |
| PdtStateTable<typename Arc::StateId, typename Arc::StateId> *state_table; |
| |
| ExpandFstOptions( |
| const CacheOptions &opts = CacheOptions(), |
| bool kp = false, |
| PdtStack<typename Arc::StateId, typename Arc::Label> *s = 0, |
| PdtStateTable<typename Arc::StateId, typename Arc::StateId> *st = 0) |
| : CacheOptions(opts), keep_parentheses(kp), stack(s), state_table(st) {} |
| }; |
| |
| // Properties for an expanded PDT. |
| inline uint64 ExpandProperties(uint64 inprops) { |
| return inprops & (kAcceptor | kAcyclic | kInitialAcyclic | kUnweighted); |
| } |
| |
| |
| // Implementation class for ExpandFst |
| template <class A> |
| class ExpandFstImpl |
| : public CacheImpl<A> { |
| public: |
| using FstImpl<A>::SetType; |
| using FstImpl<A>::SetProperties; |
| using FstImpl<A>::Properties; |
| using FstImpl<A>::SetInputSymbols; |
| using FstImpl<A>::SetOutputSymbols; |
| |
| using CacheBaseImpl< CacheState<A> >::PushArc; |
| using CacheBaseImpl< CacheState<A> >::HasArcs; |
| using CacheBaseImpl< CacheState<A> >::HasFinal; |
| using CacheBaseImpl< CacheState<A> >::HasStart; |
| using CacheBaseImpl< CacheState<A> >::SetArcs; |
| using CacheBaseImpl< CacheState<A> >::SetFinal; |
| using CacheBaseImpl< CacheState<A> >::SetStart; |
| |
| typedef A Arc; |
| typedef typename A::Label Label; |
| typedef typename A::Weight Weight; |
| typedef typename A::StateId StateId; |
| typedef StateId StackId; |
| typedef PdtStateTuple<StateId, StackId> StateTuple; |
| |
| ExpandFstImpl(const Fst<A> &fst, |
| const vector<pair<typename Arc::Label, |
| typename Arc::Label> > &parens, |
| const ExpandFstOptions<A> &opts) |
| : CacheImpl<A>(opts), fst_(fst.Copy()), |
| stack_(opts.stack ? opts.stack: new PdtStack<StateId, Label>(parens)), |
| state_table_(opts.state_table ? opts.state_table : |
| new PdtStateTable<StateId, StackId>()), |
| own_stack_(opts.stack == 0), own_state_table_(opts.state_table == 0), |
| keep_parentheses_(opts.keep_parentheses) { |
| SetType("expand"); |
| |
| uint64 props = fst.Properties(kFstProperties, false); |
| SetProperties(ExpandProperties(props), kCopyProperties); |
| |
| SetInputSymbols(fst.InputSymbols()); |
| SetOutputSymbols(fst.OutputSymbols()); |
| } |
| |
| ExpandFstImpl(const ExpandFstImpl &impl) |
| : CacheImpl<A>(impl), |
| fst_(impl.fst_->Copy(true)), |
| stack_(new PdtStack<StateId, Label>(*impl.stack_)), |
| state_table_(new PdtStateTable<StateId, StackId>()), |
| own_stack_(true), own_state_table_(true), |
| keep_parentheses_(impl.keep_parentheses_) { |
| SetType("expand"); |
| SetProperties(impl.Properties(), kCopyProperties); |
| SetInputSymbols(impl.InputSymbols()); |
| SetOutputSymbols(impl.OutputSymbols()); |
| } |
| |
| ~ExpandFstImpl() { |
| delete fst_; |
| if (own_stack_) |
| delete stack_; |
| if (own_state_table_) |
| delete state_table_; |
| } |
| |
| StateId Start() { |
| if (!HasStart()) { |
| StateId s = fst_->Start(); |
| if (s == kNoStateId) |
| return kNoStateId; |
| StateTuple tuple(s, 0); |
| StateId start = state_table_->FindState(tuple); |
| SetStart(start); |
| } |
| return CacheImpl<A>::Start(); |
| } |
| |
| Weight Final(StateId s) { |
| if (!HasFinal(s)) { |
| const StateTuple &tuple = state_table_->Tuple(s); |
| Weight w = fst_->Final(tuple.state_id); |
| if (w != Weight::Zero() && tuple.stack_id == 0) |
| SetFinal(s, w); |
| else |
| SetFinal(s, Weight::Zero()); |
| } |
| return CacheImpl<A>::Final(s); |
| } |
| |
| size_t NumArcs(StateId s) { |
| if (!HasArcs(s)) { |
| ExpandState(s); |
| } |
| return CacheImpl<A>::NumArcs(s); |
| } |
| |
| size_t NumInputEpsilons(StateId s) { |
| if (!HasArcs(s)) |
| ExpandState(s); |
| return CacheImpl<A>::NumInputEpsilons(s); |
| } |
| |
| size_t NumOutputEpsilons(StateId s) { |
| if (!HasArcs(s)) |
| ExpandState(s); |
| return CacheImpl<A>::NumOutputEpsilons(s); |
| } |
| |
| void InitArcIterator(StateId s, ArcIteratorData<A> *data) { |
| if (!HasArcs(s)) |
| ExpandState(s); |
| CacheImpl<A>::InitArcIterator(s, data); |
| } |
| |
| // Computes the outgoing transitions from a state, creating new destination |
| // states as needed. |
| void ExpandState(StateId s) { |
| StateTuple tuple = state_table_->Tuple(s); |
| for (ArcIterator< Fst<A> > aiter(*fst_, tuple.state_id); |
| !aiter.Done(); aiter.Next()) { |
| Arc arc = aiter.Value(); |
| StackId stack_id = stack_->Find(tuple.stack_id, arc.ilabel); |
| if (stack_id == -1) { |
| // Non-matching close parenthesis |
| continue; |
| } else if ((stack_id != tuple.stack_id) && !keep_parentheses_) { |
| // Stack push/pop |
| arc.ilabel = arc.olabel = 0; |
| } |
| |
| StateTuple ntuple(arc.nextstate, stack_id); |
| arc.nextstate = state_table_->FindState(ntuple); |
| PushArc(s, arc); |
| } |
| SetArcs(s); |
| } |
| |
| const PdtStack<StackId, Label> &GetStack() const { return *stack_; } |
| |
| const PdtStateTable<StateId, StackId> &GetStateTable() const { |
| return *state_table_; |
| } |
| |
| private: |
| const Fst<A> *fst_; |
| |
| PdtStack<StackId, Label> *stack_; |
| PdtStateTable<StateId, StackId> *state_table_; |
| bool own_stack_; |
| bool own_state_table_; |
| bool keep_parentheses_; |
| |
| void operator=(const ExpandFstImpl<A> &); // disallow |
| }; |
| |
| // Expands a pushdown transducer (PDT) encoded as an FST into an FST. |
| // This version is a delayed Fst. In the PDT, some transitions are |
| // labeled with open or close parentheses. To be interpreted as a PDT, |
| // the parens must balance on a path. The open-close parenthesis label |
| // pairs are passed in 'parens'. The expansion enforces the |
| // parenthesis constraints. The PDT must be expandable as an FST. |
| // |
| // This class attaches interface to implementation and handles |
| // reference counting, delegating most methods to ImplToFst. |
| template <class A> |
| class ExpandFst : public ImplToFst< ExpandFstImpl<A> > { |
| public: |
| friend class ArcIterator< ExpandFst<A> >; |
| friend class StateIterator< ExpandFst<A> >; |
| |
| typedef A Arc; |
| typedef typename A::Label Label; |
| typedef typename A::Weight Weight; |
| typedef typename A::StateId StateId; |
| typedef StateId StackId; |
| typedef CacheState<A> State; |
| typedef ExpandFstImpl<A> Impl; |
| |
| ExpandFst(const Fst<A> &fst, |
| const vector<pair<typename Arc::Label, |
| typename Arc::Label> > &parens) |
| : ImplToFst<Impl>(new Impl(fst, parens, ExpandFstOptions<A>())) {} |
| |
| ExpandFst(const Fst<A> &fst, |
| const vector<pair<typename Arc::Label, |
| typename Arc::Label> > &parens, |
| const ExpandFstOptions<A> &opts) |
| : ImplToFst<Impl>(new Impl(fst, parens, opts)) {} |
| |
| // See Fst<>::Copy() for doc. |
| ExpandFst(const ExpandFst<A> &fst, bool safe = false) |
| : ImplToFst<Impl>(fst, safe) {} |
| |
| // Get a copy of this ExpandFst. See Fst<>::Copy() for further doc. |
| virtual ExpandFst<A> *Copy(bool safe = false) const { |
| return new ExpandFst<A>(*this, safe); |
| } |
| |
| virtual inline void InitStateIterator(StateIteratorData<A> *data) const; |
| |
| virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const { |
| GetImpl()->InitArcIterator(s, data); |
| } |
| |
| const PdtStack<StackId, Label> &GetStack() const { |
| return GetImpl()->GetStack(); |
| } |
| |
| const PdtStateTable<StateId, StackId> &GetStateTable() const { |
| return GetImpl()->GetStateTable(); |
| } |
| |
| private: |
| // Makes visible to friends. |
| Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); } |
| |
| void operator=(const ExpandFst<A> &fst); // Disallow |
| }; |
| |
| |
| // Specialization for ExpandFst. |
| template<class A> |
| class StateIterator< ExpandFst<A> > |
| : public CacheStateIterator< ExpandFst<A> > { |
| public: |
| explicit StateIterator(const ExpandFst<A> &fst) |
| : CacheStateIterator< ExpandFst<A> >(fst, fst.GetImpl()) {} |
| }; |
| |
| |
| // Specialization for ExpandFst. |
| template <class A> |
| class ArcIterator< ExpandFst<A> > |
| : public CacheArcIterator< ExpandFst<A> > { |
| public: |
| typedef typename A::StateId StateId; |
| |
| ArcIterator(const ExpandFst<A> &fst, StateId s) |
| : CacheArcIterator< ExpandFst<A> >(fst.GetImpl(), s) { |
| if (!fst.GetImpl()->HasArcs(s)) |
| fst.GetImpl()->ExpandState(s); |
| } |
| |
| private: |
| DISALLOW_COPY_AND_ASSIGN(ArcIterator); |
| }; |
| |
| |
| template <class A> inline |
| void ExpandFst<A>::InitStateIterator(StateIteratorData<A> *data) const |
| { |
| data->base = new StateIterator< ExpandFst<A> >(*this); |
| } |
| |
| // |
| // PrunedExpand Class |
| // |
| |
| // Prunes the delayed expansion of a pushdown transducer (PDT) encoded |
| // as an FST into an FST. In the PDT, some transitions are labeled |
| // with open or close parentheses. To be interpreted as a PDT, the |
| // parens must balance on a path. The open-close parenthesis label |
| // pairs are passed in 'parens'. The expansion enforces the |
| // parenthesis constraints. |
| // |
| // The algorithm works by visiting the delayed ExpandFst using a |
| // shortest-stack first queue discipline and relies on the |
| // shortest-distance information computed using a reverse |
| // shortest-path call to perform the pruning. |
| // |
| // The algorithm maintains the same state ordering between the ExpandFst |
| // being visited 'efst_' and the result of pruning written into the |
| // MutableFst 'ofst_' to improve readability of the code. |
| // |
| template <class A> |
| class PrunedExpand { |
| public: |
| typedef A Arc; |
| typedef typename A::Label Label; |
| typedef typename A::StateId StateId; |
| typedef typename A::Weight Weight; |
| typedef StateId StackId; |
| typedef PdtStack<StackId, Label> Stack; |
| typedef PdtStateTable<StateId, StackId> StateTable; |
| typedef typename PdtBalanceData<Arc>::SetIterator SetIterator; |
| |
| // Constructor taking as input a PDT specified by 'ifst' and 'parens'. |
| // 'keep_parentheses' specifies whether parentheses are replaced by |
| // epsilons or not during the expansion. 'opts' is the cache options |
| // used to instantiate the underlying ExpandFst. |
| PrunedExpand(const Fst<A> &ifst, |
| const vector<pair<Label, Label> > &parens, |
| bool keep_parentheses = false, |
| const CacheOptions &opts = CacheOptions()) |
| : ifst_(ifst.Copy()), |
| keep_parentheses_(keep_parentheses), |
| stack_(parens), |
| efst_(ifst, parens, |
| ExpandFstOptions<Arc>(opts, true, &stack_, &state_table_)), |
| queue_(state_table_, stack_, stack_length_, distance_, fdistance_) { |
| Reverse(*ifst_, parens, &rfst_); |
| VectorFst<Arc> path; |
| reverse_shortest_path_ = new SP( |
| rfst_, parens, |
| PdtShortestPathOptions<A, FifoQueue<StateId> >(true, false)); |
| reverse_shortest_path_->ShortestPath(&path); |
| balance_data_ = reverse_shortest_path_->GetBalanceData()->Reverse( |
| rfst_.NumStates(), 10, -1); |
| |
| InitCloseParenMultimap(parens); |
| } |
| |
| ~PrunedExpand() { |
| delete ifst_; |
| delete reverse_shortest_path_; |
| delete balance_data_; |
| } |
| |
| // Expands and prunes with weight threshold 'threshold' the input PDT. |
| // Writes the result in 'ofst'. |
| void Expand(MutableFst<A> *ofst, const Weight &threshold); |
| |
| private: |
| static const uint8 kEnqueued; |
| static const uint8 kExpanded; |
| static const uint8 kSourceState; |
| |
| // Comparison functor used by the queue: |
| // 1. states corresponding to shortest stack first, |
| // 2. among stacks of the same length, reverse lexicographic order is used, |
| // 3. among states with the same stack, shortest-first order is used. |
| class StackCompare { |
| public: |
| StackCompare(const StateTable &st, |
| const Stack &s, const vector<StackId> &sl, |
| const vector<Weight> &d, const vector<Weight> &fd) |
| : state_table_(st), stack_(s), stack_length_(sl), |
| distance_(d), fdistance_(fd) {} |
| |
| bool operator()(StateId s1, StateId s2) const { |
| StackId si1 = state_table_.Tuple(s1).stack_id; |
| StackId si2 = state_table_.Tuple(s2).stack_id; |
| if (stack_length_[si1] < stack_length_[si2]) |
| return true; |
| if (stack_length_[si1] > stack_length_[si2]) |
| return false; |
| // If stack id equal, use A* |
| if (si1 == si2) { |
| Weight w1 = (s1 < distance_.size()) && (s1 < fdistance_.size()) ? |
| Times(distance_[s1], fdistance_[s1]) : Weight::Zero(); |
| Weight w2 = (s2 < distance_.size()) && (s2 < fdistance_.size()) ? |
| Times(distance_[s2], fdistance_[s2]) : Weight::Zero(); |
| return less_(w1, w2); |
| } |
| // If lenghts are equal, use reverse lexico. |
| for (; si1 != si2; si1 = stack_.Pop(si1), si2 = stack_.Pop(si2)) { |
| if (stack_.Top(si1) < stack_.Top(si2)) return true; |
| if (stack_.Top(si1) > stack_.Top(si2)) return false; |
| } |
| return false; |
| } |
| |
| private: |
| const StateTable &state_table_; |
| const Stack &stack_; |
| const vector<StackId> &stack_length_; |
| const vector<Weight> &distance_; |
| const vector<Weight> &fdistance_; |
| NaturalLess<Weight> less_; |
| }; |
| |
| class ShortestStackFirstQueue |
| : public ShortestFirstQueue<StateId, StackCompare> { |
| public: |
| ShortestStackFirstQueue( |
| const PdtStateTable<StateId, StackId> &st, |
| const Stack &s, |
| const vector<StackId> &sl, |
| const vector<Weight> &d, const vector<Weight> &fd) |
| : ShortestFirstQueue<StateId, StackCompare>( |
| StackCompare(st, s, sl, d, fd)) {} |
| }; |
| |
| |
| void InitCloseParenMultimap(const vector<pair<Label, Label> > &parens); |
| Weight DistanceToDest(StateId state, StateId source) const; |
| uint8 Flags(StateId s) const; |
| void SetFlags(StateId s, uint8 flags, uint8 mask); |
| Weight Distance(StateId s) const; |
| void SetDistance(StateId s, Weight w); |
| Weight FinalDistance(StateId s) const; |
| void SetFinalDistance(StateId s, Weight w); |
| StateId SourceState(StateId s) const; |
| void SetSourceState(StateId s, StateId p); |
| void AddStateAndEnqueue(StateId s); |
| void Relax(StateId s, const A &arc, Weight w); |
| bool PruneArc(StateId s, const A &arc); |
| void ProcStart(); |
| void ProcFinal(StateId s); |
| bool ProcNonParen(StateId s, const A &arc, bool add_arc); |
| bool ProcOpenParen(StateId s, const A &arc, StackId si, StackId nsi); |
| bool ProcCloseParen(StateId s, const A &arc); |
| void ProcDestStates(StateId s, StackId si); |
| |
| Fst<A> *ifst_; // Input PDT |
| VectorFst<Arc> rfst_; // Reversed PDT |
| bool keep_parentheses_; // Keep parentheses in ofst? |
| StateTable state_table_; // State table for efst_ |
| Stack stack_; // Stack trie |
| ExpandFst<Arc> efst_; // Expanded PDT |
| vector<StackId> stack_length_; // Length of stack for given stack id |
| vector<Weight> distance_; // Distance from initial state in efst_/ofst |
| vector<Weight> fdistance_; // Distance to final states in efst_/ofst |
| ShortestStackFirstQueue queue_; // Queue used to visit efst_ |
| vector<uint8> flags_; // Status flags for states in efst_/ofst |
| vector<StateId> sources_; // PDT source state for each expanded state |
| |
| typedef PdtShortestPath<Arc, FifoQueue<StateId> > SP; |
| typedef typename SP::CloseParenMultimap ParenMultimap; |
| SP *reverse_shortest_path_; // Shortest path for rfst_ |
| PdtBalanceData<Arc> *balance_data_; // Not owned by shortest_path_ |
| ParenMultimap close_paren_multimap_; // Maps open paren arcs to |
| // balancing close paren arcs. |
| |
| MutableFst<Arc> *ofst_; // Output fst |
| Weight limit_; // Weight limit |
| |
| typedef unordered_map<StateId, Weight> DestMap; |
| DestMap dest_map_; |
| StackId current_stack_id_; |
| // 'current_stack_id_' is the stack id of the states currently at the top |
| // of queue, i.e., the states currently being popped and processed. |
| // 'dest_map_' maps a state 's' in 'ifst_' that is the source |
| // of a close parentheses matching the top of 'current_stack_id_; to |
| // the shortest-distance from '(s, current_stack_id_)' to the final |
| // states in 'efst_'. |
| ssize_t current_paren_id_; // Paren id at top of current stack |
| ssize_t cached_stack_id_; |
| StateId cached_source_; |
| slist<pair<StateId, Weight> > cached_dest_list_; |
| // 'cached_dest_list_' contains the set of pair of destination |
| // states and weight to final states for source state |
| // 'cached_source_' and paren id 'cached_paren_id': the set of |
| // source state of a close parenthesis with paren id |
| // 'cached_paren_id' balancing an incoming open parenthesis with |
| // paren id 'cached_paren_id' in state 'cached_source_'. |
| |
| NaturalLess<Weight> less_; |
| }; |
| |
| template <class A> const uint8 PrunedExpand<A>::kEnqueued = 0x01; |
| template <class A> const uint8 PrunedExpand<A>::kExpanded = 0x02; |
| template <class A> const uint8 PrunedExpand<A>::kSourceState = 0x04; |
| |
| |
| // Initializes close paren multimap, mapping pairs (s,paren_id) to |
| // all the arcs out of s labeled with close parenthese for paren_id. |
| template <class A> |
| void PrunedExpand<A>::InitCloseParenMultimap( |
| const vector<pair<Label, Label> > &parens) { |
| unordered_map<Label, Label> paren_id_map; |
| for (Label i = 0; i < parens.size(); ++i) { |
| const pair<Label, Label> &p = parens[i]; |
| paren_id_map[p.first] = i; |
| paren_id_map[p.second] = i; |
| } |
| |
| for (StateIterator<Fst<Arc> > siter(*ifst_); !siter.Done(); siter.Next()) { |
| StateId s = siter.Value(); |
| for (ArcIterator<Fst<Arc> > aiter(*ifst_, s); |
| !aiter.Done(); aiter.Next()) { |
| const Arc &arc = aiter.Value(); |
| typename unordered_map<Label, Label>::const_iterator pit |
| = paren_id_map.find(arc.ilabel); |
| if (pit == paren_id_map.end()) continue; |
| if (arc.ilabel == parens[pit->second].second) { // Close paren |
| ParenState<Arc> paren_state(pit->second, s); |
| close_paren_multimap_.insert(make_pair(paren_state, arc)); |
| } |
| } |
| } |
| } |
| |
| |
| // Returns the weight of the shortest balanced path from 'source' to 'dest' |
| // in 'ifst_', 'dest' must be the source state of a close paren arc. |
| template <class A> |
| typename A::Weight PrunedExpand<A>::DistanceToDest(StateId source, |
| StateId dest) const { |
| typename SP::SearchState s(source + 1, dest + 1); |
| VLOG(2) << "D(" << source << ", " << dest << ") =" |
| << reverse_shortest_path_->GetShortestPathData().Distance(s); |
| return reverse_shortest_path_->GetShortestPathData().Distance(s); |
| } |
| |
| // Returns the flags for state 's' in 'ofst_'. |
| template <class A> |
| uint8 PrunedExpand<A>::Flags(StateId s) const { |
| return s < flags_.size() ? flags_[s] : 0; |
| } |
| |
| // Modifies the flags for state 's' in 'ofst_'. |
| template <class A> |
| void PrunedExpand<A>::SetFlags(StateId s, uint8 flags, uint8 mask) { |
| while (flags_.size() <= s) flags_.push_back(0); |
| flags_[s] &= ~mask; |
| flags_[s] |= flags & mask; |
| } |
| |
| |
| // Returns the shortest distance from the initial state to 's' in 'ofst_'. |
| template <class A> |
| typename A::Weight PrunedExpand<A>::Distance(StateId s) const { |
| return s < distance_.size() ? distance_[s] : Weight::Zero(); |
| } |
| |
| // Sets the shortest distance from the initial state to 's' in 'ofst_' to 'w'. |
| template <class A> |
| void PrunedExpand<A>::SetDistance(StateId s, Weight w) { |
| while (distance_.size() <= s ) distance_.push_back(Weight::Zero()); |
| distance_[s] = w; |
| } |
| |
| |
| // Returns the shortest distance from 's' to the final states in 'ofst_'. |
| template <class A> |
| typename A::Weight PrunedExpand<A>::FinalDistance(StateId s) const { |
| return s < fdistance_.size() ? fdistance_[s] : Weight::Zero(); |
| } |
| |
| // Sets the shortest distance from 's' to the final states in 'ofst_' to 'w'. |
| template <class A> |
| void PrunedExpand<A>::SetFinalDistance(StateId s, Weight w) { |
| while (fdistance_.size() <= s) fdistance_.push_back(Weight::Zero()); |
| fdistance_[s] = w; |
| } |
| |
| // Returns the PDT "source" state of state 's' in 'ofst_'. |
| template <class A> |
| typename A::StateId PrunedExpand<A>::SourceState(StateId s) const { |
| return s < sources_.size() ? sources_[s] : kNoStateId; |
| } |
| |
| // Sets the PDT "source" state of state 's' in 'ofst_' to state 'p' in 'ifst_'. |
| template <class A> |
| void PrunedExpand<A>::SetSourceState(StateId s, StateId p) { |
| while (sources_.size() <= s) sources_.push_back(kNoStateId); |
| sources_[s] = p; |
| } |
| |
| // Adds state 's' of 'efst_' to 'ofst_' and inserts it in the queue, |
| // modifying the flags for 's' accordingly. |
| template <class A> |
| void PrunedExpand<A>::AddStateAndEnqueue(StateId s) { |
| if (!(Flags(s) & (kEnqueued | kExpanded))) { |
| while (ofst_->NumStates() <= s) ofst_->AddState(); |
| queue_.Enqueue(s); |
| SetFlags(s, kEnqueued, kEnqueued); |
| } else if (Flags(s) & kEnqueued) { |
| queue_.Update(s); |
| } |
| // TODO(allauzen): Check everything is fine when kExpanded? |
| } |
| |
| // Relaxes arc 'arc' out of state 's' in 'ofst_': |
| // * if the distance to 's' times the weight of 'arc' is smaller than |
| // the currently stored distance for 'arc.nextstate', |
| // updates 'Distance(arc.nextstate)' with new estimate; |
| // * if 'fd' is less than the currently stored distance from 'arc.nextstate' |
| // to the final state, updates with new estimate. |
| template <class A> |
| void PrunedExpand<A>::Relax(StateId s, const A &arc, Weight fd) { |
| Weight nd = Times(Distance(s), arc.weight); |
| if (less_(nd, Distance(arc.nextstate))) { |
| SetDistance(arc.nextstate, nd); |
| SetSourceState(arc.nextstate, SourceState(s)); |
| } |
| if (less_(fd, FinalDistance(arc.nextstate))) |
| SetFinalDistance(arc.nextstate, fd); |
| VLOG(2) << "Relax: " << s << ", d[s] = " << Distance(s) << ", to " |
| << arc.nextstate << ", d[ns] = " << Distance(arc.nextstate) |
| << ", nd = " << nd; |
| } |
| |
| // Returns 'true' if the arc 'arc' out of state 's' in 'efst_' needs to |
| // be pruned. |
| template <class A> |
| bool PrunedExpand<A>::PruneArc(StateId s, const A &arc) { |
| VLOG(2) << "Prune ?"; |
| Weight fd = Weight::Zero(); |
| |
| if ((cached_source_ != SourceState(s)) || |
| (cached_stack_id_ != current_stack_id_)) { |
| cached_source_ = SourceState(s); |
| cached_stack_id_ = current_stack_id_; |
| cached_dest_list_.clear(); |
| if (cached_source_ != ifst_->Start()) { |
| for (SetIterator set_iter = |
| balance_data_->Find(current_paren_id_, cached_source_); |
| !set_iter.Done(); set_iter.Next()) { |
| StateId dest = set_iter.Element(); |
| typename DestMap::const_iterator iter = dest_map_.find(dest); |
| cached_dest_list_.push_front(*iter); |
| } |
| } else { |
| // TODO(allauzen): queue discipline should prevent this never |
| // from happening; replace by a check. |
| cached_dest_list_.push_front( |
| make_pair(rfst_.Start() -1, Weight::One())); |
| } |
| } |
| |
| for (typename slist<pair<StateId, Weight> >::const_iterator iter = |
| cached_dest_list_.begin(); |
| iter != cached_dest_list_.end(); |
| ++iter) { |
| fd = Plus(fd, |
| Times(DistanceToDest(state_table_.Tuple(arc.nextstate).state_id, |
| iter->first), |
| iter->second)); |
| } |
| Relax(s, arc, fd); |
| Weight w = Times(Distance(s), Times(arc.weight, fd)); |
| return less_(limit_, w); |
| } |
| |
| // Adds start state of 'efst_' to 'ofst_', enqueues it and initializes |
| // the distance data structures. |
| template <class A> |
| void PrunedExpand<A>::ProcStart() { |
| StateId s = efst_.Start(); |
| AddStateAndEnqueue(s); |
| ofst_->SetStart(s); |
| SetSourceState(s, ifst_->Start()); |
| |
| current_stack_id_ = 0; |
| current_paren_id_ = -1; |
| stack_length_.push_back(0); |
| dest_map_[rfst_.Start() - 1] = Weight::One(); // not needed |
| |
| cached_source_ = ifst_->Start(); |
| cached_stack_id_ = 0; |
| cached_dest_list_.push_front( |
| make_pair(rfst_.Start() -1, Weight::One())); |
| |
| PdtStateTuple<StateId, StackId> tuple(rfst_.Start() - 1, 0); |
| SetFinalDistance(state_table_.FindState(tuple), Weight::One()); |
| SetDistance(s, Weight::One()); |
| SetFinalDistance(s, DistanceToDest(ifst_->Start(), rfst_.Start() - 1)); |
| VLOG(2) << DistanceToDest(ifst_->Start(), rfst_.Start() - 1); |
| } |
| |
| // Makes 's' final in 'ofst_' if shortest accepting path ending in 's' |
| // is below threshold. |
| template <class A> |
| void PrunedExpand<A>::ProcFinal(StateId s) { |
| Weight final = efst_.Final(s); |
| if ((final == Weight::Zero()) || less_(limit_, Times(Distance(s), final))) |
| return; |
| ofst_->SetFinal(s, final); |
| } |
| |
| // Returns true when arc (or meta-arc) 'arc' out of 's' in 'efst_' is |
| // below the threshold. When 'add_arc' is true, 'arc' is added to 'ofst_'. |
| template <class A> |
| bool PrunedExpand<A>::ProcNonParen(StateId s, const A &arc, bool add_arc) { |
| VLOG(2) << "ProcNonParen: " << s << " to " << arc.nextstate |
| << ", " << arc.ilabel << ":" << arc.olabel << " / " << arc.weight |
| << ", add_arc = " << (add_arc ? "true" : "false"); |
| if (PruneArc(s, arc)) return false; |
| if(add_arc) ofst_->AddArc(s, arc); |
| AddStateAndEnqueue(arc.nextstate); |
| return true; |
| } |
| |
| // Processes an open paren arc 'arc' out of state 's' in 'ofst_'. |
| // When 'arc' is labeled with an open paren, |
| // 1. considers each (shortest) balanced path starting in 's' by |
| // taking 'arc' and ending by a close paren balancing the open |
| // paren of 'arc' as a meta-arc, processes and prunes each meta-arc |
| // as a non-paren arc, inserting its destination to the queue; |
| // 2. if at least one of these meta-arcs has not been pruned, |
| // adds the destination of 'arc' to 'ofst_' as a new source state |
| // for the stack id 'nsi' and inserts it in the queue. |
| template <class A> |
| bool PrunedExpand<A>::ProcOpenParen(StateId s, const A &arc, StackId si, |
| StackId nsi) { |
| // Update the stack lenght when needed: |nsi| = |si| + 1. |
| while (stack_length_.size() <= nsi) stack_length_.push_back(-1); |
| if (stack_length_[nsi] == -1) |
| stack_length_[nsi] = stack_length_[si] + 1; |
| |
| StateId ns = arc.nextstate; |
| VLOG(2) << "Open paren: " << s << "(" << state_table_.Tuple(s).state_id |
| << ") to " << ns << "(" << state_table_.Tuple(ns).state_id << ")"; |
| bool proc_arc = false; |
| Weight fd = Weight::Zero(); |
| ssize_t paren_id = stack_.ParenId(arc.ilabel); |
| slist<StateId> sources; |
| for (SetIterator set_iter = |
| balance_data_->Find(paren_id, state_table_.Tuple(ns).state_id); |
| !set_iter.Done(); set_iter.Next()) { |
| sources.push_front(set_iter.Element()); |
| } |
| for (typename slist<StateId>::const_iterator sources_iter = sources.begin(); |
| sources_iter != sources.end(); |
| ++ sources_iter) { |
| StateId source = *sources_iter; |
| VLOG(2) << "Close paren source: " << source; |
| ParenState<Arc> paren_state(paren_id, source); |
| for (typename ParenMultimap::const_iterator iter = |
| close_paren_multimap_.find(paren_state); |
| iter != close_paren_multimap_.end() && paren_state == iter->first; |
| ++iter) { |
| Arc meta_arc = iter->second; |
| PdtStateTuple<StateId, StackId> tuple(meta_arc.nextstate, si); |
| meta_arc.nextstate = state_table_.FindState(tuple); |
| VLOG(2) << state_table_.Tuple(ns).state_id << ", " << source; |
| VLOG(2) << "Meta arc weight = " << arc.weight << " Times " |
| << DistanceToDest(state_table_.Tuple(ns).state_id, source) |
| << " Times " << meta_arc.weight; |
| meta_arc.weight = Times( |
| arc.weight, |
| Times(DistanceToDest(state_table_.Tuple(ns).state_id, source), |
| meta_arc.weight)); |
| proc_arc |= ProcNonParen(s, meta_arc, false); |
| fd = Plus(fd, Times( |
| Times( |
| DistanceToDest(state_table_.Tuple(ns).state_id, source), |
| iter->second.weight), |
| FinalDistance(meta_arc.nextstate))); |
| } |
| } |
| if (proc_arc) { |
| VLOG(2) << "Proc open paren " << s << " to " << arc.nextstate; |
| ofst_->AddArc( |
| s, keep_parentheses_ ? arc : Arc(0, 0, arc.weight, arc.nextstate)); |
| AddStateAndEnqueue(arc.nextstate); |
| Weight nd = Times(Distance(s), arc.weight); |
| if(less_(nd, Distance(arc.nextstate))) |
| SetDistance(arc.nextstate, nd); |
| // FinalDistance not necessary for source state since pruning |
| // decided using the meta-arcs above. But this is a problem with |
| // A*, hence: |
| if (less_(fd, FinalDistance(arc.nextstate))) |
| SetFinalDistance(arc.nextstate, fd); |
| SetFlags(arc.nextstate, kSourceState, kSourceState); |
| } |
| return proc_arc; |
| } |
| |
| // Checks that shortest path through close paren arc in 'efst_' is |
| // below threshold, if so adds it to 'ofst_'. |
| template <class A> |
| bool PrunedExpand<A>::ProcCloseParen(StateId s, const A &arc) { |
| Weight w = Times(Distance(s), |
| Times(arc.weight, FinalDistance(arc.nextstate))); |
| if (less_(limit_, w)) |
| return false; |
| ofst_->AddArc( |
| s, keep_parentheses_ ? arc : Arc(0, 0, arc.weight, arc.nextstate)); |
| return true; |
| } |
| |
| // When 's' in 'ofst_' is a source state for stack id 'si', identifies |
| // all the corresponding possible destination states, that is, all the |
| // states in 'ifst_' that have an outgoing close paren arc balancing |
| // the incoming open paren taken to get to 's', and for each such |
| // state 't', computes the shortest distance from (t, si) to the final |
| // states in 'ofst_'. Stores this information in 'dest_map_'. |
| template <class A> |
| void PrunedExpand<A>::ProcDestStates(StateId s, StackId si) { |
| if (!(Flags(s) & kSourceState)) return; |
| if (si != current_stack_id_) { |
| dest_map_.clear(); |
| current_stack_id_ = si; |
| current_paren_id_ = stack_.Top(current_stack_id_); |
| VLOG(2) << "StackID " << si << " dequeued for first time"; |
| } |
| // TODO(allauzen): clean up source state business; rename current function to |
| // ProcSourceState. |
| SetSourceState(s, state_table_.Tuple(s).state_id); |
| |
| ssize_t paren_id = stack_.Top(si); |
| for (SetIterator set_iter = |
| balance_data_->Find(paren_id, state_table_.Tuple(s).state_id); |
| !set_iter.Done(); set_iter.Next()) { |
| StateId dest_state = set_iter.Element(); |
| if (dest_map_.find(dest_state) != dest_map_.end()) |
| continue; |
| Weight dest_weight = Weight::Zero(); |
| ParenState<Arc> paren_state(paren_id, dest_state); |
| for (typename ParenMultimap::const_iterator iter = |
| close_paren_multimap_.find(paren_state); |
| iter != close_paren_multimap_.end() && paren_state == iter->first; |
| ++iter) { |
| const Arc &arc = iter->second; |
| PdtStateTuple<StateId, StackId> tuple(arc.nextstate, stack_.Pop(si)); |
| dest_weight = Plus(dest_weight, |
| Times(arc.weight, |
| FinalDistance(state_table_.FindState(tuple)))); |
| } |
| dest_map_[dest_state] = dest_weight; |
| VLOG(2) << "State " << dest_state << " is a dest state for stack id " |
| << si << " with weight " << dest_weight; |
| } |
| } |
| |
| // Expands and prunes with weight threshold 'threshold' the input PDT. |
| // Writes the result in 'ofst'. |
| template <class A> |
| void PrunedExpand<A>::Expand( |
| MutableFst<A> *ofst, const typename A::Weight &threshold) { |
| ofst_ = ofst; |
| ofst_->DeleteStates(); |
| ofst_->SetInputSymbols(ifst_->InputSymbols()); |
| ofst_->SetOutputSymbols(ifst_->OutputSymbols()); |
| |
| limit_ = Times(DistanceToDest(ifst_->Start(), rfst_.Start() - 1), threshold); |
| flags_.clear(); |
| |
| ProcStart(); |
| |
| while (!queue_.Empty()) { |
| StateId s = queue_.Head(); |
| queue_.Dequeue(); |
| SetFlags(s, kExpanded, kExpanded | kEnqueued); |
| VLOG(2) << s << " dequeued!"; |
| |
| ProcFinal(s); |
| StackId stack_id = state_table_.Tuple(s).stack_id; |
| ProcDestStates(s, stack_id); |
| |
| for (ArcIterator<ExpandFst<Arc> > aiter(efst_, s); |
| !aiter.Done(); |
| aiter.Next()) { |
| Arc arc = aiter.Value(); |
| StackId nextstack_id = state_table_.Tuple(arc.nextstate).stack_id; |
| if (stack_id == nextstack_id) |
| ProcNonParen(s, arc, true); |
| else if (stack_id == stack_.Pop(nextstack_id)) |
| ProcOpenParen(s, arc, stack_id, nextstack_id); |
| else |
| ProcCloseParen(s, arc); |
| } |
| VLOG(2) << "d[" << s << "] = " << Distance(s) |
| << ", fd[" << s << "] = " << FinalDistance(s); |
| } |
| } |
| |
| // |
| // Expand() Functions |
| // |
| |
| template <class Arc> |
| struct ExpandOptions { |
| bool connect; |
| bool keep_parentheses; |
| typename Arc::Weight weight_threshold; |
| |
| ExpandOptions(bool c = true, bool k = false, |
| typename Arc::Weight w = Arc::Weight::Zero()) |
| : connect(c), keep_parentheses(k), weight_threshold(w) {} |
| }; |
| |
| // Expands a pushdown transducer (PDT) encoded as an FST into an FST. |
| // This version writes the expanded PDT result to a MutableFst. |
| // In the PDT, some transitions are labeled with open or close |
| // parentheses. To be interpreted as a PDT, the parens must balance on |
| // a path. The open-close parenthesis label pairs are passed in |
| // 'parens'. The expansion enforces the parenthesis constraints. The |
| // PDT must be expandable as an FST. |
| template <class Arc> |
| void Expand( |
| const Fst<Arc> &ifst, |
| const vector<pair<typename Arc::Label, typename Arc::Label> > &parens, |
| MutableFst<Arc> *ofst, |
| const ExpandOptions<Arc> &opts) { |
| typedef typename Arc::Label Label; |
| typedef typename Arc::StateId StateId; |
| typedef typename Arc::Weight Weight; |
| typedef typename ExpandFst<Arc>::StackId StackId; |
| |
| ExpandFstOptions<Arc> eopts; |
| eopts.gc_limit = 0; |
| if (opts.weight_threshold == Weight::Zero()) { |
| eopts.keep_parentheses = opts.keep_parentheses; |
| *ofst = ExpandFst<Arc>(ifst, parens, eopts); |
| } else { |
| PrunedExpand<Arc> pruned_expand(ifst, parens, opts.keep_parentheses); |
| pruned_expand.Expand(ofst, opts.weight_threshold); |
| } |
| |
| if (opts.connect) |
| Connect(ofst); |
| } |
| |
| // Expands a pushdown transducer (PDT) encoded as an FST into an FST. |
| // This version writes the expanded PDT result to a MutableFst. |
| // In the PDT, some transitions are labeled with open or close |
| // parentheses. To be interpreted as a PDT, the parens must balance on |
| // a path. The open-close parenthesis label pairs are passed in |
| // 'parens'. The expansion enforces the parenthesis constraints. The |
| // PDT must be expandable as an FST. |
| template<class Arc> |
| void Expand( |
| const Fst<Arc> &ifst, |
| const vector<pair<typename Arc::Label, typename Arc::Label> > &parens, |
| MutableFst<Arc> *ofst, |
| bool connect = true, bool keep_parentheses = false) { |
| Expand(ifst, parens, ofst, ExpandOptions<Arc>(connect, keep_parentheses)); |
| } |
| |
| } // namespace fst |
| |
| #endif // FST_EXTENSIONS_PDT_EXPAND_H__ |