| // replace-util.h |
| |
| |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| // |
| // Copyright 2005-2010 Google, Inc. |
| // Author: riley@google.com (Michael Riley) |
| // |
| |
| // \file |
| // Utility classes for the recursive replacement of Fsts (RTNs). |
| |
| #ifndef FST_LIB_REPLACE_UTIL_H__ |
| #define FST_LIB_REPLACE_UTIL_H__ |
| |
| #include <vector> |
| using std::vector; |
| #include <unordered_map> |
| using std::tr1::unordered_map; |
| using std::tr1::unordered_multimap; |
| #include <unordered_set> |
| using std::tr1::unordered_set; |
| using std::tr1::unordered_multiset; |
| #include <map> |
| |
| #include <fst/connect.h> |
| #include <fst/mutable-fst.h> |
| #include <fst/topsort.h> |
| |
| |
| namespace fst { |
| |
| template <class Arc> |
| void Replace(const vector<pair<typename Arc::Label, const Fst<Arc>* > >&, |
| MutableFst<Arc> *, typename Arc::Label, bool); |
| |
| |
| // Utility class for the recursive replacement of Fsts (RTNs). The |
| // user provides a set of Label, Fst pairs at construction. These are |
| // used by methods for testing cyclic dependencies and connectedness |
| // and doing RTN connection and specific Fst replacement by label or |
| // for various optimization properties. The modified results can be |
| // obtained with the GetFstPairs() or GetMutableFstPairs() methods. |
| template <class Arc> |
| class ReplaceUtil { |
| public: |
| typedef typename Arc::Label Label; |
| typedef typename Arc::Weight Weight; |
| typedef typename Arc::StateId StateId; |
| |
| typedef pair<Label, const Fst<Arc>*> FstPair; |
| typedef pair<Label, MutableFst<Arc>*> MutableFstPair; |
| typedef unordered_map<Label, Label> NonTerminalHash; |
| |
| // Constructs from mutable Fsts; Fst ownership given to ReplaceUtil. |
| ReplaceUtil(const vector<MutableFstPair> &fst_pairs, |
| Label root_label, bool epsilon_on_replace = false); |
| |
| // Constructs from Fsts; Fst ownership retained by caller. |
| ReplaceUtil(const vector<FstPair> &fst_pairs, |
| Label root_label, bool epsilon_on_replace = false); |
| |
| // Constructs from ReplaceFst internals; ownership retained by caller. |
| ReplaceUtil(const vector<const Fst<Arc> *> &fst_array, |
| const NonTerminalHash &nonterminal_hash, Label root_fst, |
| bool epsilon_on_replace = false); |
| |
| ~ReplaceUtil() { |
| for (Label i = 0; i < fst_array_.size(); ++i) |
| delete fst_array_[i]; |
| } |
| |
| // True if the non-terminal dependencies are cyclic. Cyclic |
| // dependencies will result in an unexpandable replace fst. |
| bool CyclicDependencies() const { |
| GetDependencies(false); |
| return depprops_ & kCyclic; |
| } |
| |
| // Returns true if no useless Fsts, states or transitions. |
| bool Connected() const { |
| GetDependencies(false); |
| uint64 props = kAccessible | kCoAccessible; |
| for (Label i = 0; i < fst_array_.size(); ++i) { |
| if (!fst_array_[i]) |
| continue; |
| if (fst_array_[i]->Properties(props, true) != props || !depaccess_[i]) |
| return false; |
| } |
| return true; |
| } |
| |
| // Removes useless Fsts, states and transitions. |
| void Connect(); |
| |
| // Replaces Fsts specified by labels. |
| // Does nothing if there are cyclic dependencies. |
| void ReplaceLabels(const vector<Label> &labels); |
| |
| // Replaces Fsts that have at most 'nstates' states, 'narcs' arcs and |
| // 'nnonterm' non-terminals (updating in reverse dependency order). |
| // Does nothing if there are cyclic dependencies. |
| void ReplaceBySize(size_t nstates, size_t narcs, size_t nnonterms); |
| |
| // Replaces singleton Fsts. |
| // Does nothing if there are cyclic dependencies. |
| void ReplaceTrivial() { ReplaceBySize(2, 1, 1); } |
| |
| // Replaces non-terminals that have at most 'ninstances' instances |
| // (updating in dependency order). |
| // Does nothing if there are cyclic dependencies. |
| void ReplaceByInstances(size_t ninstances); |
| |
| // Replaces non-terminals that have only one instance. |
| // Does nothing if there are cyclic dependencies. |
| void ReplaceUnique() { ReplaceByInstances(1); } |
| |
| // Returns Label, Fst pairs; Fst ownership retained by ReplaceUtil. |
| void GetFstPairs(vector<FstPair> *fst_pairs); |
| |
| // Returns Label, MutableFst pairs; Fst ownership given to caller. |
| void GetMutableFstPairs(vector<MutableFstPair> *mutable_fst_pairs); |
| |
| private: |
| // Per Fst statistics |
| struct ReplaceStats { |
| StateId nstates; // # of states |
| StateId nfinal; // # of final states |
| size_t narcs; // # of arcs |
| Label nnonterms; // # of non-terminals in Fst |
| size_t nref; // # of non-terminal instances referring to this Fst |
| |
| // # of times that ith Fst references this Fst |
| map<Label, size_t> inref; |
| // # of times that this Fst references the ith Fst |
| map<Label, size_t> outref; |
| |
| ReplaceStats() |
| : nstates(0), |
| nfinal(0), |
| narcs(0), |
| nnonterms(0), |
| nref(0) {} |
| }; |
| |
| // Check Mutable Fsts exist o.w. create them. |
| void CheckMutableFsts(); |
| |
| // Computes the dependency graph of the replace Fsts. |
| // If 'stats' is true, dependency statistics computed as well. |
| void GetDependencies(bool stats) const; |
| |
| void ClearDependencies() const { |
| depfst_.DeleteStates(); |
| stats_.clear(); |
| depprops_ = 0; |
| have_stats_ = false; |
| } |
| |
| // Get topological order of dependencies. Returns false with cyclic input. |
| bool GetTopOrder(const Fst<Arc> &fst, vector<Label> *toporder) const; |
| |
| // Update statistics assuming that jth Fst will be replaced. |
| void UpdateStats(Label j); |
| |
| Label root_label_; // root non-terminal |
| Label root_fst_; // root Fst ID |
| bool epsilon_on_replace_; // see Replace() |
| vector<const Fst<Arc> *> fst_array_; // Fst per ID |
| vector<MutableFst<Arc> *> mutable_fst_array_; // MutableFst per ID |
| vector<Label> nonterminal_array_; // Fst ID to non-terminal |
| NonTerminalHash nonterminal_hash_; // non-terminal to Fst ID |
| mutable VectorFst<Arc> depfst_; // Fst ID dependencies |
| mutable vector<bool> depaccess_; // Fst ID accessibility |
| mutable uint64 depprops_; // dependency Fst props |
| mutable bool have_stats_; // have dependency statistics |
| mutable vector<ReplaceStats> stats_; // Per Fst statistics |
| DISALLOW_COPY_AND_ASSIGN(ReplaceUtil); |
| }; |
| |
| template <class Arc> |
| ReplaceUtil<Arc>::ReplaceUtil( |
| const vector<MutableFstPair> &fst_pairs, |
| Label root_label, bool epsilon_on_replace) |
| : root_label_(root_label), |
| epsilon_on_replace_(epsilon_on_replace), |
| depprops_(0), |
| have_stats_(false) { |
| fst_array_.push_back(0); |
| mutable_fst_array_.push_back(0); |
| nonterminal_array_.push_back(kNoLabel); |
| for (Label i = 0; i < fst_pairs.size(); ++i) { |
| Label label = fst_pairs[i].first; |
| MutableFst<Arc> *fst = fst_pairs[i].second; |
| nonterminal_hash_[label] = fst_array_.size(); |
| nonterminal_array_.push_back(label); |
| fst_array_.push_back(fst); |
| mutable_fst_array_.push_back(fst); |
| } |
| root_fst_ = nonterminal_hash_[root_label_]; |
| if (!root_fst_) |
| FSTERROR() << "ReplaceUtil: no root FST for label: " << root_label_; |
| } |
| |
| template <class Arc> |
| ReplaceUtil<Arc>::ReplaceUtil( |
| const vector<FstPair> &fst_pairs, |
| Label root_label, bool epsilon_on_replace) |
| : root_label_(root_label), |
| epsilon_on_replace_(epsilon_on_replace), |
| depprops_(0), |
| have_stats_(false) { |
| fst_array_.push_back(0); |
| nonterminal_array_.push_back(kNoLabel); |
| for (Label i = 0; i < fst_pairs.size(); ++i) { |
| Label label = fst_pairs[i].first; |
| const Fst<Arc> *fst = fst_pairs[i].second; |
| nonterminal_hash_[label] = fst_array_.size(); |
| nonterminal_array_.push_back(label); |
| fst_array_.push_back(fst->Copy()); |
| } |
| root_fst_ = nonterminal_hash_[root_label]; |
| if (!root_fst_) |
| FSTERROR() << "ReplaceUtil: no root FST for label: " << root_label_; |
| } |
| |
| template <class Arc> |
| ReplaceUtil<Arc>::ReplaceUtil( |
| const vector<const Fst<Arc> *> &fst_array, |
| const NonTerminalHash &nonterminal_hash, Label root_fst, |
| bool epsilon_on_replace) |
| : root_fst_(root_fst), |
| epsilon_on_replace_(epsilon_on_replace), |
| nonterminal_array_(fst_array.size()), |
| nonterminal_hash_(nonterminal_hash), |
| depprops_(0), |
| have_stats_(false) { |
| fst_array_.push_back(0); |
| for (Label i = 1; i < fst_array.size(); ++i) |
| fst_array_.push_back(fst_array[i]->Copy()); |
| for (typename NonTerminalHash::const_iterator it = |
| nonterminal_hash.begin(); it != nonterminal_hash.end(); ++it) |
| nonterminal_array_[it->second] = it->first; |
| root_label_ = nonterminal_array_[root_fst_]; |
| } |
| |
| template <class Arc> |
| void ReplaceUtil<Arc>::GetDependencies(bool stats) const { |
| if (depfst_.NumStates() > 0) { |
| if (stats && !have_stats_) |
| ClearDependencies(); |
| else |
| return; |
| } |
| |
| have_stats_ = stats; |
| if (have_stats_) |
| stats_.reserve(fst_array_.size()); |
| |
| for (Label i = 0; i < fst_array_.size(); ++i) { |
| depfst_.AddState(); |
| depfst_.SetFinal(i, Weight::One()); |
| if (have_stats_) |
| stats_.push_back(ReplaceStats()); |
| } |
| depfst_.SetStart(root_fst_); |
| |
| // An arc from each state (representing the fst) to the |
| // state representing the fst being replaced |
| for (Label i = 0; i < fst_array_.size(); ++i) { |
| const Fst<Arc> *ifst = fst_array_[i]; |
| if (!ifst) |
| continue; |
| for (StateIterator<Fst<Arc> > siter(*ifst); !siter.Done(); siter.Next()) { |
| StateId s = siter.Value(); |
| if (have_stats_) { |
| ++stats_[i].nstates; |
| if (ifst->Final(s) != Weight::Zero()) |
| ++stats_[i].nfinal; |
| } |
| for (ArcIterator<Fst<Arc> > aiter(*ifst, s); |
| !aiter.Done(); aiter.Next()) { |
| if (have_stats_) |
| ++stats_[i].narcs; |
| const Arc& arc = aiter.Value(); |
| |
| typename NonTerminalHash::const_iterator it = |
| nonterminal_hash_.find(arc.olabel); |
| if (it != nonterminal_hash_.end()) { |
| Label j = it->second; |
| depfst_.AddArc(i, Arc(arc.olabel, arc.olabel, Weight::One(), j)); |
| if (have_stats_) { |
| ++stats_[i].nnonterms; |
| ++stats_[j].nref; |
| ++stats_[j].inref[i]; |
| ++stats_[i].outref[j]; |
| } |
| } |
| } |
| } |
| } |
| |
| // Gets accessibility info |
| SccVisitor<Arc> scc_visitor(0, &depaccess_, 0, &depprops_); |
| DfsVisit(depfst_, &scc_visitor); |
| } |
| |
| template <class Arc> |
| void ReplaceUtil<Arc>::UpdateStats(Label j) { |
| if (!have_stats_) { |
| FSTERROR() << "ReplaceUtil::UpdateStats: stats not available"; |
| return; |
| } |
| |
| if (j == root_fst_) // can't replace root |
| return; |
| |
| typedef typename map<Label, size_t>::iterator Iter; |
| for (Iter in = stats_[j].inref.begin(); |
| in != stats_[j].inref.end(); |
| ++in) { |
| Label i = in->first; |
| size_t ni = in->second; |
| stats_[i].nstates += stats_[j].nstates * ni; |
| stats_[i].narcs += (stats_[j].narcs + 1) * ni; // narcs - 1 + 2 (eps) |
| stats_[i].nnonterms += (stats_[j].nnonterms - 1) * ni; |
| stats_[i].outref.erase(stats_[i].outref.find(j)); |
| for (Iter out = stats_[j].outref.begin(); |
| out != stats_[j].outref.end(); |
| ++out) { |
| Label k = out->first; |
| size_t nk = out->second; |
| stats_[i].outref[k] += ni * nk; |
| } |
| } |
| |
| for (Iter out = stats_[j].outref.begin(); |
| out != stats_[j].outref.end(); |
| ++out) { |
| Label k = out->first; |
| size_t nk = out->second; |
| stats_[k].nref -= nk; |
| stats_[k].inref.erase(stats_[k].inref.find(j)); |
| for (Iter in = stats_[j].inref.begin(); |
| in != stats_[j].inref.end(); |
| ++in) { |
| Label i = in->first; |
| size_t ni = in->second; |
| stats_[k].inref[i] += ni * nk; |
| stats_[k].nref += ni * nk; |
| } |
| } |
| } |
| |
| template <class Arc> |
| void ReplaceUtil<Arc>::CheckMutableFsts() { |
| if (mutable_fst_array_.size() == 0) { |
| for (Label i = 0; i < fst_array_.size(); ++i) { |
| if (!fst_array_[i]) { |
| mutable_fst_array_.push_back(0); |
| } else { |
| mutable_fst_array_.push_back(new VectorFst<Arc>(*fst_array_[i])); |
| delete fst_array_[i]; |
| fst_array_[i] = mutable_fst_array_[i]; |
| } |
| } |
| } |
| } |
| |
| template <class Arc> |
| void ReplaceUtil<Arc>::Connect() { |
| CheckMutableFsts(); |
| uint64 props = kAccessible | kCoAccessible; |
| for (Label i = 0; i < mutable_fst_array_.size(); ++i) { |
| if (!mutable_fst_array_[i]) |
| continue; |
| if (mutable_fst_array_[i]->Properties(props, false) != props) |
| fst::Connect(mutable_fst_array_[i]); |
| } |
| GetDependencies(false); |
| for (Label i = 0; i < mutable_fst_array_.size(); ++i) { |
| MutableFst<Arc> *fst = mutable_fst_array_[i]; |
| if (fst && !depaccess_[i]) { |
| delete fst; |
| fst_array_[i] = 0; |
| mutable_fst_array_[i] = 0; |
| } |
| } |
| ClearDependencies(); |
| } |
| |
| template <class Arc> |
| bool ReplaceUtil<Arc>::GetTopOrder(const Fst<Arc> &fst, |
| vector<Label> *toporder) const { |
| // Finds topological order of dependencies. |
| vector<StateId> order; |
| bool acyclic = false; |
| |
| TopOrderVisitor<Arc> top_order_visitor(&order, &acyclic); |
| DfsVisit(fst, &top_order_visitor); |
| if (!acyclic) { |
| LOG(WARNING) << "ReplaceUtil::GetTopOrder: Cyclical label dependencies"; |
| return false; |
| } |
| |
| toporder->resize(order.size()); |
| for (Label i = 0; i < order.size(); ++i) |
| (*toporder)[order[i]] = i; |
| |
| return true; |
| } |
| |
| template <class Arc> |
| void ReplaceUtil<Arc>::ReplaceLabels(const vector<Label> &labels) { |
| CheckMutableFsts(); |
| unordered_set<Label> label_set; |
| for (Label i = 0; i < labels.size(); ++i) |
| if (labels[i] != root_label_) // can't replace root |
| label_set.insert(labels[i]); |
| |
| // Finds Fst dependencies restricted to the labels requested. |
| GetDependencies(false); |
| VectorFst<Arc> pfst(depfst_); |
| for (StateId i = 0; i < pfst.NumStates(); ++i) { |
| vector<Arc> arcs; |
| for (ArcIterator< VectorFst<Arc> > aiter(pfst, i); |
| !aiter.Done(); aiter.Next()) { |
| const Arc &arc = aiter.Value(); |
| Label label = nonterminal_array_[arc.nextstate]; |
| if (label_set.count(label) > 0) |
| arcs.push_back(arc); |
| } |
| pfst.DeleteArcs(i); |
| for (size_t j = 0; j < arcs.size(); ++j) |
| pfst.AddArc(i, arcs[j]); |
| } |
| |
| vector<Label> toporder; |
| if (!GetTopOrder(pfst, &toporder)) { |
| ClearDependencies(); |
| return; |
| } |
| |
| // Visits Fsts in reverse topological order of dependencies and |
| // performs replacements. |
| for (Label o = toporder.size() - 1; o >= 0; --o) { |
| vector<FstPair> fst_pairs; |
| StateId s = toporder[o]; |
| for (ArcIterator< VectorFst<Arc> > aiter(pfst, s); |
| !aiter.Done(); aiter.Next()) { |
| const Arc &arc = aiter.Value(); |
| Label label = nonterminal_array_[arc.nextstate]; |
| const Fst<Arc> *fst = fst_array_[arc.nextstate]; |
| fst_pairs.push_back(make_pair(label, fst)); |
| } |
| if (fst_pairs.empty()) |
| continue; |
| Label label = nonterminal_array_[s]; |
| const Fst<Arc> *fst = fst_array_[s]; |
| fst_pairs.push_back(make_pair(label, fst)); |
| |
| Replace(fst_pairs, mutable_fst_array_[s], label, epsilon_on_replace_); |
| } |
| ClearDependencies(); |
| } |
| |
| template <class Arc> |
| void ReplaceUtil<Arc>::ReplaceBySize(size_t nstates, size_t narcs, |
| size_t nnonterms) { |
| vector<Label> labels; |
| GetDependencies(true); |
| |
| vector<Label> toporder; |
| if (!GetTopOrder(depfst_, &toporder)) { |
| ClearDependencies(); |
| return; |
| } |
| |
| for (Label o = toporder.size() - 1; o >= 0; --o) { |
| Label j = toporder[o]; |
| if (stats_[j].nstates <= nstates && |
| stats_[j].narcs <= narcs && |
| stats_[j].nnonterms <= nnonterms) { |
| labels.push_back(nonterminal_array_[j]); |
| UpdateStats(j); |
| } |
| } |
| ReplaceLabels(labels); |
| } |
| |
| template <class Arc> |
| void ReplaceUtil<Arc>::ReplaceByInstances(size_t ninstances) { |
| vector<Label> labels; |
| GetDependencies(true); |
| |
| vector<Label> toporder; |
| if (!GetTopOrder(depfst_, &toporder)) { |
| ClearDependencies(); |
| return; |
| } |
| for (Label o = 0; o < toporder.size(); ++o) { |
| Label j = toporder[o]; |
| if (stats_[j].nref <= ninstances) { |
| labels.push_back(nonterminal_array_[j]); |
| UpdateStats(j); |
| } |
| } |
| ReplaceLabels(labels); |
| } |
| |
| template <class Arc> |
| void ReplaceUtil<Arc>::GetFstPairs(vector<FstPair> *fst_pairs) { |
| CheckMutableFsts(); |
| fst_pairs->clear(); |
| for (Label i = 0; i < fst_array_.size(); ++i) { |
| Label label = nonterminal_array_[i]; |
| const Fst<Arc> *fst = fst_array_[i]; |
| if (!fst) |
| continue; |
| fst_pairs->push_back(make_pair(label, fst)); |
| } |
| } |
| |
| template <class Arc> |
| void ReplaceUtil<Arc>::GetMutableFstPairs( |
| vector<MutableFstPair> *mutable_fst_pairs) { |
| CheckMutableFsts(); |
| mutable_fst_pairs->clear(); |
| for (Label i = 0; i < mutable_fst_array_.size(); ++i) { |
| Label label = nonterminal_array_[i]; |
| MutableFst<Arc> *fst = mutable_fst_array_[i]; |
| if (!fst) |
| continue; |
| mutable_fst_pairs->push_back(make_pair(label, fst->Copy())); |
| } |
| } |
| |
| } // namespace fst |
| |
| #endif // FST_LIB_REPLACE_UTIL_H__ |