blob: 24b169f4d558dc2ea0de0eac10bef201d9e5e2c1 [file] [log] [blame]
// compose.h
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//
// \file
// Class to compute the composition of two FSTs
#ifndef FST_LIB_COMPOSE_H__
#define FST_LIB_COMPOSE_H__
#include <algorithm>
#include <ext/hash_map>
using __gnu_cxx::hash_map;
#include "fst/lib/cache.h"
#include "fst/lib/test-properties.h"
namespace fst {
// Enumeration of uint64 bits used to represent the user-defined
// properties of FST composition (in the template parameter to
// ComposeFstOptions<T>). The bits stand for extensions of generic FST
// composition. ComposeFstOptions<> (all the bits unset) is the "plain"
// compose without any extra extensions.
enum ComposeTypes {
// RHO: flags dealing with a special "rest" symbol in the FSTs.
// NB: at most one of the bits COMPOSE_FST1_RHO, COMPOSE_FST2_RHO
// may be set.
COMPOSE_FST1_RHO = 1ULL<<0, // "Rest" symbol on the output side of fst1.
COMPOSE_FST2_RHO = 1ULL<<1, // "Rest" symbol on the input side of fst2.
COMPOSE_FST1_PHI = 1ULL<<2, // "Failure" symbol on the output
// side of fst1.
COMPOSE_FST2_PHI = 1ULL<<3, // "Failure" symbol on the input side
// of fst2.
COMPOSE_FST1_SIGMA = 1ULL<<4, // "Any" symbol on the output side of
// fst1.
COMPOSE_FST2_SIGMA = 1ULL<<5, // "Any" symbol on the input side of
// fst2.
// Optimization related bits.
COMPOSE_GENERIC = 1ULL<<32, // Disables optimizations, applies
// the generic version of the
// composition algorithm. This flag
// is used for internal testing
// only.
// -----------------------------------------------------------------
// Auxiliary enum values denoting specific combinations of
// bits. Internal use only.
COMPOSE_RHO = COMPOSE_FST1_RHO | COMPOSE_FST2_RHO,
COMPOSE_PHI = COMPOSE_FST1_PHI | COMPOSE_FST2_PHI,
COMPOSE_SIGMA = COMPOSE_FST1_SIGMA | COMPOSE_FST2_SIGMA,
COMPOSE_SPECIAL_SYMBOLS = COMPOSE_RHO | COMPOSE_PHI | COMPOSE_SIGMA,
// -----------------------------------------------------------------
// The following bits, denoting specific optimizations, are
// typically set *internally* by the composition algorithm.
COMPOSE_FST1_STRING = 1ULL<<33, // fst1 is a string
COMPOSE_FST2_STRING = 1ULL<<34, // fst2 is a string
COMPOSE_FST1_DET = 1ULL<<35, // fst1 is deterministic
COMPOSE_FST2_DET = 1ULL<<36, // fst2 is deterministic
COMPOSE_INTERNAL_MASK = 0xffffffff00000000ULL
};
template <uint64 T = 0ULL>
struct ComposeFstOptions : public CacheOptions {
explicit ComposeFstOptions(const CacheOptions &opts) : CacheOptions(opts) {}
ComposeFstOptions() { }
};
// Abstract base for the implementation of delayed ComposeFst. The
// concrete specializations are templated on the (uint64-valued)
// properties of the FSTs being composed.
template <class A>
class ComposeFstImplBase : public CacheImpl<A> {
public:
using FstImpl<A>::SetType;
using FstImpl<A>::SetProperties;
using FstImpl<A>::Properties;
using FstImpl<A>::SetInputSymbols;
using FstImpl<A>::SetOutputSymbols;
using CacheBaseImpl< CacheState<A> >::HasStart;
using CacheBaseImpl< CacheState<A> >::HasFinal;
using CacheBaseImpl< CacheState<A> >::HasArcs;
typedef typename A::Label Label;
typedef typename A::Weight Weight;
typedef typename A::StateId StateId;
typedef CacheState<A> State;
ComposeFstImplBase(const Fst<A> &fst1,
const Fst<A> &fst2,
const CacheOptions &opts)
:CacheImpl<A>(opts), fst1_(fst1.Copy()), fst2_(fst2.Copy()) {
SetType("compose");
uint64 props1 = fst1.Properties(kFstProperties, false);
uint64 props2 = fst2.Properties(kFstProperties, false);
SetProperties(ComposeProperties(props1, props2), kCopyProperties);
if (!CompatSymbols(fst2.InputSymbols(), fst1.OutputSymbols()))
LOG(FATAL) << "ComposeFst: output symbol table of 1st argument "
<< "does not match input symbol table of 2nd argument";
SetInputSymbols(fst1.InputSymbols());
SetOutputSymbols(fst2.OutputSymbols());
}
virtual ~ComposeFstImplBase() {
delete fst1_;
delete fst2_;
}
StateId Start() {
if (!HasStart()) {
StateId start = ComputeStart();
if (start != kNoStateId) {
SetStart(start);
}
}
return CacheImpl<A>::Start();
}
Weight Final(StateId s) {
if (!HasFinal(s)) {
Weight final = ComputeFinal(s);
SetFinal(s, final);
}
return CacheImpl<A>::Final(s);
}
virtual void Expand(StateId s) = 0;
size_t NumArcs(StateId s) {
if (!HasArcs(s))
Expand(s);
return CacheImpl<A>::NumArcs(s);
}
size_t NumInputEpsilons(StateId s) {
if (!HasArcs(s))
Expand(s);
return CacheImpl<A>::NumInputEpsilons(s);
}
size_t NumOutputEpsilons(StateId s) {
if (!HasArcs(s))
Expand(s);
return CacheImpl<A>::NumOutputEpsilons(s);
}
void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
if (!HasArcs(s))
Expand(s);
CacheImpl<A>::InitArcIterator(s, data);
}
// Access to flags encoding compose options/optimizations etc. (for
// debugging).
virtual uint64 ComposeFlags() const = 0;
protected:
virtual StateId ComputeStart() = 0;
virtual Weight ComputeFinal(StateId s) = 0;
const Fst<A> *fst1_; // first input Fst
const Fst<A> *fst2_; // second input Fst
};
// The following class encapsulates implementation-dependent details
// of state tuple lookup, i.e. a bijective mapping from triples of two
// FST states and an epsilon filter state to the corresponding state
// IDs of the fst resulting from composition. The mapping must
// implement the [] operator in the style of STL associative
// containers (map, hash_map), i.e. table[x] must return a reference
// to the value associated with x. If x is an unassigned tuple, the
// operator must automatically associate x with value 0.
//
// NB: "table[x] == 0" for unassigned tuples x is required by the
// following off-by-one device used in the implementation of
// ComposeFstImpl. The value stored in the table is equal to tuple ID
// plus one, i.e. it is always a strictly positive number. Therefore,
// table[x] is equal to 0 if and only if x is an unassigned tuple (in
// which the algorithm assigns a new ID to x, and sets table[x] -
// stored in a reference - to "new ID + 1"). This form of lookup is
// more efficient than calling "find(x)" and "insert(make_pair(x, new
// ID))" if x is an unassigned tuple.
//
// The generic implementation is a wrapper around a hash_map.
template <class A, uint64 T>
class ComposeStateTable {
public:
typedef typename A::StateId StateId;
struct StateTuple {
StateTuple() {}
StateTuple(StateId s1, StateId s2, int f)
: state_id1(s1), state_id2(s2), filt(f) {}
StateId state_id1; // state Id on fst1
StateId state_id2; // state Id on fst2
int filt; // epsilon filter state
};
ComposeStateTable() {
StateTuple empty_tuple(kNoStateId, kNoStateId, 0);
}
// NB: if 'tuple' is not in 'table_', the pair (tuple, StateId()) is
// inserted into 'table_' (standard STL container semantics). Since
// StateId is a built-in type, the explicit default constructor call
// StateId() returns 0.
StateId &operator[](const StateTuple &tuple) {
return table_[tuple];
}
private:
// Comparison object for hashing StateTuple(s).
class StateTupleEqual {
public:
bool operator()(const StateTuple& x, const StateTuple& y) const {
return x.state_id1 == y.state_id1 &&
x.state_id2 == y.state_id2 &&
x.filt == y.filt;
}
};
static const int kPrime0 = 7853;
static const int kPrime1 = 7867;
// Hash function for StateTuple to Fst states.
class StateTupleKey {
public:
size_t operator()(const StateTuple& x) const {
return static_cast<size_t>(x.state_id1 +
x.state_id2 * kPrime0 +
x.filt * kPrime1);
}
};
// Lookup table mapping state tuples to state IDs.
typedef hash_map<StateTuple,
StateId,
StateTupleKey,
StateTupleEqual> StateTable;
// Actual table data.
StateTable table_;
DISALLOW_EVIL_CONSTRUCTORS(ComposeStateTable);
};
// State tuple lookup table for the composition of a string FST with a
// deterministic FST. The class maps state tuples to their unique IDs
// (i.e. states of the ComposeFst). Main optimization: due to the
// 1-to-1 correspondence between the states of the input string FST
// and those of the resulting (string) FST, a state tuple (s1, s2) is
// simply mapped to StateId s1. Hence, we use an STL vector as a
// lookup table. Template argument Fst1IsString specifies which FST is
// a string (this determines whether or not we index the lookup table
// by the first or by the second state).
template <class A, bool Fst1IsString>
class StringDetComposeStateTable {
public:
typedef typename A::StateId StateId;
struct StateTuple {
typedef typename A::StateId StateId;
StateTuple() {}
StateTuple(StateId s1, StateId s2, int /* f */)
: state_id1(s1), state_id2(s2) {}
StateId state_id1; // state Id on fst1
StateId state_id2; // state Id on fst2
static const int filt = 0; // 'fake' epsilon filter - only needed
// for API compatibility
};
StringDetComposeStateTable() {}
// Subscript operator. Behaves in a way similar to its map/hash_map
// counterpart, i.e. returns a reference to the value associated
// with 'tuple', inserting a 0 value if 'tuple' is unassigned.
StateId &operator[](const StateTuple &tuple) {
StateId index = Fst1IsString ? tuple.state_id1 : tuple.state_id2;
if (index >= (StateId)data_.size()) {
// NB: all values in [old_size; index] are initialized to 0.
data_.resize(index + 1);
}
return data_[index];
}
private:
vector<StateId> data_;
DISALLOW_EVIL_CONSTRUCTORS(StringDetComposeStateTable);
};
// Specializations of ComposeStateTable for the string/det case.
// Both inherit from StringDetComposeStateTable.
template <class A>
class ComposeStateTable<A, COMPOSE_FST1_STRING | COMPOSE_FST2_DET>
: public StringDetComposeStateTable<A, true> { };
template <class A>
class ComposeStateTable<A, COMPOSE_FST2_STRING | COMPOSE_FST1_DET>
: public StringDetComposeStateTable<A, false> { };
// Parameterized implementation of FST composition for a pair of FSTs
// matching the property bit vector T. If possible,
// instantiation-specific switches in the code are based on the values
// of the bits in T, which are known at compile time, so unused code
// should be optimized away by the compiler.
template <class A, uint64 T>
class ComposeFstImpl : public ComposeFstImplBase<A> {
typedef typename A::StateId StateId;
typedef typename A::Label Label;
typedef typename A::Weight Weight;
using FstImpl<A>::SetType;
using FstImpl<A>::SetProperties;
enum FindType { FIND_INPUT = 1, // find input label on fst2
FIND_OUTPUT = 2, // find output label on fst1
FIND_BOTH = 3 }; // find choice state dependent
typedef ComposeStateTable<A, T & COMPOSE_INTERNAL_MASK> StateTupleTable;
typedef typename StateTupleTable::StateTuple StateTuple;
public:
ComposeFstImpl(const Fst<A> &fst1,
const Fst<A> &fst2,
const CacheOptions &opts)
:ComposeFstImplBase<A>(fst1, fst2, opts) {
bool osorted = fst1.Properties(kOLabelSorted, false);
bool isorted = fst2.Properties(kILabelSorted, false);
switch (T & COMPOSE_SPECIAL_SYMBOLS) {
case COMPOSE_FST1_RHO:
case COMPOSE_FST1_PHI:
case COMPOSE_FST1_SIGMA:
if (!osorted || FLAGS_fst_verify_properties)
osorted = fst1.Properties(kOLabelSorted, true);
if (!osorted)
LOG(FATAL) << "ComposeFst: 1st argument not output label "
<< "sorted (special symbols present)";
break;
case COMPOSE_FST2_RHO:
case COMPOSE_FST2_PHI:
case COMPOSE_FST2_SIGMA:
if (!isorted || FLAGS_fst_verify_properties)
isorted = fst2.Properties(kILabelSorted, true);
if (!isorted)
LOG(FATAL) << "ComposeFst: 2nd argument not input label "
<< "sorted (special symbols present)";
break;
case 0:
if (!isorted && !osorted || FLAGS_fst_verify_properties) {
osorted = fst1.Properties(kOLabelSorted, true);
if (!osorted)
isorted = fst2.Properties(kILabelSorted, true);
}
break;
default:
LOG(FATAL)
<< "ComposeFst: More than one special symbol used in composition";
}
if (isorted && (T & COMPOSE_FST2_SIGMA)) {
find_type_ = FIND_INPUT;
} else if (osorted && (T & COMPOSE_FST1_SIGMA)) {
find_type_ = FIND_OUTPUT;
} else if (isorted && (T & COMPOSE_FST2_PHI)) {
find_type_ = FIND_INPUT;
} else if (osorted && (T & COMPOSE_FST1_PHI)) {
find_type_ = FIND_OUTPUT;
} else if (isorted && (T & COMPOSE_FST2_RHO)) {
find_type_ = FIND_INPUT;
} else if (osorted && (T & COMPOSE_FST1_RHO)) {
find_type_ = FIND_OUTPUT;
} else if (isorted && (T & COMPOSE_FST1_STRING)) {
find_type_ = FIND_INPUT;
} else if(osorted && (T & COMPOSE_FST2_STRING)) {
find_type_ = FIND_OUTPUT;
} else if (isorted && osorted) {
find_type_ = FIND_BOTH;
} else if (isorted) {
find_type_ = FIND_INPUT;
} else if (osorted) {
find_type_ = FIND_OUTPUT;
} else {
LOG(FATAL) << "ComposeFst: 1st argument not output label sorted "
<< "and 2nd argument is not input label sorted";
}
}
// Finds/creates an Fst state given a StateTuple. Only creates a new
// state if StateTuple is not found in the state hash.
//
// The method exploits the following device: all pairs stored in the
// associative container state_tuple_table_ are of the form (tuple,
// id(tuple) + 1), i.e. state_tuple_table_[tuple] > 0 if tuple has
// been stored previously. For unassigned tuples, the call to
// state_tuple_table_[tuple] creates a new pair (tuple, 0). As a
// result, state_tuple_table_[tuple] == 0 iff tuple is new.
StateId FindState(const StateTuple& tuple) {
StateId &assoc_value = state_tuple_table_[tuple];
if (assoc_value == 0) { // tuple wasn't present in lookup table:
// assign it a new ID.
state_tuples_.push_back(tuple);
assoc_value = state_tuples_.size();
}
return assoc_value - 1; // NB: assoc_value = ID + 1
}
// Generates arc for composition state s from matched input Fst arcs.
void AddArc(StateId s, const A &arca, const A &arcb, int f,
bool find_input) {
A arc;
if (find_input) {
arc.ilabel = arcb.ilabel;
arc.olabel = arca.olabel;
arc.weight = Times(arcb.weight, arca.weight);
StateTuple tuple(arcb.nextstate, arca.nextstate, f);
arc.nextstate = FindState(tuple);
} else {
arc.ilabel = arca.ilabel;
arc.olabel = arcb.olabel;
arc.weight = Times(arca.weight, arcb.weight);
StateTuple tuple(arca.nextstate, arcb.nextstate, f);
arc.nextstate = FindState(tuple);
}
CacheImpl<A>::AddArc(s, arc);
}
// Arranges it so that the first arg to OrderedExpand is the Fst
// that will be passed to FindLabel.
void Expand(StateId s) {
StateTuple &tuple = state_tuples_[s];
StateId s1 = tuple.state_id1;
StateId s2 = tuple.state_id2;
int f = tuple.filt;
if (find_type_ == FIND_INPUT)
OrderedExpand(s, ComposeFstImplBase<A>::fst2_, s2,
ComposeFstImplBase<A>::fst1_, s1, f, true);
else
OrderedExpand(s, ComposeFstImplBase<A>::fst1_, s1,
ComposeFstImplBase<A>::fst2_, s2, f, false);
}
// Access to flags encoding compose options/optimizations etc. (for
// debugging).
virtual uint64 ComposeFlags() const { return T; }
private:
// This does that actual matching of labels in the composition. The
// arguments are ordered so FindLabel is called with state SA of
// FSTA for each arc leaving state SB of FSTB. The FIND_INPUT arg
// determines whether the input or output label of arcs at SB is
// the one to match on.
void OrderedExpand(StateId s, const Fst<A> *fsta, StateId sa,
const Fst<A> *fstb, StateId sb, int f, bool find_input) {
size_t numarcsa = fsta->NumArcs(sa);
size_t numepsa = find_input ? fsta->NumInputEpsilons(sa) :
fsta->NumOutputEpsilons(sa);
bool finala = fsta->Final(sa) != Weight::Zero();
ArcIterator< Fst<A> > aitera(*fsta, sa);
// First handle special epsilons and sigmas on FSTA
for (; !aitera.Done(); aitera.Next()) {
const A &arca = aitera.Value();
Label match_labela = find_input ? arca.ilabel : arca.olabel;
if (match_labela > 0) {
break;
}
if ((T & COMPOSE_SIGMA) != 0 && match_labela == kSigmaLabel) {
// Found a sigma? Match it against all (non-special) symbols
// on side b.
for (ArcIterator< Fst<A> > aiterb(*fstb, sb);
!aiterb.Done();
aiterb.Next()) {
const A &arcb = aiterb.Value();
Label labelb = find_input ? arcb.olabel : arcb.ilabel;
if (labelb <= 0) continue;
AddArc(s, arca, arcb, 0, find_input);
}
} else if (f == 0 && match_labela == 0) {
A earcb(0, 0, Weight::One(), sb);
AddArc(s, arca, earcb, 0, find_input); // move forward on epsilon
}
}
// Next handle non-epsilon matches, rho labels, and epsilons on FSTB
for (ArcIterator< Fst<A> > aiterb(*fstb, sb);
!aiterb.Done();
aiterb.Next()) {
const A &arcb = aiterb.Value();
Label match_labelb = find_input ? arcb.olabel : arcb.ilabel;
if (match_labelb) { // Consider non-epsilon match
if (FindLabel(&aitera, numarcsa, match_labelb, find_input)) {
for (; !aitera.Done(); aitera.Next()) {
const A &arca = aitera.Value();
Label match_labela = find_input ? arca.ilabel : arca.olabel;
if (match_labela != match_labelb)
break;
AddArc(s, arca, arcb, 0, find_input); // move forward on match
}
} else if ((T & COMPOSE_SPECIAL_SYMBOLS) != 0) {
// If there is no transition labelled 'match_labelb' in
// fsta, try matching 'match_labelb' against special symbols
// (Phi, Rho,...).
for (aitera.Reset(); !aitera.Done(); aitera.Next()) {
A arca = aitera.Value();
Label labela = find_input ? arca.ilabel : arca.olabel;
if (labela >= 0) {
break;
} else if (((T & COMPOSE_PHI) != 0) && (labela == kPhiLabel)) {
// Case 1: if a failure transition exists, follow its
// transitive closure until a) a transition labelled
// 'match_labelb' is found, or b) the initial state of
// fsta is reached.
StateId sf = sa; // Start of current failure transition.
while (labela == kPhiLabel && sf != arca.nextstate) {
sf = arca.nextstate;
size_t numarcsf = fsta->NumArcs(sf);
ArcIterator< Fst<A> > aiterf(*fsta, sf);
if (FindLabel(&aiterf, numarcsf, match_labelb, find_input)) {
// Sub-case 1a: there exists a transition starting
// in sf and consuming symbol 'match_labelb'.
AddArc(s, aiterf.Value(), arcb, 0, find_input);
break;
} else {
// No transition labelled 'match_labelb' found: try
// next failure transition (starting at 'sf').
for (aiterf.Reset(); !aiterf.Done(); aiterf.Next()) {
arca = aiterf.Value();
labela = find_input ? arca.ilabel : arca.olabel;
if (labela >= kPhiLabel) break;
}
}
}
if (labela == kPhiLabel && sf == arca.nextstate) {
// Sub-case 1b: failure transitions lead to start
// state without finding a matching
// transition. Therefore, we generate a loop in start
// state of fsta.
A loop(match_labelb, match_labelb, Weight::One(), sf);
AddArc(s, loop, arcb, 0, find_input);
}
} else if (((T & COMPOSE_RHO) != 0) && (labela == kRhoLabel)) {
// Case 2: 'match_labelb' can be matched against a
// "rest" (rho) label in fsta.
if (find_input) {
arca.ilabel = match_labelb;
if (arca.olabel == kRhoLabel)
arca.olabel = match_labelb;
} else {
arca.olabel = match_labelb;
if (arca.ilabel == kRhoLabel)
arca.ilabel = match_labelb;
}
AddArc(s, arca, arcb, 0, find_input); // move fwd on match
}
}
}
} else if (numepsa != numarcsa || finala) { // Handle FSTB epsilon
A earca(0, 0, Weight::One(), sa);
AddArc(s, earca, arcb, numepsa > 0, find_input); // move on epsilon
}
}
SetArcs(s);
}
// Finds matches to MATCH_LABEL in arcs given by AITER
// using FIND_INPUT to determine whether to look on input or output.
bool FindLabel(ArcIterator< Fst<A> > *aiter, size_t numarcs,
Label match_label, bool find_input) {
// binary search for match
size_t low = 0;
size_t high = numarcs;
while (low < high) {
size_t mid = (low + high) / 2;
aiter->Seek(mid);
Label label = find_input ?
aiter->Value().ilabel : aiter->Value().olabel;
if (label > match_label) {
high = mid;
} else if (label < match_label) {
low = mid + 1;
} else {
// find first matching label (when non-determinism)
for (size_t i = mid; i > low; --i) {
aiter->Seek(i - 1);
label = find_input ? aiter->Value().ilabel : aiter->Value().olabel;
if (label != match_label) {
aiter->Seek(i);
return true;
}
}
return true;
}
}
return false;
}
StateId ComputeStart() {
StateId s1 = ComposeFstImplBase<A>::fst1_->Start();
StateId s2 = ComposeFstImplBase<A>::fst2_->Start();
if (s1 == kNoStateId || s2 == kNoStateId)
return kNoStateId;
StateTuple tuple(s1, s2, 0);
return FindState(tuple);
}
Weight ComputeFinal(StateId s) {
StateTuple &tuple = state_tuples_[s];
Weight final = Times(ComposeFstImplBase<A>::fst1_->Final(tuple.state_id1),
ComposeFstImplBase<A>::fst2_->Final(tuple.state_id2));
return final;
}
FindType find_type_; // find label on which side?
// Maps from StateId to StateTuple.
vector<StateTuple> state_tuples_;
// Maps from StateTuple to StateId.
StateTupleTable state_tuple_table_;
DISALLOW_EVIL_CONSTRUCTORS(ComposeFstImpl);
};
// Computes the composition of two transducers. This version is a
// delayed Fst. If FST1 transduces string x to y with weight a and FST2
// transduces y to z with weight b, then their composition transduces
// string x to z with weight Times(x, z).
//
// The output labels of the first transducer or the input labels of
// the second transducer must be sorted. The weights need to form a
// commutative semiring (valid for TropicalWeight and LogWeight).
//
// Complexity:
// Assuming the first FST is unsorted and the second is sorted:
// - Time: O(v1 v2 d1 (log d2 + m2)),
// - Space: O(v1 v2)
// where vi = # of states visited, di = maximum out-degree, and mi the
// maximum multiplicity of the states visited for the ith
// FST. Constant time and space to visit an input state or arc is
// assumed and exclusive of caching.
//
// Caveats:
// - ComposeFst does not trim its output (since it is a delayed operation).
// - The efficiency of composition can be strongly affected by several factors:
// - the choice of which tnansducer is sorted - prefer sorting the FST
// that has the greater average out-degree.
// - the amount of non-determinism
// - the presence and location of epsilon transitions - avoid epsilon
// transitions on the output side of the first transducer or
// the input side of the second transducer or prefer placing
// them later in a path since they delay matching and can
// introduce non-coaccessible states and transitions.
template <class A>
class ComposeFst : public Fst<A> {
public:
friend class ArcIterator< ComposeFst<A> >;
friend class CacheStateIterator< ComposeFst<A> >;
friend class CacheArcIterator< ComposeFst<A> >;
typedef A Arc;
typedef typename A::Weight Weight;
typedef typename A::StateId StateId;
typedef CacheState<A> State;
ComposeFst(const Fst<A> &fst1, const Fst<A> &fst2)
: impl_(Init(fst1, fst2, ComposeFstOptions<>())) { }
template <uint64 T>
ComposeFst(const Fst<A> &fst1,
const Fst<A> &fst2,
const ComposeFstOptions<T> &opts)
: impl_(Init(fst1, fst2, opts)) { }
ComposeFst(const ComposeFst<A> &fst) : Fst<A>(fst), impl_(fst.impl_) {
impl_->IncrRefCount();
}
virtual ~ComposeFst() { if (!impl_->DecrRefCount()) delete impl_; }
virtual StateId Start() const { return impl_->Start(); }
virtual Weight Final(StateId s) const { return impl_->Final(s); }
virtual size_t NumArcs(StateId s) const { return impl_->NumArcs(s); }
virtual size_t NumInputEpsilons(StateId s) const {
return impl_->NumInputEpsilons(s);
}
virtual size_t NumOutputEpsilons(StateId s) const {
return impl_->NumOutputEpsilons(s);
}
virtual uint64 Properties(uint64 mask, bool test) const {
if (test) {
uint64 known, test = TestProperties(*this, mask, &known);
impl_->SetProperties(test, known);
return test & mask;
} else {
return impl_->Properties(mask);
}
}
virtual const string& Type() const { return impl_->Type(); }
virtual ComposeFst<A> *Copy() const {
return new ComposeFst<A>(*this);
}
virtual const SymbolTable* InputSymbols() const {
return impl_->InputSymbols();
}
virtual const SymbolTable* OutputSymbols() const {
return impl_->OutputSymbols();
}
virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
impl_->InitArcIterator(s, data);
}
// Access to flags encoding compose options/optimizations etc. (for
// debugging).
uint64 ComposeFlags() const { return impl_->ComposeFlags(); }
protected:
ComposeFstImplBase<A> *Impl() { return impl_; }
private:
ComposeFstImplBase<A> *impl_;
// Auxiliary method encapsulating the creation of a ComposeFst
// implementation that is appropriate for the properties of fst1 and
// fst2.
template <uint64 T>
static ComposeFstImplBase<A> *Init(
const Fst<A> &fst1,
const Fst<A> &fst2,
const ComposeFstOptions<T> &opts) {
// Filter for sort properties (forces a property check).
uint64 sort_props_mask = kILabelSorted | kOLabelSorted;
// Filter for optimization-related properties (does not force a
// property-check).
uint64 opt_props_mask =
kString | kIDeterministic | kODeterministic | kNoIEpsilons |
kNoOEpsilons;
uint64 props1 = fst1.Properties(sort_props_mask, true);
uint64 props2 = fst2.Properties(sort_props_mask, true);
props1 |= fst1.Properties(opt_props_mask, false);
props2 |= fst2.Properties(opt_props_mask, false);
if (!(Weight::Properties() & kCommutative)) {
props1 |= fst1.Properties(kUnweighted, true);
props2 |= fst2.Properties(kUnweighted, true);
if (!(props1 & kUnweighted) && !(props2 & kUnweighted))
LOG(FATAL) << "ComposeFst: Weight needs to be a commutative semiring: "
<< Weight::Type();
}
// Case 1: flag COMPOSE_GENERIC disables optimizations.
if (T & COMPOSE_GENERIC) {
return new ComposeFstImpl<A, T>(fst1, fst2, opts);
}
const uint64 kStringDetOptProps =
kIDeterministic | kILabelSorted | kNoIEpsilons;
const uint64 kDetStringOptProps =
kODeterministic | kOLabelSorted | kNoOEpsilons;
// Case 2: fst1 is a string, fst2 is deterministic and epsilon-free.
if ((props1 & kString) &&
!(T & (COMPOSE_FST1_RHO | COMPOSE_FST1_PHI | COMPOSE_FST1_SIGMA)) &&
((props2 & kStringDetOptProps) == kStringDetOptProps)) {
return new ComposeFstImpl<A, T | COMPOSE_FST1_STRING | COMPOSE_FST2_DET>(
fst1, fst2, opts);
}
// Case 3: fst1 is deterministic and epsilon-free, fst2 is string.
if ((props2 & kString) &&
!(T & (COMPOSE_FST1_RHO | COMPOSE_FST1_PHI | COMPOSE_FST1_SIGMA)) &&
((props1 & kDetStringOptProps) == kDetStringOptProps)) {
return new ComposeFstImpl<A, T | COMPOSE_FST2_STRING | COMPOSE_FST1_DET>(
fst1, fst2, opts);
}
// Default case: no optimizations.
return new ComposeFstImpl<A, T>(fst1, fst2, opts);
}
void operator=(const ComposeFst<A> &fst); // disallow
};
// Specialization for ComposeFst.
template<class A>
class StateIterator< ComposeFst<A> >
: public CacheStateIterator< ComposeFst<A> > {
public:
explicit StateIterator(const ComposeFst<A> &fst)
: CacheStateIterator< ComposeFst<A> >(fst) {}
};
// Specialization for ComposeFst.
template <class A>
class ArcIterator< ComposeFst<A> >
: public CacheArcIterator< ComposeFst<A> > {
public:
typedef typename A::StateId StateId;
ArcIterator(const ComposeFst<A> &fst, StateId s)
: CacheArcIterator< ComposeFst<A> >(fst, s) {
if (!fst.impl_->HasArcs(s))
fst.impl_->Expand(s);
}
private:
DISALLOW_EVIL_CONSTRUCTORS(ArcIterator);
};
template <class A> inline
void ComposeFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
data->base = new StateIterator< ComposeFst<A> >(*this);
}
// Useful alias when using StdArc.
typedef ComposeFst<StdArc> StdComposeFst;
struct ComposeOptions {
bool connect; // Connect output
ComposeOptions(bool c) : connect(c) {}
ComposeOptions() : connect(true) { }
};
// Computes the composition of two transducers. This version writes
// the composed FST into a MurableFst. If FST1 transduces string x to
// y with weight a and FST2 transduces y to z with weight b, then
// their composition transduces string x to z with weight
// Times(x, z).
//
// The output labels of the first transducer or the input labels of
// the second transducer must be sorted. The weights need to form a
// commutative semiring (valid for TropicalWeight and LogWeight).
//
// Complexity:
// Assuming the first FST is unsorted and the second is sorted:
// - Time: O(V1 V2 D1 (log D2 + M2)),
// - Space: O(V1 V2 D1 M2)
// where Vi = # of states, Di = maximum out-degree, and Mi is
// the maximum multiplicity for the ith FST.
//
// Caveats:
// - Compose trims its output.
// - The efficiency of composition can be strongly affected by several factors:
// - the choice of which tnansducer is sorted - prefer sorting the FST
// that has the greater average out-degree.
// - the amount of non-determinism
// - the presence and location of epsilon transitions - avoid epsilon
// transitions on the output side of the first transducer or
// the input side of the second transducer or prefer placing
// them later in a path since they delay matching and can
// introduce non-coaccessible states and transitions.
template<class Arc>
void Compose(const Fst<Arc> &ifst1, const Fst<Arc> &ifst2,
MutableFst<Arc> *ofst,
const ComposeOptions &opts = ComposeOptions()) {
ComposeFstOptions<> nopts;
nopts.gc_limit = 0; // Cache only the last state for fastest copy.
*ofst = ComposeFst<Arc>(ifst1, ifst2, nopts);
if (opts.connect)
Connect(ofst);
}
} // namespace fst
#endif // FST_LIB_COMPOSE_H__