blob: 89b8178aae6794d866a6d18dca8b12f9eb8fca17 [file] [log] [blame]
// rmepsilon.h
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: allauzen@google.com (Cyril Allauzen)
//
// \file
// Functions and classes that implemement epsilon-removal.
#ifndef FST_LIB_RMEPSILON_H__
#define FST_LIB_RMEPSILON_H__
#include <tr1/unordered_map>
using std::tr1::unordered_map;
using std::tr1::unordered_multimap;
#include <fst/slist.h>
#include <stack>
#include <string>
#include <utility>
using std::pair; using std::make_pair;
#include <vector>
using std::vector;
#include <fst/arcfilter.h>
#include <fst/cache.h>
#include <fst/connect.h>
#include <fst/factor-weight.h>
#include <fst/invert.h>
#include <fst/prune.h>
#include <fst/queue.h>
#include <fst/shortest-distance.h>
#include <fst/topsort.h>
namespace fst {
template <class Arc, class Queue>
class RmEpsilonOptions
: public ShortestDistanceOptions<Arc, Queue, EpsilonArcFilter<Arc> > {
public:
typedef typename Arc::StateId StateId;
typedef typename Arc::Weight Weight;
bool connect; // Connect output
Weight weight_threshold; // Pruning weight threshold.
StateId state_threshold; // Pruning state threshold.
explicit RmEpsilonOptions(Queue *q, float d = kDelta, bool c = true,
Weight w = Weight::Zero(),
StateId n = kNoStateId)
: ShortestDistanceOptions< Arc, Queue, EpsilonArcFilter<Arc> >(
q, EpsilonArcFilter<Arc>(), kNoStateId, d),
connect(c), weight_threshold(w), state_threshold(n) {}
private:
RmEpsilonOptions(); // disallow
};
// Computation state of the epsilon-removal algorithm.
template <class Arc, class Queue>
class RmEpsilonState {
public:
typedef typename Arc::Label Label;
typedef typename Arc::StateId StateId;
typedef typename Arc::Weight Weight;
RmEpsilonState(const Fst<Arc> &fst,
vector<Weight> *distance,
const RmEpsilonOptions<Arc, Queue> &opts)
: fst_(fst), distance_(distance), sd_state_(fst_, distance, opts, true),
expand_id_(0) {}
// Compute arcs and final weight for state 's'
void Expand(StateId s);
// Returns arcs of expanded state.
vector<Arc> &Arcs() { return arcs_; }
// Returns final weight of expanded state.
const Weight &Final() const { return final_; }
// Return true if an error has occured.
bool Error() const { return sd_state_.Error(); }
private:
static const size_t kPrime0 = 7853;
static const size_t kPrime1 = 7867;
struct Element {
Label ilabel;
Label olabel;
StateId nextstate;
Element() {}
Element(Label i, Label o, StateId s)
: ilabel(i), olabel(o), nextstate(s) {}
};
class ElementKey {
public:
size_t operator()(const Element& e) const {
return static_cast<size_t>(e.nextstate +
e.ilabel * kPrime0 +
e.olabel * kPrime1);
}
private:
};
class ElementEqual {
public:
bool operator()(const Element &e1, const Element &e2) const {
return (e1.ilabel == e2.ilabel) && (e1.olabel == e2.olabel)
&& (e1.nextstate == e2.nextstate);
}
};
typedef unordered_map<Element, pair<StateId, size_t>,
ElementKey, ElementEqual> ElementMap;
const Fst<Arc> &fst_;
// Distance from state being expanded in epsilon-closure.
vector<Weight> *distance_;
// Shortest distance algorithm computation state.
ShortestDistanceState<Arc, Queue, EpsilonArcFilter<Arc> > sd_state_;
// Maps an element 'e' to a pair 'p' corresponding to a position
// in the arcs vector of the state being expanded. 'e' corresponds
// to the position 'p.second' in the 'arcs_' vector if 'p.first' is
// equal to the state being expanded.
ElementMap element_map_;
EpsilonArcFilter<Arc> eps_filter_;
stack<StateId> eps_queue_; // Queue used to visit the epsilon-closure
vector<bool> visited_; // '[i] = true' if state 'i' has been visited
slist<StateId> visited_states_; // List of visited states
vector<Arc> arcs_; // Arcs of state being expanded
Weight final_; // Final weight of state being expanded
StateId expand_id_; // Unique ID for each call to Expand
DISALLOW_COPY_AND_ASSIGN(RmEpsilonState);
};
template <class Arc, class Queue>
const size_t RmEpsilonState<Arc, Queue>::kPrime0;
template <class Arc, class Queue>
const size_t RmEpsilonState<Arc, Queue>::kPrime1;
template <class Arc, class Queue>
void RmEpsilonState<Arc,Queue>::Expand(typename Arc::StateId source) {
final_ = Weight::Zero();
arcs_.clear();
sd_state_.ShortestDistance(source);
if (sd_state_.Error())
return;
eps_queue_.push(source);
while (!eps_queue_.empty()) {
StateId state = eps_queue_.top();
eps_queue_.pop();
while (visited_.size() <= state) visited_.push_back(false);
if (visited_[state]) continue;
visited_[state] = true;
visited_states_.push_front(state);
for (ArcIterator< Fst<Arc> > ait(fst_, state);
!ait.Done();
ait.Next()) {
Arc arc = ait.Value();
arc.weight = Times((*distance_)[state], arc.weight);
if (eps_filter_(arc)) {
while (visited_.size() <= arc.nextstate)
visited_.push_back(false);
if (!visited_[arc.nextstate])
eps_queue_.push(arc.nextstate);
} else {
Element element(arc.ilabel, arc.olabel, arc.nextstate);
typename ElementMap::iterator it = element_map_.find(element);
if (it == element_map_.end()) {
element_map_.insert(
pair<Element, pair<StateId, size_t> >
(element, pair<StateId, size_t>(expand_id_, arcs_.size())));
arcs_.push_back(arc);
} else {
if (((*it).second).first == expand_id_) {
Weight &w = arcs_[((*it).second).second].weight;
w = Plus(w, arc.weight);
} else {
((*it).second).first = expand_id_;
((*it).second).second = arcs_.size();
arcs_.push_back(arc);
}
}
}
}
final_ = Plus(final_, Times((*distance_)[state], fst_.Final(state)));
}
while (!visited_states_.empty()) {
visited_[visited_states_.front()] = false;
visited_states_.pop_front();
}
++expand_id_;
}
// Removes epsilon-transitions (when both the input and output label
// are an epsilon) from a transducer. The result will be an equivalent
// FST that has no such epsilon transitions. This version modifies
// its input. It allows fine control via the options argument; see
// below for a simpler interface.
//
// The vector 'distance' will be used to hold the shortest distances
// during the epsilon-closure computation. The state queue discipline
// and convergence delta are taken in the options argument.
template <class Arc, class Queue>
void RmEpsilon(MutableFst<Arc> *fst,
vector<typename Arc::Weight> *distance,
const RmEpsilonOptions<Arc, Queue> &opts) {
typedef typename Arc::StateId StateId;
typedef typename Arc::Weight Weight;
typedef typename Arc::Label Label;
if (fst->Start() == kNoStateId) {
return;
}
// 'noneps_in[s]' will be set to true iff 's' admits a non-epsilon
// incoming transition or is the start state.
vector<bool> noneps_in(fst->NumStates(), false);
noneps_in[fst->Start()] = true;
for (StateId i = 0; i < fst->NumStates(); ++i) {
for (ArcIterator<Fst<Arc> > aiter(*fst, i);
!aiter.Done();
aiter.Next()) {
if (aiter.Value().ilabel != 0 || aiter.Value().olabel != 0)
noneps_in[aiter.Value().nextstate] = true;
}
}
// States sorted in topological order when (acyclic) or generic
// topological order (cyclic).
vector<StateId> states;
states.reserve(fst->NumStates());
if (fst->Properties(kTopSorted, false) & kTopSorted) {
for (StateId i = 0; i < fst->NumStates(); i++)
states.push_back(i);
} else if (fst->Properties(kAcyclic, false) & kAcyclic) {
vector<StateId> order;
bool acyclic;
TopOrderVisitor<Arc> top_order_visitor(&order, &acyclic);
DfsVisit(*fst, &top_order_visitor, EpsilonArcFilter<Arc>());
// Sanity check: should be acyclic if property bit is set.
if(!acyclic) {
FSTERROR() << "RmEpsilon: inconsistent acyclic property bit";
fst->SetProperties(kError, kError);
return;
}
states.resize(order.size());
for (StateId i = 0; i < order.size(); i++)
states[order[i]] = i;
} else {
uint64 props;
vector<StateId> scc;
SccVisitor<Arc> scc_visitor(&scc, 0, 0, &props);
DfsVisit(*fst, &scc_visitor, EpsilonArcFilter<Arc>());
vector<StateId> first(scc.size(), kNoStateId);
vector<StateId> next(scc.size(), kNoStateId);
for (StateId i = 0; i < scc.size(); i++) {
if (first[scc[i]] != kNoStateId)
next[i] = first[scc[i]];
first[scc[i]] = i;
}
for (StateId i = 0; i < first.size(); i++)
for (StateId j = first[i]; j != kNoStateId; j = next[j])
states.push_back(j);
}
RmEpsilonState<Arc, Queue>
rmeps_state(*fst, distance, opts);
while (!states.empty()) {
StateId state = states.back();
states.pop_back();
if (!noneps_in[state])
continue;
rmeps_state.Expand(state);
fst->SetFinal(state, rmeps_state.Final());
fst->DeleteArcs(state);
vector<Arc> &arcs = rmeps_state.Arcs();
fst->ReserveArcs(state, arcs.size());
while (!arcs.empty()) {
fst->AddArc(state, arcs.back());
arcs.pop_back();
}
}
for (StateId s = 0; s < fst->NumStates(); ++s) {
if (!noneps_in[s])
fst->DeleteArcs(s);
}
if(rmeps_state.Error())
fst->SetProperties(kError, kError);
fst->SetProperties(
RmEpsilonProperties(fst->Properties(kFstProperties, false)),
kFstProperties);
if (opts.weight_threshold != Weight::Zero() ||
opts.state_threshold != kNoStateId)
Prune(fst, opts.weight_threshold, opts.state_threshold);
if (opts.connect && (opts.weight_threshold == Weight::Zero() ||
opts.state_threshold != kNoStateId))
Connect(fst);
}
// Removes epsilon-transitions (when both the input and output label
// are an epsilon) from a transducer. The result will be an equivalent
// FST that has no such epsilon transitions. This version modifies its
// input. It has a simplified interface; see above for a version that
// allows finer control.
//
// Complexity:
// - Time:
// - Unweighted: O(V2 + V E)
// - Acyclic: O(V2 + V E)
// - Tropical semiring: O(V2 log V + V E)
// - General: exponential
// - Space: O(V E)
// where V = # of states visited, E = # of arcs.
//
// References:
// - Mehryar Mohri. Generic Epsilon-Removal and Input
// Epsilon-Normalization Algorithms for Weighted Transducers,
// "International Journal of Computer Science", 13(1):129-143 (2002).
template <class Arc>
void RmEpsilon(MutableFst<Arc> *fst,
bool connect = true,
typename Arc::Weight weight_threshold = Arc::Weight::Zero(),
typename Arc::StateId state_threshold = kNoStateId,
float delta = kDelta) {
typedef typename Arc::StateId StateId;
typedef typename Arc::Weight Weight;
typedef typename Arc::Label Label;
vector<Weight> distance;
AutoQueue<StateId> state_queue(*fst, &distance, EpsilonArcFilter<Arc>());
RmEpsilonOptions<Arc, AutoQueue<StateId> >
opts(&state_queue, delta, connect, weight_threshold, state_threshold);
RmEpsilon(fst, &distance, opts);
}
struct RmEpsilonFstOptions : CacheOptions {
float delta;
RmEpsilonFstOptions(const CacheOptions &opts, float delta = kDelta)
: CacheOptions(opts), delta(delta) {}
explicit RmEpsilonFstOptions(float delta = kDelta) : delta(delta) {}
};
// Implementation of delayed RmEpsilonFst.
template <class A>
class RmEpsilonFstImpl : public CacheImpl<A> {
public:
using FstImpl<A>::SetType;
using FstImpl<A>::SetProperties;
using FstImpl<A>::SetInputSymbols;
using FstImpl<A>::SetOutputSymbols;
using CacheBaseImpl< CacheState<A> >::PushArc;
using CacheBaseImpl< CacheState<A> >::HasArcs;
using CacheBaseImpl< CacheState<A> >::HasFinal;
using CacheBaseImpl< CacheState<A> >::HasStart;
using CacheBaseImpl< CacheState<A> >::SetArcs;
using CacheBaseImpl< CacheState<A> >::SetFinal;
using CacheBaseImpl< CacheState<A> >::SetStart;
typedef typename A::Label Label;
typedef typename A::Weight Weight;
typedef typename A::StateId StateId;
typedef CacheState<A> State;
RmEpsilonFstImpl(const Fst<A>& fst, const RmEpsilonFstOptions &opts)
: CacheImpl<A>(opts),
fst_(fst.Copy()),
delta_(opts.delta),
rmeps_state_(
*fst_,
&distance_,
RmEpsilonOptions<A, FifoQueue<StateId> >(&queue_, delta_, false)) {
SetType("rmepsilon");
uint64 props = fst.Properties(kFstProperties, false);
SetProperties(RmEpsilonProperties(props, true), kCopyProperties);
SetInputSymbols(fst.InputSymbols());
SetOutputSymbols(fst.OutputSymbols());
}
RmEpsilonFstImpl(const RmEpsilonFstImpl &impl)
: CacheImpl<A>(impl),
fst_(impl.fst_->Copy(true)),
delta_(impl.delta_),
rmeps_state_(
*fst_,
&distance_,
RmEpsilonOptions<A, FifoQueue<StateId> >(&queue_, delta_, false)) {
SetType("rmepsilon");
SetProperties(impl.Properties(), kCopyProperties);
SetInputSymbols(impl.InputSymbols());
SetOutputSymbols(impl.OutputSymbols());
}
~RmEpsilonFstImpl() {
delete fst_;
}
StateId Start() {
if (!HasStart()) {
SetStart(fst_->Start());
}
return CacheImpl<A>::Start();
}
Weight Final(StateId s) {
if (!HasFinal(s)) {
Expand(s);
}
return CacheImpl<A>::Final(s);
}
size_t NumArcs(StateId s) {
if (!HasArcs(s))
Expand(s);
return CacheImpl<A>::NumArcs(s);
}
size_t NumInputEpsilons(StateId s) {
if (!HasArcs(s))
Expand(s);
return CacheImpl<A>::NumInputEpsilons(s);
}
size_t NumOutputEpsilons(StateId s) {
if (!HasArcs(s))
Expand(s);
return CacheImpl<A>::NumOutputEpsilons(s);
}
uint64 Properties() const { return Properties(kFstProperties); }
// Set error if found; return FST impl properties.
uint64 Properties(uint64 mask) const {
if ((mask & kError) &&
(fst_->Properties(kError, false) || rmeps_state_.Error()))
SetProperties(kError, kError);
return FstImpl<A>::Properties(mask);
}
void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
if (!HasArcs(s))
Expand(s);
CacheImpl<A>::InitArcIterator(s, data);
}
void Expand(StateId s) {
rmeps_state_.Expand(s);
SetFinal(s, rmeps_state_.Final());
vector<A> &arcs = rmeps_state_.Arcs();
while (!arcs.empty()) {
PushArc(s, arcs.back());
arcs.pop_back();
}
SetArcs(s);
}
private:
const Fst<A> *fst_;
float delta_;
vector<Weight> distance_;
FifoQueue<StateId> queue_;
RmEpsilonState<A, FifoQueue<StateId> > rmeps_state_;
void operator=(const RmEpsilonFstImpl<A> &); // disallow
};
// Removes epsilon-transitions (when both the input and output label
// are an epsilon) from a transducer. The result will be an equivalent
// FST that has no such epsilon transitions. This version is a
// delayed Fst.
//
// Complexity:
// - Time:
// - Unweighted: O(v^2 + v e)
// - General: exponential
// - Space: O(v e)
// where v = # of states visited, e = # of arcs visited. Constant time
// to visit an input state or arc is assumed and exclusive of caching.
//
// References:
// - Mehryar Mohri. Generic Epsilon-Removal and Input
// Epsilon-Normalization Algorithms for Weighted Transducers,
// "International Journal of Computer Science", 13(1):129-143 (2002).
//
// This class attaches interface to implementation and handles
// reference counting, delegating most methods to ImplToFst.
template <class A>
class RmEpsilonFst : public ImplToFst< RmEpsilonFstImpl<A> > {
public:
friend class ArcIterator< RmEpsilonFst<A> >;
friend class StateIterator< RmEpsilonFst<A> >;
typedef A Arc;
typedef typename A::StateId StateId;
typedef CacheState<A> State;
typedef RmEpsilonFstImpl<A> Impl;
RmEpsilonFst(const Fst<A> &fst)
: ImplToFst<Impl>(new Impl(fst, RmEpsilonFstOptions())) {}
RmEpsilonFst(const Fst<A> &fst, const RmEpsilonFstOptions &opts)
: ImplToFst<Impl>(new Impl(fst, opts)) {}
// See Fst<>::Copy() for doc.
RmEpsilonFst(const RmEpsilonFst<A> &fst, bool safe = false)
: ImplToFst<Impl>(fst, safe) {}
// Get a copy of this RmEpsilonFst. See Fst<>::Copy() for further doc.
virtual RmEpsilonFst<A> *Copy(bool safe = false) const {
return new RmEpsilonFst<A>(*this, safe);
}
virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
virtual void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const {
GetImpl()->InitArcIterator(s, data);
}
private:
// Makes visible to friends.
Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
void operator=(const RmEpsilonFst<A> &fst); // disallow
};
// Specialization for RmEpsilonFst.
template<class A>
class StateIterator< RmEpsilonFst<A> >
: public CacheStateIterator< RmEpsilonFst<A> > {
public:
explicit StateIterator(const RmEpsilonFst<A> &fst)
: CacheStateIterator< RmEpsilonFst<A> >(fst, fst.GetImpl()) {}
};
// Specialization for RmEpsilonFst.
template <class A>
class ArcIterator< RmEpsilonFst<A> >
: public CacheArcIterator< RmEpsilonFst<A> > {
public:
typedef typename A::StateId StateId;
ArcIterator(const RmEpsilonFst<A> &fst, StateId s)
: CacheArcIterator< RmEpsilonFst<A> >(fst.GetImpl(), s) {
if (!fst.GetImpl()->HasArcs(s))
fst.GetImpl()->Expand(s);
}
private:
DISALLOW_COPY_AND_ASSIGN(ArcIterator);
};
template <class A> inline
void RmEpsilonFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
data->base = new StateIterator< RmEpsilonFst<A> >(*this);
}
// Useful alias when using StdArc.
typedef RmEpsilonFst<StdArc> StdRmEpsilonFst;
} // namespace fst
#endif // FST_LIB_RMEPSILON_H__