blob: f4a9c05535ab03df6429fbfc9b22b4c52052176a [file] [log] [blame]
// replace-util.h
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: riley@google.com (Michael Riley)
//
// \file
// Utility classes for the recursive replacement of Fsts (RTNs).
#ifndef FST_LIB_REPLACE_UTIL_H__
#define FST_LIB_REPLACE_UTIL_H__
#include <vector>
using std::vector;
#include <unordered_map>
using std::tr1::unordered_map;
using std::tr1::unordered_multimap;
#include <unordered_set>
using std::tr1::unordered_set;
using std::tr1::unordered_multiset;
#include <map>
#include <fst/connect.h>
#include <fst/mutable-fst.h>
#include <fst/topsort.h>
namespace fst {
template <class Arc>
void Replace(const vector<pair<typename Arc::Label, const Fst<Arc>* > >&,
MutableFst<Arc> *, typename Arc::Label, bool);
// Utility class for the recursive replacement of Fsts (RTNs). The
// user provides a set of Label, Fst pairs at construction. These are
// used by methods for testing cyclic dependencies and connectedness
// and doing RTN connection and specific Fst replacement by label or
// for various optimization properties. The modified results can be
// obtained with the GetFstPairs() or GetMutableFstPairs() methods.
template <class Arc>
class ReplaceUtil {
public:
typedef typename Arc::Label Label;
typedef typename Arc::Weight Weight;
typedef typename Arc::StateId StateId;
typedef pair<Label, const Fst<Arc>*> FstPair;
typedef pair<Label, MutableFst<Arc>*> MutableFstPair;
typedef unordered_map<Label, Label> NonTerminalHash;
// Constructs from mutable Fsts; Fst ownership given to ReplaceUtil.
ReplaceUtil(const vector<MutableFstPair> &fst_pairs,
Label root_label, bool epsilon_on_replace = false);
// Constructs from Fsts; Fst ownership retained by caller.
ReplaceUtil(const vector<FstPair> &fst_pairs,
Label root_label, bool epsilon_on_replace = false);
// Constructs from ReplaceFst internals; ownership retained by caller.
ReplaceUtil(const vector<const Fst<Arc> *> &fst_array,
const NonTerminalHash &nonterminal_hash, Label root_fst,
bool epsilon_on_replace = false);
~ReplaceUtil() {
for (Label i = 0; i < fst_array_.size(); ++i)
delete fst_array_[i];
}
// True if the non-terminal dependencies are cyclic. Cyclic
// dependencies will result in an unexpandable replace fst.
bool CyclicDependencies() const {
GetDependencies(false);
return depprops_ & kCyclic;
}
// Returns true if no useless Fsts, states or transitions.
bool Connected() const {
GetDependencies(false);
uint64 props = kAccessible | kCoAccessible;
for (Label i = 0; i < fst_array_.size(); ++i) {
if (!fst_array_[i])
continue;
if (fst_array_[i]->Properties(props, true) != props || !depaccess_[i])
return false;
}
return true;
}
// Removes useless Fsts, states and transitions.
void Connect();
// Replaces Fsts specified by labels.
// Does nothing if there are cyclic dependencies.
void ReplaceLabels(const vector<Label> &labels);
// Replaces Fsts that have at most 'nstates' states, 'narcs' arcs and
// 'nnonterm' non-terminals (updating in reverse dependency order).
// Does nothing if there are cyclic dependencies.
void ReplaceBySize(size_t nstates, size_t narcs, size_t nnonterms);
// Replaces singleton Fsts.
// Does nothing if there are cyclic dependencies.
void ReplaceTrivial() { ReplaceBySize(2, 1, 1); }
// Replaces non-terminals that have at most 'ninstances' instances
// (updating in dependency order).
// Does nothing if there are cyclic dependencies.
void ReplaceByInstances(size_t ninstances);
// Replaces non-terminals that have only one instance.
// Does nothing if there are cyclic dependencies.
void ReplaceUnique() { ReplaceByInstances(1); }
// Returns Label, Fst pairs; Fst ownership retained by ReplaceUtil.
void GetFstPairs(vector<FstPair> *fst_pairs);
// Returns Label, MutableFst pairs; Fst ownership given to caller.
void GetMutableFstPairs(vector<MutableFstPair> *mutable_fst_pairs);
private:
// Per Fst statistics
struct ReplaceStats {
StateId nstates; // # of states
StateId nfinal; // # of final states
size_t narcs; // # of arcs
Label nnonterms; // # of non-terminals in Fst
size_t nref; // # of non-terminal instances referring to this Fst
// # of times that ith Fst references this Fst
map<Label, size_t> inref;
// # of times that this Fst references the ith Fst
map<Label, size_t> outref;
ReplaceStats()
: nstates(0),
nfinal(0),
narcs(0),
nnonterms(0),
nref(0) {}
};
// Check Mutable Fsts exist o.w. create them.
void CheckMutableFsts();
// Computes the dependency graph of the replace Fsts.
// If 'stats' is true, dependency statistics computed as well.
void GetDependencies(bool stats) const;
void ClearDependencies() const {
depfst_.DeleteStates();
stats_.clear();
depprops_ = 0;
have_stats_ = false;
}
// Get topological order of dependencies. Returns false with cyclic input.
bool GetTopOrder(const Fst<Arc> &fst, vector<Label> *toporder) const;
// Update statistics assuming that jth Fst will be replaced.
void UpdateStats(Label j);
Label root_label_; // root non-terminal
Label root_fst_; // root Fst ID
bool epsilon_on_replace_; // see Replace()
vector<const Fst<Arc> *> fst_array_; // Fst per ID
vector<MutableFst<Arc> *> mutable_fst_array_; // MutableFst per ID
vector<Label> nonterminal_array_; // Fst ID to non-terminal
NonTerminalHash nonterminal_hash_; // non-terminal to Fst ID
mutable VectorFst<Arc> depfst_; // Fst ID dependencies
mutable vector<bool> depaccess_; // Fst ID accessibility
mutable uint64 depprops_; // dependency Fst props
mutable bool have_stats_; // have dependency statistics
mutable vector<ReplaceStats> stats_; // Per Fst statistics
DISALLOW_COPY_AND_ASSIGN(ReplaceUtil);
};
template <class Arc>
ReplaceUtil<Arc>::ReplaceUtil(
const vector<MutableFstPair> &fst_pairs,
Label root_label, bool epsilon_on_replace)
: root_label_(root_label),
epsilon_on_replace_(epsilon_on_replace),
depprops_(0),
have_stats_(false) {
fst_array_.push_back(0);
mutable_fst_array_.push_back(0);
nonterminal_array_.push_back(kNoLabel);
for (Label i = 0; i < fst_pairs.size(); ++i) {
Label label = fst_pairs[i].first;
MutableFst<Arc> *fst = fst_pairs[i].second;
nonterminal_hash_[label] = fst_array_.size();
nonterminal_array_.push_back(label);
fst_array_.push_back(fst);
mutable_fst_array_.push_back(fst);
}
root_fst_ = nonterminal_hash_[root_label_];
if (!root_fst_)
FSTERROR() << "ReplaceUtil: no root FST for label: " << root_label_;
}
template <class Arc>
ReplaceUtil<Arc>::ReplaceUtil(
const vector<FstPair> &fst_pairs,
Label root_label, bool epsilon_on_replace)
: root_label_(root_label),
epsilon_on_replace_(epsilon_on_replace),
depprops_(0),
have_stats_(false) {
fst_array_.push_back(0);
nonterminal_array_.push_back(kNoLabel);
for (Label i = 0; i < fst_pairs.size(); ++i) {
Label label = fst_pairs[i].first;
const Fst<Arc> *fst = fst_pairs[i].second;
nonterminal_hash_[label] = fst_array_.size();
nonterminal_array_.push_back(label);
fst_array_.push_back(fst->Copy());
}
root_fst_ = nonterminal_hash_[root_label];
if (!root_fst_)
FSTERROR() << "ReplaceUtil: no root FST for label: " << root_label_;
}
template <class Arc>
ReplaceUtil<Arc>::ReplaceUtil(
const vector<const Fst<Arc> *> &fst_array,
const NonTerminalHash &nonterminal_hash, Label root_fst,
bool epsilon_on_replace)
: root_fst_(root_fst),
epsilon_on_replace_(epsilon_on_replace),
nonterminal_array_(fst_array.size()),
nonterminal_hash_(nonterminal_hash),
depprops_(0),
have_stats_(false) {
fst_array_.push_back(0);
for (Label i = 1; i < fst_array.size(); ++i)
fst_array_.push_back(fst_array[i]->Copy());
for (typename NonTerminalHash::const_iterator it =
nonterminal_hash.begin(); it != nonterminal_hash.end(); ++it)
nonterminal_array_[it->second] = it->first;
root_label_ = nonterminal_array_[root_fst_];
}
template <class Arc>
void ReplaceUtil<Arc>::GetDependencies(bool stats) const {
if (depfst_.NumStates() > 0) {
if (stats && !have_stats_)
ClearDependencies();
else
return;
}
have_stats_ = stats;
if (have_stats_)
stats_.reserve(fst_array_.size());
for (Label i = 0; i < fst_array_.size(); ++i) {
depfst_.AddState();
depfst_.SetFinal(i, Weight::One());
if (have_stats_)
stats_.push_back(ReplaceStats());
}
depfst_.SetStart(root_fst_);
// An arc from each state (representing the fst) to the
// state representing the fst being replaced
for (Label i = 0; i < fst_array_.size(); ++i) {
const Fst<Arc> *ifst = fst_array_[i];
if (!ifst)
continue;
for (StateIterator<Fst<Arc> > siter(*ifst); !siter.Done(); siter.Next()) {
StateId s = siter.Value();
if (have_stats_) {
++stats_[i].nstates;
if (ifst->Final(s) != Weight::Zero())
++stats_[i].nfinal;
}
for (ArcIterator<Fst<Arc> > aiter(*ifst, s);
!aiter.Done(); aiter.Next()) {
if (have_stats_)
++stats_[i].narcs;
const Arc& arc = aiter.Value();
typename NonTerminalHash::const_iterator it =
nonterminal_hash_.find(arc.olabel);
if (it != nonterminal_hash_.end()) {
Label j = it->second;
depfst_.AddArc(i, Arc(arc.olabel, arc.olabel, Weight::One(), j));
if (have_stats_) {
++stats_[i].nnonterms;
++stats_[j].nref;
++stats_[j].inref[i];
++stats_[i].outref[j];
}
}
}
}
}
// Gets accessibility info
SccVisitor<Arc> scc_visitor(0, &depaccess_, 0, &depprops_);
DfsVisit(depfst_, &scc_visitor);
}
template <class Arc>
void ReplaceUtil<Arc>::UpdateStats(Label j) {
if (!have_stats_) {
FSTERROR() << "ReplaceUtil::UpdateStats: stats not available";
return;
}
if (j == root_fst_) // can't replace root
return;
typedef typename map<Label, size_t>::iterator Iter;
for (Iter in = stats_[j].inref.begin();
in != stats_[j].inref.end();
++in) {
Label i = in->first;
size_t ni = in->second;
stats_[i].nstates += stats_[j].nstates * ni;
stats_[i].narcs += (stats_[j].narcs + 1) * ni; // narcs - 1 + 2 (eps)
stats_[i].nnonterms += (stats_[j].nnonterms - 1) * ni;
stats_[i].outref.erase(stats_[i].outref.find(j));
for (Iter out = stats_[j].outref.begin();
out != stats_[j].outref.end();
++out) {
Label k = out->first;
size_t nk = out->second;
stats_[i].outref[k] += ni * nk;
}
}
for (Iter out = stats_[j].outref.begin();
out != stats_[j].outref.end();
++out) {
Label k = out->first;
size_t nk = out->second;
stats_[k].nref -= nk;
stats_[k].inref.erase(stats_[k].inref.find(j));
for (Iter in = stats_[j].inref.begin();
in != stats_[j].inref.end();
++in) {
Label i = in->first;
size_t ni = in->second;
stats_[k].inref[i] += ni * nk;
stats_[k].nref += ni * nk;
}
}
}
template <class Arc>
void ReplaceUtil<Arc>::CheckMutableFsts() {
if (mutable_fst_array_.size() == 0) {
for (Label i = 0; i < fst_array_.size(); ++i) {
if (!fst_array_[i]) {
mutable_fst_array_.push_back(0);
} else {
mutable_fst_array_.push_back(new VectorFst<Arc>(*fst_array_[i]));
delete fst_array_[i];
fst_array_[i] = mutable_fst_array_[i];
}
}
}
}
template <class Arc>
void ReplaceUtil<Arc>::Connect() {
CheckMutableFsts();
uint64 props = kAccessible | kCoAccessible;
for (Label i = 0; i < mutable_fst_array_.size(); ++i) {
if (!mutable_fst_array_[i])
continue;
if (mutable_fst_array_[i]->Properties(props, false) != props)
fst::Connect(mutable_fst_array_[i]);
}
GetDependencies(false);
for (Label i = 0; i < mutable_fst_array_.size(); ++i) {
MutableFst<Arc> *fst = mutable_fst_array_[i];
if (fst && !depaccess_[i]) {
delete fst;
fst_array_[i] = 0;
mutable_fst_array_[i] = 0;
}
}
ClearDependencies();
}
template <class Arc>
bool ReplaceUtil<Arc>::GetTopOrder(const Fst<Arc> &fst,
vector<Label> *toporder) const {
// Finds topological order of dependencies.
vector<StateId> order;
bool acyclic = false;
TopOrderVisitor<Arc> top_order_visitor(&order, &acyclic);
DfsVisit(fst, &top_order_visitor);
if (!acyclic) {
LOG(WARNING) << "ReplaceUtil::GetTopOrder: Cyclical label dependencies";
return false;
}
toporder->resize(order.size());
for (Label i = 0; i < order.size(); ++i)
(*toporder)[order[i]] = i;
return true;
}
template <class Arc>
void ReplaceUtil<Arc>::ReplaceLabels(const vector<Label> &labels) {
CheckMutableFsts();
unordered_set<Label> label_set;
for (Label i = 0; i < labels.size(); ++i)
if (labels[i] != root_label_) // can't replace root
label_set.insert(labels[i]);
// Finds Fst dependencies restricted to the labels requested.
GetDependencies(false);
VectorFst<Arc> pfst(depfst_);
for (StateId i = 0; i < pfst.NumStates(); ++i) {
vector<Arc> arcs;
for (ArcIterator< VectorFst<Arc> > aiter(pfst, i);
!aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
Label label = nonterminal_array_[arc.nextstate];
if (label_set.count(label) > 0)
arcs.push_back(arc);
}
pfst.DeleteArcs(i);
for (size_t j = 0; j < arcs.size(); ++j)
pfst.AddArc(i, arcs[j]);
}
vector<Label> toporder;
if (!GetTopOrder(pfst, &toporder)) {
ClearDependencies();
return;
}
// Visits Fsts in reverse topological order of dependencies and
// performs replacements.
for (Label o = toporder.size() - 1; o >= 0; --o) {
vector<FstPair> fst_pairs;
StateId s = toporder[o];
for (ArcIterator< VectorFst<Arc> > aiter(pfst, s);
!aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
Label label = nonterminal_array_[arc.nextstate];
const Fst<Arc> *fst = fst_array_[arc.nextstate];
fst_pairs.push_back(make_pair(label, fst));
}
if (fst_pairs.empty())
continue;
Label label = nonterminal_array_[s];
const Fst<Arc> *fst = fst_array_[s];
fst_pairs.push_back(make_pair(label, fst));
Replace(fst_pairs, mutable_fst_array_[s], label, epsilon_on_replace_);
}
ClearDependencies();
}
template <class Arc>
void ReplaceUtil<Arc>::ReplaceBySize(size_t nstates, size_t narcs,
size_t nnonterms) {
vector<Label> labels;
GetDependencies(true);
vector<Label> toporder;
if (!GetTopOrder(depfst_, &toporder)) {
ClearDependencies();
return;
}
for (Label o = toporder.size() - 1; o >= 0; --o) {
Label j = toporder[o];
if (stats_[j].nstates <= nstates &&
stats_[j].narcs <= narcs &&
stats_[j].nnonterms <= nnonterms) {
labels.push_back(nonterminal_array_[j]);
UpdateStats(j);
}
}
ReplaceLabels(labels);
}
template <class Arc>
void ReplaceUtil<Arc>::ReplaceByInstances(size_t ninstances) {
vector<Label> labels;
GetDependencies(true);
vector<Label> toporder;
if (!GetTopOrder(depfst_, &toporder)) {
ClearDependencies();
return;
}
for (Label o = 0; o < toporder.size(); ++o) {
Label j = toporder[o];
if (stats_[j].nref <= ninstances) {
labels.push_back(nonterminal_array_[j]);
UpdateStats(j);
}
}
ReplaceLabels(labels);
}
template <class Arc>
void ReplaceUtil<Arc>::GetFstPairs(vector<FstPair> *fst_pairs) {
CheckMutableFsts();
fst_pairs->clear();
for (Label i = 0; i < fst_array_.size(); ++i) {
Label label = nonterminal_array_[i];
const Fst<Arc> *fst = fst_array_[i];
if (!fst)
continue;
fst_pairs->push_back(make_pair(label, fst));
}
}
template <class Arc>
void ReplaceUtil<Arc>::GetMutableFstPairs(
vector<MutableFstPair> *mutable_fst_pairs) {
CheckMutableFsts();
mutable_fst_pairs->clear();
for (Label i = 0; i < mutable_fst_array_.size(); ++i) {
Label label = nonterminal_array_[i];
MutableFst<Arc> *fst = mutable_fst_array_[i];
if (!fst)
continue;
mutable_fst_pairs->push_back(make_pair(label, fst->Copy()));
}
}
} // namespace fst
#endif // FST_LIB_REPLACE_UTIL_H__