blob: d08c0ea3b341d2e5931a7456c52bfb4cfef02ea3 [file] [log] [blame]
// replace.h
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: johans@google.com (Johan Schalkwyk)
//
// \file
// Functions and classes for the recursive replacement of Fsts.
//
#ifndef FST_LIB_REPLACE_H__
#define FST_LIB_REPLACE_H__
#include <unordered_map>
using std::tr1::unordered_map;
using std::tr1::unordered_multimap;
#include <set>
#include <string>
#include <utility>
using std::pair; using std::make_pair;
#include <vector>
using std::vector;
#include <fst/cache.h>
#include <fst/expanded-fst.h>
#include <fst/fst.h>
#include <fst/matcher.h>
#include <fst/replace-util.h>
#include <fst/state-table.h>
#include <fst/test-properties.h>
namespace fst {
//
// REPLACE STATE TUPLES AND TABLES
//
// The replace state table has the form
//
// template <class A, class P>
// class ReplaceStateTable {
// public:
// typedef A Arc;
// typedef P PrefixId;
// typedef typename A::StateId StateId;
// typedef ReplaceStateTuple<StateId, PrefixId> StateTuple;
// typedef typename A::Label Label;
//
// // Required constuctor
// ReplaceStateTable(const vector<pair<Label, const Fst<A>*> > &fst_tuples,
// Label root);
//
// // Required copy constructor that does not copy state
// ReplaceStateTable(const ReplaceStateTable<A,P> &table);
//
// // Lookup state ID by tuple. If it doesn't exist, then add it.
// StateId FindState(const StateTuple &tuple);
//
// // Lookup state tuple by ID.
// const StateTuple &Tuple(StateId id) const;
// };
// \struct ReplaceStateTuple
// \brief Tuple of information that uniquely defines a state in replace
template <class S, class P>
struct ReplaceStateTuple {
typedef S StateId;
typedef P PrefixId;
ReplaceStateTuple()
: prefix_id(-1), fst_id(kNoStateId), fst_state(kNoStateId) {}
ReplaceStateTuple(PrefixId p, StateId f, StateId s)
: prefix_id(p), fst_id(f), fst_state(s) {}
PrefixId prefix_id; // index in prefix table
StateId fst_id; // current fst being walked
StateId fst_state; // current state in fst being walked, not to be
// confused with the state_id of the combined fst
};
// Equality of replace state tuples.
template <class S, class P>
inline bool operator==(const ReplaceStateTuple<S, P>& x,
const ReplaceStateTuple<S, P>& y) {
return x.prefix_id == y.prefix_id &&
x.fst_id == y.fst_id &&
x.fst_state == y.fst_state;
}
// \class ReplaceRootSelector
// Functor returning true for tuples corresponding to states in the root FST
template <class S, class P>
class ReplaceRootSelector {
public:
bool operator()(const ReplaceStateTuple<S, P> &tuple) const {
return tuple.prefix_id == 0;
}
};
// \class ReplaceFingerprint
// Fingerprint for general replace state tuples.
template <class S, class P>
class ReplaceFingerprint {
public:
ReplaceFingerprint(const vector<uint64> *size_array)
: cumulative_size_array_(size_array) {}
uint64 operator()(const ReplaceStateTuple<S, P> &tuple) const {
return tuple.prefix_id * (cumulative_size_array_->back()) +
cumulative_size_array_->at(tuple.fst_id - 1) +
tuple.fst_state;
}
private:
const vector<uint64> *cumulative_size_array_;
};
// \class ReplaceFstStateFingerprint
// Useful when the fst_state uniquely define the tuple.
template <class S, class P>
class ReplaceFstStateFingerprint {
public:
uint64 operator()(const ReplaceStateTuple<S, P>& tuple) const {
return tuple.fst_state;
}
};
// \class ReplaceHash
// A generic hash function for replace state tuples.
template <typename S, typename P>
class ReplaceHash {
public:
size_t operator()(const ReplaceStateTuple<S, P>& t) const {
return t.prefix_id + t.fst_id * kPrime0 + t.fst_state * kPrime1;
}
private:
static const size_t kPrime0;
static const size_t kPrime1;
};
template <typename S, typename P>
const size_t ReplaceHash<S, P>::kPrime0 = 7853;
template <typename S, typename P>
const size_t ReplaceHash<S, P>::kPrime1 = 7867;
template <class A, class T> class ReplaceFstMatcher;
// \class VectorHashReplaceStateTable
// A two-level state table for replace.
// Warning: calls CountStates to compute the number of states of each
// component Fst.
template <class A, class P = ssize_t>
class VectorHashReplaceStateTable {
public:
typedef A Arc;
typedef typename A::StateId StateId;
typedef typename A::Label Label;
typedef P PrefixId;
typedef ReplaceStateTuple<StateId, P> StateTuple;
typedef VectorHashStateTable<ReplaceStateTuple<StateId, P>,
ReplaceRootSelector<StateId, P>,
ReplaceFstStateFingerprint<StateId, P>,
ReplaceFingerprint<StateId, P> > StateTable;
VectorHashReplaceStateTable(
const vector<pair<Label, const Fst<A>*> > &fst_tuples,
Label root) : root_size_(0) {
cumulative_size_array_.push_back(0);
for (size_t i = 0; i < fst_tuples.size(); ++i) {
if (fst_tuples[i].first == root) {
root_size_ = CountStates(*(fst_tuples[i].second));
cumulative_size_array_.push_back(cumulative_size_array_.back());
} else {
cumulative_size_array_.push_back(cumulative_size_array_.back() +
CountStates(*(fst_tuples[i].second)));
}
}
state_table_ = new StateTable(
new ReplaceRootSelector<StateId, P>,
new ReplaceFstStateFingerprint<StateId, P>,
new ReplaceFingerprint<StateId, P>(&cumulative_size_array_),
root_size_,
root_size_ + cumulative_size_array_.back());
}
VectorHashReplaceStateTable(const VectorHashReplaceStateTable<A, P> &table)
: root_size_(table.root_size_),
cumulative_size_array_(table.cumulative_size_array_) {
state_table_ = new StateTable(
new ReplaceRootSelector<StateId, P>,
new ReplaceFstStateFingerprint<StateId, P>,
new ReplaceFingerprint<StateId, P>(&cumulative_size_array_),
root_size_,
root_size_ + cumulative_size_array_.back());
}
~VectorHashReplaceStateTable() {
delete state_table_;
}
StateId FindState(const StateTuple &tuple) {
return state_table_->FindState(tuple);
}
const StateTuple &Tuple(StateId id) const {
return state_table_->Tuple(id);
}
private:
StateId root_size_;
vector<uint64> cumulative_size_array_;
StateTable *state_table_;
};
// \class DefaultReplaceStateTable
// Default replace state table
template <class A, class P = ssize_t>
class DefaultReplaceStateTable : public CompactHashStateTable<
ReplaceStateTuple<typename A::StateId, P>,
ReplaceHash<typename A::StateId, P> > {
public:
typedef A Arc;
typedef typename A::StateId StateId;
typedef typename A::Label Label;
typedef P PrefixId;
typedef ReplaceStateTuple<StateId, P> StateTuple;
typedef CompactHashStateTable<StateTuple,
ReplaceHash<StateId, PrefixId> > StateTable;
using StateTable::FindState;
using StateTable::Tuple;
DefaultReplaceStateTable(
const vector<pair<Label, const Fst<A>*> > &fst_tuples,
Label root) {}
DefaultReplaceStateTable(const DefaultReplaceStateTable<A, P> &table)
: StateTable() {}
};
//
// REPLACE FST CLASS
//
// By default ReplaceFst will copy the input label of the 'replace arc'.
// For acceptors we do not want this behaviour. Instead we need to
// create an epsilon arc when recursing into the appropriate Fst.
// The 'epsilon_on_replace' option can be used to toggle this behaviour.
template <class A, class T = DefaultReplaceStateTable<A> >
struct ReplaceFstOptions : CacheOptions {
int64 root; // root rule for expansion
bool epsilon_on_replace;
bool take_ownership; // take ownership of input Fst(s)
T* state_table;
ReplaceFstOptions(const CacheOptions &opts, int64 r)
: CacheOptions(opts),
root(r),
epsilon_on_replace(false),
take_ownership(false),
state_table(0) {}
explicit ReplaceFstOptions(int64 r)
: root(r),
epsilon_on_replace(false),
take_ownership(false),
state_table(0) {}
ReplaceFstOptions(int64 r, bool epsilon_replace_arc)
: root(r),
epsilon_on_replace(epsilon_replace_arc),
take_ownership(false),
state_table(0) {}
ReplaceFstOptions()
: root(kNoLabel),
epsilon_on_replace(false),
take_ownership(false),
state_table(0) {}
};
// \class ReplaceFstImpl
// \brief Implementation class for replace class Fst
//
// The replace implementation class supports a dynamic
// expansion of a recursive transition network represented as Fst
// with dynamic replacable arcs.
//
template <class A, class T>
class ReplaceFstImpl : public CacheImpl<A> {
friend class ReplaceFstMatcher<A, T>;
public:
using FstImpl<A>::SetType;
using FstImpl<A>::SetProperties;
using FstImpl<A>::WriteHeader;
using FstImpl<A>::SetInputSymbols;
using FstImpl<A>::SetOutputSymbols;
using FstImpl<A>::InputSymbols;
using FstImpl<A>::OutputSymbols;
using CacheImpl<A>::PushArc;
using CacheImpl<A>::HasArcs;
using CacheImpl<A>::HasFinal;
using CacheImpl<A>::HasStart;
using CacheImpl<A>::SetArcs;
using CacheImpl<A>::SetFinal;
using CacheImpl<A>::SetStart;
typedef typename A::Label Label;
typedef typename A::Weight Weight;
typedef typename A::StateId StateId;
typedef CacheState<A> State;
typedef A Arc;
typedef unordered_map<Label, Label> NonTerminalHash;
typedef T StateTable;
typedef typename T::PrefixId PrefixId;
typedef ReplaceStateTuple<StateId, PrefixId> StateTuple;
// constructor for replace class implementation.
// \param fst_tuples array of label/fst tuples, one for each non-terminal
ReplaceFstImpl(const vector< pair<Label, const Fst<A>* > >& fst_tuples,
const ReplaceFstOptions<A, T> &opts)
: CacheImpl<A>(opts),
epsilon_on_replace_(opts.epsilon_on_replace),
state_table_(opts.state_table ? opts.state_table :
new StateTable(fst_tuples, opts.root)) {
SetType("replace");
if (fst_tuples.size() > 0) {
SetInputSymbols(fst_tuples[0].second->InputSymbols());
SetOutputSymbols(fst_tuples[0].second->OutputSymbols());
}
bool all_negative = true; // all nonterminals are negative?
bool dense_range = true; // all nonterminals are positive
// and form a dense range containing 1?
for (size_t i = 0; i < fst_tuples.size(); ++i) {
Label nonterminal = fst_tuples[i].first;
if (nonterminal >= 0)
all_negative = false;
if (nonterminal > fst_tuples.size() || nonterminal <= 0)
dense_range = false;
}
vector<uint64> inprops;
bool all_ilabel_sorted = true;
bool all_olabel_sorted = true;
bool all_non_empty = true;
fst_array_.push_back(0);
for (size_t i = 0; i < fst_tuples.size(); ++i) {
Label label = fst_tuples[i].first;
const Fst<A> *fst = fst_tuples[i].second;
nonterminal_hash_[label] = fst_array_.size();
nonterminal_set_.insert(label);
fst_array_.push_back(opts.take_ownership ? fst : fst->Copy());
if (fst->Start() == kNoStateId)
all_non_empty = false;
if(!fst->Properties(kILabelSorted, false))
all_ilabel_sorted = false;
if(!fst->Properties(kOLabelSorted, false))
all_olabel_sorted = false;
inprops.push_back(fst->Properties(kCopyProperties, false));
if (i) {
if (!CompatSymbols(InputSymbols(), fst->InputSymbols())) {
FSTERROR() << "ReplaceFstImpl: input symbols of Fst " << i
<< " does not match input symbols of base Fst (0'th fst)";
SetProperties(kError, kError);
}
if (!CompatSymbols(OutputSymbols(), fst->OutputSymbols())) {
FSTERROR() << "ReplaceFstImpl: output symbols of Fst " << i
<< " does not match output symbols of base Fst "
<< "(0'th fst)";
SetProperties(kError, kError);
}
}
}
Label nonterminal = nonterminal_hash_[opts.root];
if ((nonterminal == 0) && (fst_array_.size() > 1)) {
FSTERROR() << "ReplaceFstImpl: no Fst corresponding to root label '"
<< opts.root << "' in the input tuple vector";
SetProperties(kError, kError);
}
root_ = (nonterminal > 0) ? nonterminal : 1;
SetProperties(ReplaceProperties(inprops, root_ - 1, epsilon_on_replace_,
all_non_empty));
// We assume that all terminals are positive. The resulting
// ReplaceFst is known to be kILabelSorted when all sub-FSTs are
// kILabelSorted and one of the 3 following conditions is satisfied:
// 1. 'epsilon_on_replace' is false, or
// 2. all non-terminals are negative, or
// 3. all non-terninals are positive and form a dense range containing 1.
if (all_ilabel_sorted &&
(!epsilon_on_replace_ || all_negative || dense_range))
SetProperties(kILabelSorted, kILabelSorted);
// Similarly, the resulting ReplaceFst is known to be
// kOLabelSorted when all sub-FSTs are kOLabelSorted and one of
// the 2 following conditions is satisfied:
// 1. all non-terminals are negative, or
// 2. all non-terninals are positive and form a dense range containing 1.
if (all_olabel_sorted && (all_negative || dense_range))
SetProperties(kOLabelSorted, kOLabelSorted);
// Enable optional caching as long as sorted and all non empty.
if (Properties(kILabelSorted | kOLabelSorted) && all_non_empty)
always_cache_ = false;
else
always_cache_ = true;
VLOG(2) << "ReplaceFstImpl::ReplaceFstImpl: always_cache = "
<< (always_cache_ ? "true" : "false");
}
ReplaceFstImpl(const ReplaceFstImpl& impl)
: CacheImpl<A>(impl),
epsilon_on_replace_(impl.epsilon_on_replace_),
always_cache_(impl.always_cache_),
state_table_(new StateTable(*(impl.state_table_))),
nonterminal_set_(impl.nonterminal_set_),
nonterminal_hash_(impl.nonterminal_hash_),
root_(impl.root_) {
SetType("replace");
SetProperties(impl.Properties(), kCopyProperties);
SetInputSymbols(impl.InputSymbols());
SetOutputSymbols(impl.OutputSymbols());
fst_array_.reserve(impl.fst_array_.size());
fst_array_.push_back(0);
for (size_t i = 1; i < impl.fst_array_.size(); ++i) {
fst_array_.push_back(impl.fst_array_[i]->Copy(true));
}
}
~ReplaceFstImpl() {
VLOG(2) << "~ReplaceFstImpl: gc = "
<< (CacheImpl<A>::GetCacheGc() ? "true" : "false")
<< ", gc_size = " << CacheImpl<A>::GetCacheSize()
<< ", gc_limit = " << CacheImpl<A>::GetCacheLimit();
delete state_table_;
for (size_t i = 1; i < fst_array_.size(); ++i) {
delete fst_array_[i];
}
}
// Computes the dependency graph of the replace class and returns
// true if the dependencies are cyclic. Cyclic dependencies will result
// in an un-expandable replace fst.
bool CyclicDependencies() const {
ReplaceUtil<A> replace_util(fst_array_, nonterminal_hash_, root_);
return replace_util.CyclicDependencies();
}
// Return or compute start state of replace fst
StateId Start() {
if (!HasStart()) {
if (fst_array_.size() == 1) { // no fsts defined for replace
SetStart(kNoStateId);
return kNoStateId;
} else {
const Fst<A>* fst = fst_array_[root_];
StateId fst_start = fst->Start();
if (fst_start == kNoStateId) // root Fst is empty
return kNoStateId;
PrefixId prefix = GetPrefixId(StackPrefix());
StateId start = state_table_->FindState(
StateTuple(prefix, root_, fst_start));
SetStart(start);
return start;
}
} else {
return CacheImpl<A>::Start();
}
}
// return final weight of state (kInfWeight means state is not final)
Weight Final(StateId s) {
if (!HasFinal(s)) {
const StateTuple& tuple = state_table_->Tuple(s);
const StackPrefix& stack = stackprefix_array_[tuple.prefix_id];
const Fst<A>* fst = fst_array_[tuple.fst_id];
StateId fst_state = tuple.fst_state;
if (fst->Final(fst_state) != Weight::Zero() && stack.Depth() == 0)
SetFinal(s, fst->Final(fst_state));
else
SetFinal(s, Weight::Zero());
}
return CacheImpl<A>::Final(s);
}
size_t NumArcs(StateId s) {
if (HasArcs(s)) { // If state cached, use the cached value.
return CacheImpl<A>::NumArcs(s);
} else if (always_cache_) { // If always caching, expand and cache state.
Expand(s);
return CacheImpl<A>::NumArcs(s);
} else { // Otherwise compute the number of arcs without expanding.
StateTuple tuple = state_table_->Tuple(s);
if (tuple.fst_state == kNoStateId)
return 0;
const Fst<A>* fst = fst_array_[tuple.fst_id];
size_t num_arcs = fst->NumArcs(tuple.fst_state);
if (ComputeFinalArc(tuple, 0))
num_arcs++;
return num_arcs;
}
}
// Returns whether a given label is a non terminal
bool IsNonTerminal(Label l) const {
// TODO(allauzen): be smarter and take advantage of
// all_dense or all_negative.
// Use also in ComputeArc, this would require changes to replace
// so that recursing into an empty fst lead to a non co-accessible
// state instead of deleting the arc as done currently.
// Current use correct, since i/olabel sorted iff all_non_empty.
typename NonTerminalHash::const_iterator it =
nonterminal_hash_.find(l);
return it != nonterminal_hash_.end();
}
size_t NumInputEpsilons(StateId s) {
if (HasArcs(s)) {
// If state cached, use the cached value.
return CacheImpl<A>::NumInputEpsilons(s);
} else if (always_cache_ || !Properties(kILabelSorted)) {
// If always caching or if the number of input epsilons is too expensive
// to compute without caching (i.e. not ilabel sorted),
// then expand and cache state.
Expand(s);
return CacheImpl<A>::NumInputEpsilons(s);
} else {
// Otherwise, compute the number of input epsilons without caching.
StateTuple tuple = state_table_->Tuple(s);
if (tuple.fst_state == kNoStateId)
return 0;
const Fst<A>* fst = fst_array_[tuple.fst_id];
size_t num = 0;
if (!epsilon_on_replace_) {
// If epsilon_on_replace is false, all input epsilon arcs
// are also input epsilons arcs in the underlying machine.
fst->NumInputEpsilons(tuple.fst_state);
} else {
// Otherwise, one need to consider that all non-terminal arcs
// in the underlying machine also become input epsilon arc.
ArcIterator<Fst<A> > aiter(*fst, tuple.fst_state);
for (; !aiter.Done() &&
((aiter.Value().ilabel == 0) ||
IsNonTerminal(aiter.Value().olabel));
aiter.Next())
++num;
}
if (ComputeFinalArc(tuple, 0))
num++;
return num;
}
}
size_t NumOutputEpsilons(StateId s) {
if (HasArcs(s)) {
// If state cached, use the cached value.
return CacheImpl<A>::NumOutputEpsilons(s);
} else if(always_cache_ || !Properties(kOLabelSorted)) {
// If always caching or if the number of output epsilons is too expensive
// to compute without caching (i.e. not olabel sorted),
// then expand and cache state.
Expand(s);
return CacheImpl<A>::NumOutputEpsilons(s);
} else {
// Otherwise, compute the number of output epsilons without caching.
StateTuple tuple = state_table_->Tuple(s);
if (tuple.fst_state == kNoStateId)
return 0;
const Fst<A>* fst = fst_array_[tuple.fst_id];
size_t num = 0;
ArcIterator<Fst<A> > aiter(*fst, tuple.fst_state);
for (; !aiter.Done() &&
((aiter.Value().olabel == 0) ||
IsNonTerminal(aiter.Value().olabel));
aiter.Next())
++num;
if (ComputeFinalArc(tuple, 0))
num++;
return num;
}
}
uint64 Properties() const { return Properties(kFstProperties); }
// Set error if found; return FST impl properties.
uint64 Properties(uint64 mask) const {
if (mask & kError) {
for (size_t i = 1; i < fst_array_.size(); ++i) {
if (fst_array_[i]->Properties(kError, false))
SetProperties(kError, kError);
}
}
return FstImpl<Arc>::Properties(mask);
}
// return the base arc iterator, if arcs have not been computed yet,
// extend/recurse for new arcs.
void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
if (!HasArcs(s))
Expand(s);
CacheImpl<A>::InitArcIterator(s, data);
// TODO(allauzen): Set behaviour of generic iterator
// Warning: ArcIterator<ReplaceFst<A> >::InitCache()
// relies on current behaviour.
}
// Extend current state (walk arcs one level deep)
void Expand(StateId s) {
StateTuple tuple = state_table_->Tuple(s);
// If local fst is empty
if (tuple.fst_state == kNoStateId) {
SetArcs(s);
return;
}
ArcIterator< Fst<A> > aiter(
*(fst_array_[tuple.fst_id]), tuple.fst_state);
Arc arc;
// Create a final arc when needed
if (ComputeFinalArc(tuple, &arc))
PushArc(s, arc);
// Expand all arcs leaving the state
for (;!aiter.Done(); aiter.Next()) {
if (ComputeArc(tuple, aiter.Value(), &arc))
PushArc(s, arc);
}
SetArcs(s);
}
void Expand(StateId s, const StateTuple &tuple,
const ArcIteratorData<A> &data) {
// If local fst is empty
if (tuple.fst_state == kNoStateId) {
SetArcs(s);
return;
}
ArcIterator< Fst<A> > aiter(data);
Arc arc;
// Create a final arc when needed
if (ComputeFinalArc(tuple, &arc))
AddArc(s, arc);
// Expand all arcs leaving the state
for (; !aiter.Done(); aiter.Next()) {
if (ComputeArc(tuple, aiter.Value(), &arc))
AddArc(s, arc);
}
SetArcs(s);
}
// If arcp == 0, only returns if a final arc is required, does not
// actually compute it.
bool ComputeFinalArc(const StateTuple &tuple, A* arcp,
uint32 flags = kArcValueFlags) {
const Fst<A>* fst = fst_array_[tuple.fst_id];
StateId fst_state = tuple.fst_state;
if (fst_state == kNoStateId)
return false;
// if state is final, pop up stack
const StackPrefix& stack = stackprefix_array_[tuple.prefix_id];
if (fst->Final(fst_state) != Weight::Zero() && stack.Depth()) {
if (arcp) {
arcp->ilabel = 0;
arcp->olabel = 0;
if (flags & kArcNextStateValue) {
PrefixId prefix_id = PopPrefix(stack);
const PrefixTuple& top = stack.Top();
arcp->nextstate = state_table_->FindState(
StateTuple(prefix_id, top.fst_id, top.nextstate));
}
if (flags & kArcWeightValue)
arcp->weight = fst->Final(fst_state);
}
return true;
} else {
return false;
}
}
// Compute the arc in the replace fst corresponding to a given
// in the underlying machine. Returns false if the underlying arc
// corresponds to no arc in the replace.
bool ComputeArc(const StateTuple &tuple, const A &arc, A* arcp,
uint32 flags = kArcValueFlags) {
if (!epsilon_on_replace_ &&
(flags == (flags & (kArcILabelValue | kArcWeightValue)))) {
*arcp = arc;
return true;
}
if (arc.olabel == 0) { // expand local fst
StateId nextstate = flags & kArcNextStateValue
? state_table_->FindState(
StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate))
: kNoStateId;
*arcp = A(arc.ilabel, arc.olabel, arc.weight, nextstate);
} else {
// check for non terminal
typename NonTerminalHash::const_iterator it =
nonterminal_hash_.find(arc.olabel);
if (it != nonterminal_hash_.end()) { // recurse into non terminal
Label nonterminal = it->second;
const Fst<A>* nt_fst = fst_array_[nonterminal];
PrefixId nt_prefix = PushPrefix(stackprefix_array_[tuple.prefix_id],
tuple.fst_id, arc.nextstate);
// if start state is valid replace, else arc is implicitly
// deleted
StateId nt_start = nt_fst->Start();
if (nt_start != kNoStateId) {
StateId nt_nextstate = flags & kArcNextStateValue
? state_table_->FindState(
StateTuple(nt_prefix, nonterminal, nt_start))
: kNoStateId;
Label ilabel = (epsilon_on_replace_) ? 0 : arc.ilabel;
*arcp = A(ilabel, 0, arc.weight, nt_nextstate);
} else {
return false;
}
} else {
StateId nextstate = flags & kArcNextStateValue
? state_table_->FindState(
StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate))
: kNoStateId;
*arcp = A(arc.ilabel, arc.olabel, arc.weight, nextstate);
}
}
return true;
}
// Returns the arc iterator flags supported by this Fst.
uint32 ArcIteratorFlags() const {
uint32 flags = kArcValueFlags;
if (!always_cache_)
flags |= kArcNoCache;
return flags;
}
T* GetStateTable() const {
return state_table_;
}
const Fst<A>* GetFst(Label fst_id) const {
return fst_array_[fst_id];
}
bool EpsilonOnReplace() const { return epsilon_on_replace_; }
// private helper classes
private:
static const size_t kPrime0;
// \class PrefixTuple
// \brief Tuple of fst_id and destination state (entry in stack prefix)
struct PrefixTuple {
PrefixTuple(Label f, StateId s) : fst_id(f), nextstate(s) {}
Label fst_id;
StateId nextstate;
};
// \class StackPrefix
// \brief Container for stack prefix.
class StackPrefix {
public:
StackPrefix() {}
// copy constructor
StackPrefix(const StackPrefix& x) :
prefix_(x.prefix_) {
}
void Push(StateId fst_id, StateId nextstate) {
prefix_.push_back(PrefixTuple(fst_id, nextstate));
}
void Pop() {
prefix_.pop_back();
}
const PrefixTuple& Top() const {
return prefix_[prefix_.size()-1];
}
size_t Depth() const {
return prefix_.size();
}
public:
vector<PrefixTuple> prefix_;
};
// \class StackPrefixEqual
// \brief Compare two stack prefix classes for equality
class StackPrefixEqual {
public:
bool operator()(const StackPrefix& x, const StackPrefix& y) const {
if (x.prefix_.size() != y.prefix_.size()) return false;
for (size_t i = 0; i < x.prefix_.size(); ++i) {
if (x.prefix_[i].fst_id != y.prefix_[i].fst_id ||
x.prefix_[i].nextstate != y.prefix_[i].nextstate) return false;
}
return true;
}
};
//
// \class StackPrefixKey
// \brief Hash function for stack prefix to prefix id
class StackPrefixKey {
public:
size_t operator()(const StackPrefix& x) const {
size_t sum = 0;
for (size_t i = 0; i < x.prefix_.size(); ++i) {
sum += x.prefix_[i].fst_id + x.prefix_[i].nextstate*kPrime0;
}
return sum;
}
};
typedef unordered_map<StackPrefix, PrefixId, StackPrefixKey, StackPrefixEqual>
StackPrefixHash;
// private methods
private:
// hash stack prefix (return unique index into stackprefix array)
PrefixId GetPrefixId(const StackPrefix& prefix) {
typename StackPrefixHash::iterator it = prefix_hash_.find(prefix);
if (it == prefix_hash_.end()) {
PrefixId prefix_id = stackprefix_array_.size();
stackprefix_array_.push_back(prefix);
prefix_hash_[prefix] = prefix_id;
return prefix_id;
} else {
return it->second;
}
}
// prefix id after a stack pop
PrefixId PopPrefix(StackPrefix prefix) {
prefix.Pop();
return GetPrefixId(prefix);
}
// prefix id after a stack push
PrefixId PushPrefix(StackPrefix prefix, Label fst_id, StateId nextstate) {
prefix.Push(fst_id, nextstate);
return GetPrefixId(prefix);
}
// private data
private:
// runtime options
bool epsilon_on_replace_;
bool always_cache_; // Optionally caching arc iterator disabled when true
// state table
StateTable *state_table_;
// cross index of unique stack prefix
// could potentially have one copy of prefix array
StackPrefixHash prefix_hash_;
vector<StackPrefix> stackprefix_array_;
set<Label> nonterminal_set_;
NonTerminalHash nonterminal_hash_;
vector<const Fst<A>*> fst_array_;
Label root_;
void operator=(const ReplaceFstImpl<A, T> &); // disallow
};
template <class A, class T>
const size_t ReplaceFstImpl<A, T>::kPrime0 = 7853;
//
// \class ReplaceFst
// \brief Recursivively replaces arcs in the root Fst with other Fsts.
// This version is a delayed Fst.
//
// ReplaceFst supports dynamic replacement of arcs in one Fst with
// another Fst. This replacement is recursive. ReplaceFst can be used
// to support a variety of delayed constructions such as recursive
// transition networks, union, or closure. It is constructed with an
// array of Fst(s). One Fst represents the root (or topology)
// machine. The root Fst refers to other Fsts by recursively replacing
// arcs labeled as non-terminals with the matching non-terminal
// Fst. Currently the ReplaceFst uses the output symbols of the arcs
// to determine whether the arc is a non-terminal arc or not. A
// non-terminal can be any label that is not a non-zero terminal label
// in the output alphabet.
//
// Note that the constructor uses a vector of pair<>. These correspond
// to the tuple of non-terminal Label and corresponding Fst. For example
// to implement the closure operation we need 2 Fsts. The first root
// Fst is a single Arc on the start State that self loops, it references
// the particular machine for which we are performing the closure operation.
//
// The ReplaceFst class supports an optionally caching arc iterator:
// ArcIterator< ReplaceFst<A> >
// The ReplaceFst need to be built such that it is known to be ilabel
// or olabel sorted (see usage below).
//
// Observe that Matcher<Fst<A> > will use the optionally caching arc
// iterator when available (Fst is ilabel sorted and matching on the
// input, or Fst is olabel sorted and matching on the output).
// In order to obtain the most efficient behaviour, it is recommended
// to set 'epsilon_on_replace' to false (this means constructing acceptors
// as transducers with epsilons on the input side of nonterminal arcs)
// and matching on the input side.
//
// This class attaches interface to implementation and handles
// reference counting, delegating most methods to ImplToFst.
template <class A, class T = DefaultReplaceStateTable<A> >
class ReplaceFst : public ImplToFst< ReplaceFstImpl<A, T> > {
public:
friend class ArcIterator< ReplaceFst<A, T> >;
friend class StateIterator< ReplaceFst<A, T> >;
friend class ReplaceFstMatcher<A, T>;
typedef A Arc;
typedef typename A::Label Label;
typedef typename A::Weight Weight;
typedef typename A::StateId StateId;
typedef CacheState<A> State;
typedef ReplaceFstImpl<A, T> Impl;
using ImplToFst<Impl>::Properties;
ReplaceFst(const vector<pair<Label, const Fst<A>* > >& fst_array,
Label root)
: ImplToFst<Impl>(new Impl(fst_array, ReplaceFstOptions<A, T>(root))) {}
ReplaceFst(const vector<pair<Label, const Fst<A>* > >& fst_array,
const ReplaceFstOptions<A, T> &opts)
: ImplToFst<Impl>(new Impl(fst_array, opts)) {}
// See Fst<>::Copy() for doc.
ReplaceFst(const ReplaceFst<A, T>& fst, bool safe = false)
: ImplToFst<Impl>(fst, safe) {}
// Get a copy of this ReplaceFst. See Fst<>::Copy() for further doc.
virtual ReplaceFst<A, T> *Copy(bool safe = false) const {
return new ReplaceFst<A, T>(*this, safe);
}
virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
GetImpl()->InitArcIterator(s, data);
}
virtual MatcherBase<A> *InitMatcher(MatchType match_type) const {
if ((GetImpl()->ArcIteratorFlags() & kArcNoCache) &&
((match_type == MATCH_INPUT && Properties(kILabelSorted, false)) ||
(match_type == MATCH_OUTPUT && Properties(kOLabelSorted, false)))) {
return new ReplaceFstMatcher<A, T>(*this, match_type);
}
else {
VLOG(2) << "Not using replace matcher";
return 0;
}
}
bool CyclicDependencies() const {
return GetImpl()->CyclicDependencies();
}
private:
// Makes visible to friends.
Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
void operator=(const ReplaceFst<A> &fst); // disallow
};
// Specialization for ReplaceFst.
template<class A, class T>
class StateIterator< ReplaceFst<A, T> >
: public CacheStateIterator< ReplaceFst<A, T> > {
public:
explicit StateIterator(const ReplaceFst<A, T> &fst)
: CacheStateIterator< ReplaceFst<A, T> >(fst, fst.GetImpl()) {}
private:
DISALLOW_COPY_AND_ASSIGN(StateIterator);
};
// Specialization for ReplaceFst.
// Implements optional caching. It can be used as follows:
//
// ReplaceFst<A> replace;
// ArcIterator< ReplaceFst<A> > aiter(replace, s);
// // Note: ArcIterator< Fst<A> > is always a caching arc iterator.
// aiter.SetFlags(kArcNoCache, kArcNoCache);
// // Use the arc iterator, no arc will be cached, no state will be expanded.
// // The varied 'kArcValueFlags' can be used to decide which part
// // of arc values needs to be computed.
// aiter.SetFlags(kArcILabelValue, kArcValueFlags);
// // Only want the ilabel for this arc
// aiter.Value(); // Does not compute the destination state.
// aiter.Next();
// aiter.SetFlags(kArcNextStateValue, kArcNextStateValue);
// // Want both ilabel and nextstate for that arc
// aiter.Value(); // Does compute the destination state and inserts it
// // in the replace state table.
// // No Arc has been cached at that point.
//
template <class A, class T>
class ArcIterator< ReplaceFst<A, T> > {
public:
typedef A Arc;
typedef typename A::StateId StateId;
ArcIterator(const ReplaceFst<A, T> &fst, StateId s)
: fst_(fst), state_(s), pos_(0), offset_(0), flags_(0), arcs_(0),
data_flags_(0), final_flags_(0) {
cache_data_.ref_count = 0;
local_data_.ref_count = 0;
// If FST does not support optional caching, force caching.
if(!(fst_.GetImpl()->ArcIteratorFlags() & kArcNoCache) &&
!(fst_.GetImpl()->HasArcs(state_)))
fst_.GetImpl()->Expand(state_);
// If state is already cached, use cached arcs array.
if (fst_.GetImpl()->HasArcs(state_)) {
(fst_.GetImpl())->template CacheImpl<A>::InitArcIterator(state_,
&cache_data_);
num_arcs_ = cache_data_.narcs;
arcs_ = cache_data_.arcs; // 'arcs_' is a ptr to the cached arcs.
data_flags_ = kArcValueFlags; // All the arc member values are valid.
} else { // Otherwise delay decision until Value() is called.
tuple_ = fst_.GetImpl()->GetStateTable()->Tuple(state_);
if (tuple_.fst_state == kNoStateId) {
num_arcs_ = 0;
} else {
// The decision to cache or not to cache has been defered
// until Value() or SetFlags() is called. However, the arc
// iterator is set up now to be ready for non-caching in order
// to keep the Value() method simple and efficient.
const Fst<A>* fst = fst_.GetImpl()->GetFst(tuple_.fst_id);
fst->InitArcIterator(tuple_.fst_state, &local_data_);
// 'arcs_' is a pointer to the arcs in the underlying machine.
arcs_ = local_data_.arcs;
// Compute the final arc (but not its destination state)
// if a final arc is required.
bool has_final_arc = fst_.GetImpl()->ComputeFinalArc(
tuple_,
&final_arc_,
kArcValueFlags & ~kArcNextStateValue);
// Set the arc value flags that hold for 'final_arc_'.
final_flags_ = kArcValueFlags & ~kArcNextStateValue;
// Compute the number of arcs.
num_arcs_ = local_data_.narcs;
if (has_final_arc)
++num_arcs_;
// Set the offset between the underlying arc positions and
// the positions in the arc iterator.
offset_ = num_arcs_ - local_data_.narcs;
// Defers the decision to cache or not until Value() or
// SetFlags() is called.
data_flags_ = 0;
}
}
}
~ArcIterator() {
if (cache_data_.ref_count)
--(*cache_data_.ref_count);
if (local_data_.ref_count)
--(*local_data_.ref_count);
}
void ExpandAndCache() const {
// TODO(allauzen): revisit this
// fst_.GetImpl()->Expand(state_, tuple_, local_data_);
// (fst_.GetImpl())->CacheImpl<A>*>::InitArcIterator(state_,
// &cache_data_);
//
fst_.InitArcIterator(state_, &cache_data_); // Expand and cache state.
arcs_ = cache_data_.arcs; // 'arcs_' is a pointer to the cached arcs.
data_flags_ = kArcValueFlags; // All the arc member values are valid.
offset_ = 0; // No offset
}
void Init() {
if (flags_ & kArcNoCache) { // If caching is disabled
// 'arcs_' is a pointer to the arcs in the underlying machine.
arcs_ = local_data_.arcs;
// Set the arcs value flags that hold for 'arcs_'.
data_flags_ = kArcWeightValue;
if (!fst_.GetImpl()->EpsilonOnReplace())
data_flags_ |= kArcILabelValue;
// Set the offset between the underlying arc positions and
// the positions in the arc iterator.
offset_ = num_arcs_ - local_data_.narcs;
} else { // Otherwise, expand and cache
ExpandAndCache();
}
}
bool Done() const { return pos_ >= num_arcs_; }
const A& Value() const {
// If 'data_flags_' was set to 0, non-caching was not requested
if (!data_flags_) {
// TODO(allauzen): revisit this.
if (flags_ & kArcNoCache) {
// Should never happen.
FSTERROR() << "ReplaceFst: inconsistent arc iterator flags";
}
ExpandAndCache(); // Expand and cache.
}
if (pos_ - offset_ >= 0) { // The requested arc is not the 'final' arc.
const A& arc = arcs_[pos_ - offset_];
if ((data_flags_ & flags_) == (flags_ & kArcValueFlags)) {
// If the value flags for 'arc' match the recquired value flags
// then return 'arc'.
return arc;
} else {
// Otherwise, compute the corresponding arc on-the-fly.
fst_.GetImpl()->ComputeArc(tuple_, arc, &arc_, flags_ & kArcValueFlags);
return arc_;
}
} else { // The requested arc is the 'final' arc.
if ((final_flags_ & flags_) != (flags_ & kArcValueFlags)) {
// If the arc value flags that hold for the final arc
// do not match the requested value flags, then
// 'final_arc_' needs to be updated.
fst_.GetImpl()->ComputeFinalArc(tuple_, &final_arc_,
flags_ & kArcValueFlags);
final_flags_ = flags_ & kArcValueFlags;
}
return final_arc_;
}
}
void Next() { ++pos_; }
size_t Position() const { return pos_; }
void Reset() { pos_ = 0; }
void Seek(size_t pos) { pos_ = pos; }
uint32 Flags() const { return flags_; }
void SetFlags(uint32 f, uint32 mask) {
// Update the flags taking into account what flags are supported
// by the Fst.
flags_ &= ~mask;
flags_ |= (f & fst_.GetImpl()->ArcIteratorFlags());
// If non-caching is not requested (and caching has not already
// been performed), then flush 'data_flags_' to request caching
// during the next call to Value().
if (!(flags_ & kArcNoCache) && data_flags_ != kArcValueFlags) {
if (!fst_.GetImpl()->HasArcs(state_))
data_flags_ = 0;
}
// If 'data_flags_' has been flushed but non-caching is requested
// before calling Value(), then set up the iterator for non-caching.
if ((f & kArcNoCache) && (!data_flags_))
Init();
}
private:
const ReplaceFst<A, T> &fst_; // Reference to the FST
StateId state_; // State in the FST
mutable typename T::StateTuple tuple_; // Tuple corresponding to state_
ssize_t pos_; // Current position
mutable ssize_t offset_; // Offset between position in iterator and in arcs_
ssize_t num_arcs_; // Number of arcs at state_
uint32 flags_; // Behavorial flags for the arc iterator
mutable Arc arc_; // Memory to temporarily store computed arcs
mutable ArcIteratorData<Arc> cache_data_; // Arc iterator data in cache
mutable ArcIteratorData<Arc> local_data_; // Arc iterator data in local fst
mutable const A* arcs_; // Array of arcs
mutable uint32 data_flags_; // Arc value flags valid for data in arcs_
mutable Arc final_arc_; // Final arc (when required)
mutable uint32 final_flags_; // Arc value flags valid for final_arc_
DISALLOW_COPY_AND_ASSIGN(ArcIterator);
};
template <class A, class T>
class ReplaceFstMatcher : public MatcherBase<A> {
public:
typedef A Arc;
typedef typename A::StateId StateId;
typedef typename A::Label Label;
typedef MultiEpsMatcher<Matcher<Fst<A> > > LocalMatcher;
ReplaceFstMatcher(const ReplaceFst<A, T> &fst, fst::MatchType match_type)
: fst_(fst),
impl_(fst_.GetImpl()),
s_(fst::kNoStateId),
match_type_(match_type),
current_loop_(false),
final_arc_(false),
loop_(fst::kNoLabel, 0, A::Weight::One(), fst::kNoStateId) {
if (match_type_ == fst::MATCH_OUTPUT)
swap(loop_.ilabel, loop_.olabel);
InitMatchers();
}
ReplaceFstMatcher(const ReplaceFstMatcher<A, T> &matcher, bool safe = false)
: fst_(matcher.fst_),
impl_(fst_.GetImpl()),
s_(fst::kNoStateId),
match_type_(matcher.match_type_),
current_loop_(false),
loop_(fst::kNoLabel, 0, A::Weight::One(), fst::kNoStateId) {
if (match_type_ == fst::MATCH_OUTPUT)
swap(loop_.ilabel, loop_.olabel);
InitMatchers();
}
// Create a local matcher for each component Fst of replace.
// LocalMatcher is a multi epsilon wrapper matcher. MultiEpsilonMatcher
// is used to match each non-terminal arc, since these non-terminal
// turn into epsilons on recursion.
void InitMatchers() {
const vector<const Fst<A>*>& fst_array = impl_->fst_array_;
matcher_.resize(fst_array.size(), 0);
for (size_t i = 0; i < fst_array.size(); ++i) {
if (fst_array[i]) {
matcher_[i] =
new LocalMatcher(*fst_array[i], match_type_, kMultiEpsList);
typename set<Label>::iterator it = impl_->nonterminal_set_.begin();
for (; it != impl_->nonterminal_set_.end(); ++it) {
matcher_[i]->AddMultiEpsLabel(*it);
}
}
}
}
virtual ReplaceFstMatcher<A, T> *Copy(bool safe = false) const {
return new ReplaceFstMatcher<A, T>(*this, safe);
}
virtual ~ReplaceFstMatcher() {
for (size_t i = 0; i < matcher_.size(); ++i)
delete matcher_[i];
}
virtual MatchType Type(bool test) const {
if (match_type_ == MATCH_NONE)
return match_type_;
uint64 true_prop = match_type_ == MATCH_INPUT ?
kILabelSorted : kOLabelSorted;
uint64 false_prop = match_type_ == MATCH_INPUT ?
kNotILabelSorted : kNotOLabelSorted;
uint64 props = fst_.Properties(true_prop | false_prop, test);
if (props & true_prop)
return match_type_;
else if (props & false_prop)
return MATCH_NONE;
else
return MATCH_UNKNOWN;
}
virtual const Fst<A> &GetFst() const {
return fst_;
}
virtual uint64 Properties(uint64 props) const {
return props;
}
private:
// Set the sate from which our matching happens.
virtual void SetState_(StateId s) {
if (s_ == s) return;
s_ = s;
tuple_ = impl_->GetStateTable()->Tuple(s_);
if (tuple_.fst_state == kNoStateId) {
done_ = true;
return;
}
// Get current matcher. Used for non epsilon matching
current_matcher_ = matcher_[tuple_.fst_id];
current_matcher_->SetState(tuple_.fst_state);
loop_.nextstate = s_;
final_arc_ = false;
}
// Search for label, from previous set state. If label == 0, first
// hallucinate and epsilon loop, else use the underlying matcher to
// search for the label or epsilons.
// - Note since the ReplaceFST recursion on non-terminal arcs causes
// epsilon transitions to be created we use the MultiEpsilonMatcher
// to search for possible matches of non terminals.
// - If the component Fst reaches a final state we also need to add
// the exiting final arc.
virtual bool Find_(Label label) {
bool found = false;
label_ = label;
if (label_ == 0 || label_ == kNoLabel) {
// Compute loop directly, saving Replace::ComputeArc
if (label_ == 0) {
current_loop_ = true;
found = true;
}
// Search for matching multi epsilons
final_arc_ = impl_->ComputeFinalArc(tuple_, 0);
found = current_matcher_->Find(kNoLabel) || final_arc_ || found;
} else {
// Search on sub machine directly using sub machine matcher.
found = current_matcher_->Find(label_);
}
return found;
}
virtual bool Done_() const {
return !current_loop_ && !final_arc_ && current_matcher_->Done();
}
virtual const Arc& Value_() const {
if (current_loop_) {
return loop_;
}
if (final_arc_) {
impl_->ComputeFinalArc(tuple_, &arc_);
return arc_;
}
const Arc& component_arc = current_matcher_->Value();
impl_->ComputeArc(tuple_, component_arc, &arc_);
return arc_;
}
virtual void Next_() {
if (current_loop_) {
current_loop_ = false;
return;
}
if (final_arc_) {
final_arc_ = false;
return;
}
current_matcher_->Next();
}
const ReplaceFst<A, T>& fst_;
ReplaceFstImpl<A, T> *impl_;
LocalMatcher* current_matcher_;
vector<LocalMatcher*> matcher_;
StateId s_; // Current state
Label label_; // Current label
MatchType match_type_; // Supplied by caller
mutable bool done_;
mutable bool current_loop_; // Current arc is the implicit loop
mutable bool final_arc_; // Current arc for exiting recursion
mutable typename T::StateTuple tuple_; // Tuple corresponding to state_
mutable Arc arc_;
Arc loop_;
};
template <class A, class T> inline
void ReplaceFst<A, T>::InitStateIterator(StateIteratorData<A> *data) const {
data->base = new StateIterator< ReplaceFst<A, T> >(*this);
}
typedef ReplaceFst<StdArc> StdReplaceFst;
// // Recursivively replaces arcs in the root Fst with other Fsts.
// This version writes the result of replacement to an output MutableFst.
//
// Replace supports replacement of arcs in one Fst with another
// Fst. This replacement is recursive. Replace takes an array of
// Fst(s). One Fst represents the root (or topology) machine. The root
// Fst refers to other Fsts by recursively replacing arcs labeled as
// non-terminals with the matching non-terminal Fst. Currently Replace
// uses the output symbols of the arcs to determine whether the arc is
// a non-terminal arc or not. A non-terminal can be any label that is
// not a non-zero terminal label in the output alphabet. Note that
// input argument is a vector of pair<>. These correspond to the tuple
// of non-terminal Label and corresponding Fst.
template<class Arc>
void Replace(const vector<pair<typename Arc::Label,
const Fst<Arc>* > >& ifst_array,
MutableFst<Arc> *ofst, typename Arc::Label root,
bool epsilon_on_replace) {
ReplaceFstOptions<Arc> opts(root, epsilon_on_replace);
opts.gc_limit = 0; // Cache only the last state for fastest copy.
*ofst = ReplaceFst<Arc>(ifst_array, opts);
}
template<class Arc>
void Replace(const vector<pair<typename Arc::Label,
const Fst<Arc>* > >& ifst_array,
MutableFst<Arc> *ofst, typename Arc::Label root) {
Replace(ifst_array, ofst, root, false);
}
} // namespace fst
#endif // FST_LIB_REPLACE_H__