| // shortest-path.h |
| |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| // |
| // Copyright 2005-2010 Google, Inc. |
| // Author: allauzen@google.com (Cyril Allauzen) |
| // |
| // \file |
| // Functions to find shortest paths in an FST. |
| |
| #ifndef FST_LIB_SHORTEST_PATH_H__ |
| #define FST_LIB_SHORTEST_PATH_H__ |
| |
| #include <functional> |
| #include <utility> |
| using std::pair; using std::make_pair; |
| #include <vector> |
| using std::vector; |
| |
| #include <fst/cache.h> |
| #include <fst/determinize.h> |
| #include <fst/queue.h> |
| #include <fst/shortest-distance.h> |
| #include <fst/test-properties.h> |
| |
| |
| namespace fst { |
| |
| template <class Arc, class Queue, class ArcFilter> |
| struct ShortestPathOptions |
| : public ShortestDistanceOptions<Arc, Queue, ArcFilter> { |
| typedef typename Arc::StateId StateId; |
| typedef typename Arc::Weight Weight; |
| size_t nshortest; // return n-shortest paths |
| bool unique; // only return paths with distinct input strings |
| bool has_distance; // distance vector already contains the |
| // shortest distance from the initial state |
| bool first_path; // Single shortest path stops after finding the first |
| // path to a final state. That path is the shortest path |
| // only when using the ShortestFirstQueue and |
| // only when all the weights in the FST are between |
| // One() and Zero() according to NaturalLess. |
| Weight weight_threshold; // pruning weight threshold. |
| StateId state_threshold; // pruning state threshold. |
| |
| ShortestPathOptions(Queue *q, ArcFilter filt, size_t n = 1, bool u = false, |
| bool hasdist = false, float d = kDelta, |
| bool fp = false, Weight w = Weight::Zero(), |
| StateId s = kNoStateId) |
| : ShortestDistanceOptions<Arc, Queue, ArcFilter>(q, filt, kNoStateId, d), |
| nshortest(n), unique(u), has_distance(hasdist), first_path(fp), |
| weight_threshold(w), state_threshold(s) {} |
| }; |
| |
| |
| // Shortest-path algorithm: normally not called directly; prefer |
| // 'ShortestPath' below with n=1. 'ofst' contains the shortest path in |
| // 'ifst'. 'distance' returns the shortest distances from the source |
| // state to each state in 'ifst'. 'opts' is used to specify options |
| // such as the queue discipline, the arc filter and delta. |
| // |
| // The shortest path is the lowest weight path w.r.t. the natural |
| // semiring order. |
| // |
| // The weights need to be right distributive and have the path (kPath) |
| // property. |
| template<class Arc, class Queue, class ArcFilter> |
| void SingleShortestPath(const Fst<Arc> &ifst, |
| MutableFst<Arc> *ofst, |
| vector<typename Arc::Weight> *distance, |
| ShortestPathOptions<Arc, Queue, ArcFilter> &opts) { |
| typedef typename Arc::StateId StateId; |
| typedef typename Arc::Weight Weight; |
| |
| ofst->DeleteStates(); |
| ofst->SetInputSymbols(ifst.InputSymbols()); |
| ofst->SetOutputSymbols(ifst.OutputSymbols()); |
| |
| if (ifst.Start() == kNoStateId) { |
| if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError); |
| return; |
| } |
| |
| vector<bool> enqueued; |
| vector<StateId> parent; |
| vector<Arc> arc_parent; |
| |
| Queue *state_queue = opts.state_queue; |
| StateId source = opts.source == kNoStateId ? ifst.Start() : opts.source; |
| Weight f_distance = Weight::Zero(); |
| StateId f_parent = kNoStateId; |
| |
| distance->clear(); |
| state_queue->Clear(); |
| if (opts.nshortest != 1) { |
| FSTERROR() << "SingleShortestPath: for nshortest > 1, use ShortestPath" |
| << " instead"; |
| ofst->SetProperties(kError, kError); |
| return; |
| } |
| if (opts.weight_threshold != Weight::Zero() || |
| opts.state_threshold != kNoStateId) { |
| FSTERROR() << |
| "SingleShortestPath: weight and state thresholds not applicable"; |
| ofst->SetProperties(kError, kError); |
| return; |
| } |
| if ((Weight::Properties() & (kPath | kRightSemiring)) |
| != (kPath | kRightSemiring)) { |
| FSTERROR() << "SingleShortestPath: Weight needs to have the path" |
| << " property and be right distributive: " << Weight::Type(); |
| ofst->SetProperties(kError, kError); |
| return; |
| } |
| while (distance->size() < source) { |
| distance->push_back(Weight::Zero()); |
| enqueued.push_back(false); |
| parent.push_back(kNoStateId); |
| arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId)); |
| } |
| distance->push_back(Weight::One()); |
| parent.push_back(kNoStateId); |
| arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(), kNoStateId)); |
| state_queue->Enqueue(source); |
| enqueued.push_back(true); |
| |
| while (!state_queue->Empty()) { |
| StateId s = state_queue->Head(); |
| state_queue->Dequeue(); |
| enqueued[s] = false; |
| Weight sd = (*distance)[s]; |
| if (ifst.Final(s) != Weight::Zero()) { |
| Weight w = Times(sd, ifst.Final(s)); |
| if (f_distance != Plus(f_distance, w)) { |
| f_distance = Plus(f_distance, w); |
| f_parent = s; |
| } |
| if (!f_distance.Member()) { |
| ofst->SetProperties(kError, kError); |
| return; |
| } |
| if (opts.first_path) |
| break; |
| } |
| for (ArcIterator< Fst<Arc> > aiter(ifst, s); |
| !aiter.Done(); |
| aiter.Next()) { |
| const Arc &arc = aiter.Value(); |
| while (distance->size() <= arc.nextstate) { |
| distance->push_back(Weight::Zero()); |
| enqueued.push_back(false); |
| parent.push_back(kNoStateId); |
| arc_parent.push_back(Arc(kNoLabel, kNoLabel, Weight::Zero(), |
| kNoStateId)); |
| } |
| Weight &nd = (*distance)[arc.nextstate]; |
| Weight w = Times(sd, arc.weight); |
| if (nd != Plus(nd, w)) { |
| nd = Plus(nd, w); |
| if (!nd.Member()) { |
| ofst->SetProperties(kError, kError); |
| return; |
| } |
| parent[arc.nextstate] = s; |
| arc_parent[arc.nextstate] = arc; |
| if (!enqueued[arc.nextstate]) { |
| state_queue->Enqueue(arc.nextstate); |
| enqueued[arc.nextstate] = true; |
| } else { |
| state_queue->Update(arc.nextstate); |
| } |
| } |
| } |
| } |
| |
| StateId s_p = kNoStateId, d_p = kNoStateId; |
| for (StateId s = f_parent, d = kNoStateId; |
| s != kNoStateId; |
| d = s, s = parent[s]) { |
| d_p = s_p; |
| s_p = ofst->AddState(); |
| if (d == kNoStateId) { |
| ofst->SetFinal(s_p, ifst.Final(f_parent)); |
| } else { |
| arc_parent[d].nextstate = d_p; |
| ofst->AddArc(s_p, arc_parent[d]); |
| } |
| } |
| ofst->SetStart(s_p); |
| if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError); |
| ofst->SetProperties( |
| ShortestPathProperties(ofst->Properties(kFstProperties, false)), |
| kFstProperties); |
| } |
| |
| |
| template <class S, class W> |
| class ShortestPathCompare { |
| public: |
| typedef S StateId; |
| typedef W Weight; |
| typedef pair<StateId, Weight> Pair; |
| |
| ShortestPathCompare(const vector<Pair>& pairs, |
| const vector<Weight>& distance, |
| StateId sfinal, float d) |
| : pairs_(pairs), distance_(distance), superfinal_(sfinal), delta_(d) {} |
| |
| bool operator()(const StateId x, const StateId y) const { |
| const Pair &px = pairs_[x]; |
| const Pair &py = pairs_[y]; |
| Weight dx = px.first == superfinal_ ? Weight::One() : |
| px.first < distance_.size() ? distance_[px.first] : Weight::Zero(); |
| Weight dy = py.first == superfinal_ ? Weight::One() : |
| py.first < distance_.size() ? distance_[py.first] : Weight::Zero(); |
| Weight wx = Times(dx, px.second); |
| Weight wy = Times(dy, py.second); |
| // Penalize complete paths to ensure correct results with inexact weights. |
| // This forms a strict weak order so long as ApproxEqual(a, b) => |
| // ApproxEqual(a, c) for all c s.t. less_(a, c) && less_(c, b). |
| if (px.first == superfinal_ && py.first != superfinal_) { |
| return less_(wy, wx) || ApproxEqual(wx, wy, delta_); |
| } else if (py.first == superfinal_ && px.first != superfinal_) { |
| return less_(wy, wx) && !ApproxEqual(wx, wy, delta_); |
| } else { |
| return less_(wy, wx); |
| } |
| } |
| |
| private: |
| const vector<Pair> &pairs_; |
| const vector<Weight> &distance_; |
| StateId superfinal_; |
| float delta_; |
| NaturalLess<Weight> less_; |
| }; |
| |
| |
| // N-Shortest-path algorithm: implements the core n-shortest path |
| // algorithm. The output is built REVERSED. See below for versions with |
| // more options and not reversed. |
| // |
| // 'ofst' contains the REVERSE of 'n'-shortest paths in 'ifst'. |
| // 'distance' must contain the shortest distance from each state to a final |
| // state in 'ifst'. 'delta' is the convergence delta. |
| // |
| // The n-shortest paths are the n-lowest weight paths w.r.t. the |
| // natural semiring order. The single path that can be read from the |
| // ith of at most n transitions leaving the initial state of 'ofst' is |
| // the ith shortest path. Disregarding the initial state and initial |
| // transitions, the n-shortest paths, in fact, form a tree rooted at |
| // the single final state. |
| // |
| // The weights need to be left and right distributive (kSemiring) and |
| // have the path (kPath) property. |
| // |
| // The algorithm is from Mohri and Riley, "An Efficient Algorithm for |
| // the n-best-strings problem", ICSLP 2002. The algorithm relies on |
| // the shortest-distance algorithm. There are some issues with the |
| // pseudo-code as written in the paper (viz., line 11). |
| // |
| // IMPLEMENTATION NOTE: The input fst 'ifst' can be a delayed fst and |
| // and at any state in its expansion the values of distance vector need only |
| // be defined at that time for the states that are known to exist. |
| template<class Arc, class RevArc> |
| void NShortestPath(const Fst<RevArc> &ifst, |
| MutableFst<Arc> *ofst, |
| const vector<typename Arc::Weight> &distance, |
| size_t n, |
| float delta = kDelta, |
| typename Arc::Weight weight_threshold = Arc::Weight::Zero(), |
| typename Arc::StateId state_threshold = kNoStateId) { |
| typedef typename Arc::StateId StateId; |
| typedef typename Arc::Weight Weight; |
| typedef pair<StateId, Weight> Pair; |
| typedef typename RevArc::Weight RevWeight; |
| |
| if (n <= 0) return; |
| if ((Weight::Properties() & (kPath | kSemiring)) != (kPath | kSemiring)) { |
| FSTERROR() << "NShortestPath: Weight needs to have the " |
| << "path property and be distributive: " |
| << Weight::Type(); |
| ofst->SetProperties(kError, kError); |
| return; |
| } |
| ofst->DeleteStates(); |
| ofst->SetInputSymbols(ifst.InputSymbols()); |
| ofst->SetOutputSymbols(ifst.OutputSymbols()); |
| // Each state in 'ofst' corresponds to a path with weight w from the |
| // initial state of 'ifst' to a state s in 'ifst', that can be |
| // characterized by a pair (s,w). The vector 'pairs' maps each |
| // state in 'ofst' to the corresponding pair maps states in OFST to |
| // the corresponding pair (s,w). |
| vector<Pair> pairs; |
| // The supefinal state is denoted by -1, 'compare' knows that the |
| // distance from 'superfinal' to the final state is 'Weight::One()', |
| // hence 'distance[superfinal]' is not needed. |
| StateId superfinal = -1; |
| ShortestPathCompare<StateId, Weight> |
| compare(pairs, distance, superfinal, delta); |
| vector<StateId> heap; |
| // 'r[s + 1]', 's' state in 'fst', is the number of states in 'ofst' |
| // which corresponding pair contains 's' ,i.e. , it is number of |
| // paths computed so far to 's'. Valid for 's == -1' (superfinal). |
| vector<int> r; |
| NaturalLess<Weight> less; |
| if (ifst.Start() == kNoStateId || |
| distance.size() <= ifst.Start() || |
| distance[ifst.Start()] == Weight::Zero() || |
| less(weight_threshold, Weight::One()) || |
| state_threshold == 0) { |
| if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError); |
| return; |
| } |
| ofst->SetStart(ofst->AddState()); |
| StateId final = ofst->AddState(); |
| ofst->SetFinal(final, Weight::One()); |
| while (pairs.size() <= final) |
| pairs.push_back(Pair(kNoStateId, Weight::Zero())); |
| pairs[final] = Pair(ifst.Start(), Weight::One()); |
| heap.push_back(final); |
| Weight limit = Times(distance[ifst.Start()], weight_threshold); |
| |
| while (!heap.empty()) { |
| pop_heap(heap.begin(), heap.end(), compare); |
| StateId state = heap.back(); |
| Pair p = pairs[state]; |
| heap.pop_back(); |
| Weight d = p.first == superfinal ? Weight::One() : |
| p.first < distance.size() ? distance[p.first] : Weight::Zero(); |
| |
| if (less(limit, Times(d, p.second)) || |
| (state_threshold != kNoStateId && |
| ofst->NumStates() >= state_threshold)) |
| continue; |
| |
| while (r.size() <= p.first + 1) r.push_back(0); |
| ++r[p.first + 1]; |
| if (p.first == superfinal) |
| ofst->AddArc(ofst->Start(), Arc(0, 0, Weight::One(), state)); |
| if ((p.first == superfinal) && (r[p.first + 1] == n)) break; |
| if (r[p.first + 1] > n) continue; |
| if (p.first == superfinal) continue; |
| |
| for (ArcIterator< Fst<RevArc> > aiter(ifst, p.first); |
| !aiter.Done(); |
| aiter.Next()) { |
| const RevArc &rarc = aiter.Value(); |
| Arc arc(rarc.ilabel, rarc.olabel, rarc.weight.Reverse(), rarc.nextstate); |
| Weight w = Times(p.second, arc.weight); |
| StateId next = ofst->AddState(); |
| pairs.push_back(Pair(arc.nextstate, w)); |
| arc.nextstate = state; |
| ofst->AddArc(next, arc); |
| heap.push_back(next); |
| push_heap(heap.begin(), heap.end(), compare); |
| } |
| |
| Weight finalw = ifst.Final(p.first).Reverse(); |
| if (finalw != Weight::Zero()) { |
| Weight w = Times(p.second, finalw); |
| StateId next = ofst->AddState(); |
| pairs.push_back(Pair(superfinal, w)); |
| ofst->AddArc(next, Arc(0, 0, finalw, state)); |
| heap.push_back(next); |
| push_heap(heap.begin(), heap.end(), compare); |
| } |
| } |
| Connect(ofst); |
| if (ifst.Properties(kError, false)) ofst->SetProperties(kError, kError); |
| ofst->SetProperties( |
| ShortestPathProperties(ofst->Properties(kFstProperties, false)), |
| kFstProperties); |
| } |
| |
| |
| // N-Shortest-path algorithm: this version allow fine control |
| // via the options argument. See below for a simpler interface. |
| // |
| // 'ofst' contains the n-shortest paths in 'ifst'. 'distance' returns |
| // the shortest distances from the source state to each state in |
| // 'ifst'. 'opts' is used to specify options such as the number of |
| // paths to return, whether they need to have distinct input |
| // strings, the queue discipline, the arc filter and the convergence |
| // delta. |
| // |
| // The n-shortest paths are the n-lowest weight paths w.r.t. the |
| // natural semiring order. The single path that can be read from the |
| // ith of at most n transitions leaving the initial state of 'ofst' is |
| // the ith shortest path. Disregarding the initial state and initial |
| // transitions, The n-shortest paths, in fact, form a tree rooted at |
| // the single final state. |
| |
| // The weights need to be right distributive and have the path (kPath) |
| // property. They need to be left distributive as well for nshortest |
| // > 1. |
| // |
| // The algorithm is from Mohri and Riley, "An Efficient Algorithm for |
| // the n-best-strings problem", ICSLP 2002. The algorithm relies on |
| // the shortest-distance algorithm. There are some issues with the |
| // pseudo-code as written in the paper (viz., line 11). |
| template<class Arc, class Queue, class ArcFilter> |
| void ShortestPath(const Fst<Arc> &ifst, MutableFst<Arc> *ofst, |
| vector<typename Arc::Weight> *distance, |
| ShortestPathOptions<Arc, Queue, ArcFilter> &opts) { |
| typedef typename Arc::StateId StateId; |
| typedef typename Arc::Weight Weight; |
| typedef ReverseArc<Arc> ReverseArc; |
| |
| size_t n = opts.nshortest; |
| if (n == 1) { |
| SingleShortestPath(ifst, ofst, distance, opts); |
| return; |
| } |
| if (n <= 0) return; |
| if ((Weight::Properties() & (kPath | kSemiring)) != (kPath | kSemiring)) { |
| FSTERROR() << "ShortestPath: n-shortest: Weight needs to have the " |
| << "path property and be distributive: " |
| << Weight::Type(); |
| ofst->SetProperties(kError, kError); |
| return; |
| } |
| if (!opts.has_distance) { |
| ShortestDistance(ifst, distance, opts); |
| if (distance->size() == 1 && !(*distance)[0].Member()) { |
| ofst->SetProperties(kError, kError); |
| return; |
| } |
| } |
| // Algorithm works on the reverse of 'fst' : 'rfst', 'distance' is |
| // the distance to the final state in 'rfst', 'ofst' is built as the |
| // reverse of the tree of n-shortest path in 'rfst'. |
| VectorFst<ReverseArc> rfst; |
| Reverse(ifst, &rfst); |
| Weight d = Weight::Zero(); |
| for (ArcIterator< VectorFst<ReverseArc> > aiter(rfst, 0); |
| !aiter.Done(); aiter.Next()) { |
| const ReverseArc &arc = aiter.Value(); |
| StateId s = arc.nextstate - 1; |
| if (s < distance->size()) |
| d = Plus(d, Times(arc.weight.Reverse(), (*distance)[s])); |
| } |
| distance->insert(distance->begin(), d); |
| |
| if (!opts.unique) { |
| NShortestPath(rfst, ofst, *distance, n, opts.delta, |
| opts.weight_threshold, opts.state_threshold); |
| } else { |
| vector<Weight> ddistance; |
| DeterminizeFstOptions<ReverseArc> dopts(opts.delta); |
| DeterminizeFst<ReverseArc> dfst(rfst, *distance, &ddistance, dopts); |
| NShortestPath(dfst, ofst, ddistance, n, opts.delta, |
| opts.weight_threshold, opts.state_threshold); |
| } |
| distance->erase(distance->begin()); |
| } |
| |
| |
| // Shortest-path algorithm: simplified interface. See above for a |
| // version that allows finer control. |
| // |
| // 'ofst' contains the 'n'-shortest paths in 'ifst'. The queue |
| // discipline is automatically selected. When 'unique' == true, only |
| // paths with distinct input labels are returned. |
| // |
| // The n-shortest paths are the n-lowest weight paths w.r.t. the |
| // natural semiring order. The single path that can be read from the |
| // ith of at most n transitions leaving the initial state of 'ofst' is |
| // the ith best path. |
| // |
| // The weights need to be right distributive and have the path |
| // (kPath) property. |
| template<class Arc> |
| void ShortestPath(const Fst<Arc> &ifst, MutableFst<Arc> *ofst, |
| size_t n = 1, bool unique = false, |
| bool first_path = false, |
| typename Arc::Weight weight_threshold = Arc::Weight::Zero(), |
| typename Arc::StateId state_threshold = kNoStateId) { |
| vector<typename Arc::Weight> distance; |
| AnyArcFilter<Arc> arc_filter; |
| AutoQueue<typename Arc::StateId> state_queue(ifst, &distance, arc_filter); |
| ShortestPathOptions< Arc, AutoQueue<typename Arc::StateId>, |
| AnyArcFilter<Arc> > opts(&state_queue, arc_filter, n, unique, false, |
| kDelta, first_path, weight_threshold, |
| state_threshold); |
| ShortestPath(ifst, ofst, &distance, opts); |
| } |
| |
| } // namespace fst |
| |
| #endif // FST_LIB_SHORTEST_PATH_H__ |