| // shortest-distance.h |
| |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| // |
| // Copyright 2005-2010 Google, Inc. |
| // Author: allauzen@google.com (Cyril Allauzen) |
| // |
| // \file |
| // Functions and classes to find shortest distance in an FST. |
| |
| #ifndef FST_LIB_SHORTEST_DISTANCE_H__ |
| #define FST_LIB_SHORTEST_DISTANCE_H__ |
| |
| #include <deque> |
| using std::deque; |
| #include <vector> |
| using std::vector; |
| |
| #include <fst/arcfilter.h> |
| #include <fst/cache.h> |
| #include <fst/queue.h> |
| #include <fst/reverse.h> |
| #include <fst/test-properties.h> |
| |
| |
| namespace fst { |
| |
| template <class Arc, class Queue, class ArcFilter> |
| struct ShortestDistanceOptions { |
| typedef typename Arc::StateId StateId; |
| |
| Queue *state_queue; // Queue discipline used; owned by caller |
| ArcFilter arc_filter; // Arc filter (e.g., limit to only epsilon graph) |
| StateId source; // If kNoStateId, use the Fst's initial state |
| float delta; // Determines the degree of convergence required |
| bool first_path; // For a semiring with the path property (o.w. |
| // undefined), compute the shortest-distances along |
| // along the first path to a final state found |
| // by the algorithm. That path is the shortest-path |
| // only if the FST has a unique final state (or all |
| // the final states have the same final weight), the |
| // queue discipline is shortest-first and all the |
| // weights in the FST are between One() and Zero() |
| // according to NaturalLess. |
| |
| ShortestDistanceOptions(Queue *q, ArcFilter filt, StateId src = kNoStateId, |
| float d = kDelta) |
| : state_queue(q), arc_filter(filt), source(src), delta(d), |
| first_path(false) {} |
| }; |
| |
| |
| // Computation state of the shortest-distance algorithm. Reusable |
| // information is maintained across calls to member function |
| // ShortestDistance(source) when 'retain' is true for improved |
| // efficiency when calling multiple times from different source states |
| // (e.g., in epsilon removal). Contrary to usual conventions, 'fst' |
| // may not be freed before this class. Vector 'distance' should not be |
| // modified by the user between these calls. |
| // The Error() method returns true if an error was encountered. |
| template<class Arc, class Queue, class ArcFilter> |
| class ShortestDistanceState { |
| public: |
| typedef typename Arc::StateId StateId; |
| typedef typename Arc::Weight Weight; |
| |
| ShortestDistanceState( |
| const Fst<Arc> &fst, |
| vector<Weight> *distance, |
| const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts, |
| bool retain) |
| : fst_(fst), distance_(distance), state_queue_(opts.state_queue), |
| arc_filter_(opts.arc_filter), delta_(opts.delta), |
| first_path_(opts.first_path), retain_(retain), source_id_(0), |
| error_(false) { |
| distance_->clear(); |
| } |
| |
| ~ShortestDistanceState() {} |
| |
| void ShortestDistance(StateId source); |
| |
| bool Error() const { return error_; } |
| |
| private: |
| const Fst<Arc> &fst_; |
| vector<Weight> *distance_; |
| Queue *state_queue_; |
| ArcFilter arc_filter_; |
| float delta_; |
| bool first_path_; |
| bool retain_; // Retain and reuse information across calls |
| |
| vector<Weight> rdistance_; // Relaxation distance. |
| vector<bool> enqueued_; // Is state enqueued? |
| vector<StateId> sources_; // Source ID for ith state in 'distance_', |
| // 'rdistance_', and 'enqueued_' if retained. |
| StateId source_id_; // Unique ID characterizing each call to SD |
| |
| bool error_; |
| }; |
| |
| // Compute the shortest distance. If 'source' is kNoStateId, use |
| // the initial state of the Fst. |
| template <class Arc, class Queue, class ArcFilter> |
| void ShortestDistanceState<Arc, Queue, ArcFilter>::ShortestDistance( |
| StateId source) { |
| if (fst_.Start() == kNoStateId) { |
| if (fst_.Properties(kError, false)) error_ = true; |
| return; |
| } |
| |
| if (!(Weight::Properties() & kRightSemiring)) { |
| FSTERROR() << "ShortestDistance: Weight needs to be right distributive: " |
| << Weight::Type(); |
| error_ = true; |
| return; |
| } |
| |
| if (first_path_ && !(Weight::Properties() & kPath)) { |
| FSTERROR() << "ShortestDistance: first_path option disallowed when " |
| << "Weight does not have the path property: " |
| << Weight::Type(); |
| error_ = true; |
| return; |
| } |
| |
| state_queue_->Clear(); |
| |
| if (!retain_) { |
| distance_->clear(); |
| rdistance_.clear(); |
| enqueued_.clear(); |
| } |
| |
| if (source == kNoStateId) |
| source = fst_.Start(); |
| |
| while (distance_->size() <= source) { |
| distance_->push_back(Weight::Zero()); |
| rdistance_.push_back(Weight::Zero()); |
| enqueued_.push_back(false); |
| } |
| if (retain_) { |
| while (sources_.size() <= source) |
| sources_.push_back(kNoStateId); |
| sources_[source] = source_id_; |
| } |
| (*distance_)[source] = Weight::One(); |
| rdistance_[source] = Weight::One(); |
| enqueued_[source] = true; |
| |
| state_queue_->Enqueue(source); |
| |
| while (!state_queue_->Empty()) { |
| StateId s = state_queue_->Head(); |
| state_queue_->Dequeue(); |
| while (distance_->size() <= s) { |
| distance_->push_back(Weight::Zero()); |
| rdistance_.push_back(Weight::Zero()); |
| enqueued_.push_back(false); |
| } |
| if (first_path_ && (fst_.Final(s) != Weight::Zero())) |
| break; |
| enqueued_[s] = false; |
| Weight r = rdistance_[s]; |
| rdistance_[s] = Weight::Zero(); |
| for (ArcIterator< Fst<Arc> > aiter(fst_, s); |
| !aiter.Done(); |
| aiter.Next()) { |
| const Arc &arc = aiter.Value(); |
| if (!arc_filter_(arc)) |
| continue; |
| while (distance_->size() <= arc.nextstate) { |
| distance_->push_back(Weight::Zero()); |
| rdistance_.push_back(Weight::Zero()); |
| enqueued_.push_back(false); |
| } |
| if (retain_) { |
| while (sources_.size() <= arc.nextstate) |
| sources_.push_back(kNoStateId); |
| if (sources_[arc.nextstate] != source_id_) { |
| (*distance_)[arc.nextstate] = Weight::Zero(); |
| rdistance_[arc.nextstate] = Weight::Zero(); |
| enqueued_[arc.nextstate] = false; |
| sources_[arc.nextstate] = source_id_; |
| } |
| } |
| Weight &nd = (*distance_)[arc.nextstate]; |
| Weight &nr = rdistance_[arc.nextstate]; |
| Weight w = Times(r, arc.weight); |
| if (!ApproxEqual(nd, Plus(nd, w), delta_)) { |
| nd = Plus(nd, w); |
| nr = Plus(nr, w); |
| if (!nd.Member() || !nr.Member()) { |
| error_ = true; |
| return; |
| } |
| if (!enqueued_[arc.nextstate]) { |
| state_queue_->Enqueue(arc.nextstate); |
| enqueued_[arc.nextstate] = true; |
| } else { |
| state_queue_->Update(arc.nextstate); |
| } |
| } |
| } |
| } |
| ++source_id_; |
| if (fst_.Properties(kError, false)) error_ = true; |
| } |
| |
| |
| // Shortest-distance algorithm: this version allows fine control |
| // via the options argument. See below for a simpler interface. |
| // |
| // This computes the shortest distance from the 'opts.source' state to |
| // each visited state S and stores the value in the 'distance' vector. |
| // An unvisited state S has distance Zero(), which will be stored in |
| // the 'distance' vector if S is less than the maximum visited state. |
| // The state queue discipline, arc filter, and convergence delta are |
| // taken in the options argument. |
| // The 'distance' vector will contain a unique element for which |
| // Member() is false if an error was encountered. |
| // |
| // The weights must must be right distributive and k-closed (i.e., 1 + |
| // x + x^2 + ... + x^(k +1) = 1 + x + x^2 + ... + x^k). |
| // |
| // The algorithm is from Mohri, "Semiring Framweork and Algorithms for |
| // Shortest-Distance Problems", Journal of Automata, Languages and |
| // Combinatorics 7(3):321-350, 2002. The complexity of algorithm |
| // depends on the properties of the semiring and the queue discipline |
| // used. Refer to the paper for more details. |
| template<class Arc, class Queue, class ArcFilter> |
| void ShortestDistance( |
| const Fst<Arc> &fst, |
| vector<typename Arc::Weight> *distance, |
| const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts) { |
| |
| ShortestDistanceState<Arc, Queue, ArcFilter> |
| sd_state(fst, distance, opts, false); |
| sd_state.ShortestDistance(opts.source); |
| if (sd_state.Error()) { |
| distance->clear(); |
| distance->resize(1, Arc::Weight::NoWeight()); |
| } |
| } |
| |
| // Shortest-distance algorithm: simplified interface. See above for a |
| // version that allows finer control. |
| // |
| // If 'reverse' is false, this computes the shortest distance from the |
| // initial state to each state S and stores the value in the |
| // 'distance' vector. If 'reverse' is true, this computes the shortest |
| // distance from each state to the final states. An unvisited state S |
| // has distance Zero(), which will be stored in the 'distance' vector |
| // if S is less than the maximum visited state. The state queue |
| // discipline is automatically-selected. |
| // The 'distance' vector will contain a unique element for which |
| // Member() is false if an error was encountered. |
| // |
| // The weights must must be right (left) distributive if reverse is |
| // false (true) and k-closed (i.e., 1 + x + x^2 + ... + x^(k +1) = 1 + |
| // x + x^2 + ... + x^k). |
| // |
| // The algorithm is from Mohri, "Semiring Framweork and Algorithms for |
| // Shortest-Distance Problems", Journal of Automata, Languages and |
| // Combinatorics 7(3):321-350, 2002. The complexity of algorithm |
| // depends on the properties of the semiring and the queue discipline |
| // used. Refer to the paper for more details. |
| template<class Arc> |
| void ShortestDistance(const Fst<Arc> &fst, |
| vector<typename Arc::Weight> *distance, |
| bool reverse = false, |
| float delta = kDelta) { |
| typedef typename Arc::StateId StateId; |
| typedef typename Arc::Weight Weight; |
| |
| if (!reverse) { |
| AnyArcFilter<Arc> arc_filter; |
| AutoQueue<StateId> state_queue(fst, distance, arc_filter); |
| ShortestDistanceOptions< Arc, AutoQueue<StateId>, AnyArcFilter<Arc> > |
| opts(&state_queue, arc_filter); |
| opts.delta = delta; |
| ShortestDistance(fst, distance, opts); |
| } else { |
| typedef ReverseArc<Arc> ReverseArc; |
| typedef typename ReverseArc::Weight ReverseWeight; |
| AnyArcFilter<ReverseArc> rarc_filter; |
| VectorFst<ReverseArc> rfst; |
| Reverse(fst, &rfst); |
| vector<ReverseWeight> rdistance; |
| AutoQueue<StateId> state_queue(rfst, &rdistance, rarc_filter); |
| ShortestDistanceOptions< ReverseArc, AutoQueue<StateId>, |
| AnyArcFilter<ReverseArc> > |
| ropts(&state_queue, rarc_filter); |
| ropts.delta = delta; |
| ShortestDistance(rfst, &rdistance, ropts); |
| distance->clear(); |
| if (rdistance.size() == 1 && !rdistance[0].Member()) { |
| distance->resize(1, Arc::Weight::NoWeight()); |
| return; |
| } |
| while (distance->size() < rdistance.size() - 1) |
| distance->push_back(rdistance[distance->size() + 1].Reverse()); |
| } |
| } |
| |
| |
| // Return the sum of the weight of all successful paths in an FST, i.e., |
| // the shortest-distance from the initial state to the final states. |
| // Returns a weight such that Member() is false if an error was encountered. |
| template <class Arc> |
| typename Arc::Weight ShortestDistance(const Fst<Arc> &fst, float delta = kDelta) { |
| typedef typename Arc::Weight Weight; |
| typedef typename Arc::StateId StateId; |
| vector<Weight> distance; |
| if (Weight::Properties() & kRightSemiring) { |
| ShortestDistance(fst, &distance, false, delta); |
| if (distance.size() == 1 && !distance[0].Member()) |
| return Arc::Weight::NoWeight(); |
| Weight sum = Weight::Zero(); |
| for (StateId s = 0; s < distance.size(); ++s) |
| sum = Plus(sum, Times(distance[s], fst.Final(s))); |
| return sum; |
| } else { |
| ShortestDistance(fst, &distance, true, delta); |
| StateId s = fst.Start(); |
| if (distance.size() == 1 && !distance[0].Member()) |
| return Arc::Weight::NoWeight(); |
| return s != kNoStateId && s < distance.size() ? |
| distance[s] : Weight::Zero(); |
| } |
| } |
| |
| |
| } // namespace fst |
| |
| #endif // FST_LIB_SHORTEST_DISTANCE_H__ |