blob: 3fbe3ba4e22bd39dabab8976e1bb0d9ab3c42388 [file] [log] [blame]
// minimize.h
// minimize.h
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: johans@google.com (Johan Schalkwyk)
//
// \file Functions and classes to minimize a finite state acceptor
//
#ifndef FST_LIB_MINIMIZE_H__
#define FST_LIB_MINIMIZE_H__
#include <cmath>
#include <algorithm>
#include <map>
#include <queue>
#include <vector>
using std::vector;
#include <fst/arcsort.h>
#include <fst/connect.h>
#include <fst/dfs-visit.h>
#include <fst/encode.h>
#include <fst/factor-weight.h>
#include <fst/fst.h>
#include <fst/mutable-fst.h>
#include <fst/partition.h>
#include <fst/push.h>
#include <fst/queue.h>
#include <fst/reverse.h>
#include <fst/state-map.h>
namespace fst {
// comparator for creating partition based on sorting on
// - states
// - final weight
// - out degree,
// - (input label, output label, weight, destination_block)
template <class A>
class StateComparator {
public:
typedef typename A::StateId StateId;
typedef typename A::Weight Weight;
static const uint32 kCompareFinal = 0x00000001;
static const uint32 kCompareOutDegree = 0x00000002;
static const uint32 kCompareArcs = 0x00000004;
static const uint32 kCompareAll = 0x00000007;
StateComparator(const Fst<A>& fst,
const Partition<typename A::StateId>& partition,
uint32 flags = kCompareAll)
: fst_(fst), partition_(partition), flags_(flags) {}
// compare state x with state y based on sort criteria
bool operator()(const StateId x, const StateId y) const {
// check for final state equivalence
if (flags_ & kCompareFinal) {
const size_t xfinal = fst_.Final(x).Hash();
const size_t yfinal = fst_.Final(y).Hash();
if (xfinal < yfinal) return true;
else if (xfinal > yfinal) return false;
}
if (flags_ & kCompareOutDegree) {
// check for # arcs
if (fst_.NumArcs(x) < fst_.NumArcs(y)) return true;
if (fst_.NumArcs(x) > fst_.NumArcs(y)) return false;
if (flags_ & kCompareArcs) {
// # arcs are equal, check for arc match
for (ArcIterator<Fst<A> > aiter1(fst_, x), aiter2(fst_, y);
!aiter1.Done() && !aiter2.Done(); aiter1.Next(), aiter2.Next()) {
const A& arc1 = aiter1.Value();
const A& arc2 = aiter2.Value();
if (arc1.ilabel < arc2.ilabel) return true;
if (arc1.ilabel > arc2.ilabel) return false;
if (partition_.class_id(arc1.nextstate) <
partition_.class_id(arc2.nextstate)) return true;
if (partition_.class_id(arc1.nextstate) >
partition_.class_id(arc2.nextstate)) return false;
}
}
}
return false;
}
private:
const Fst<A>& fst_;
const Partition<typename A::StateId>& partition_;
const uint32 flags_;
};
template <class A> const uint32 StateComparator<A>::kCompareFinal;
template <class A> const uint32 StateComparator<A>::kCompareOutDegree;
template <class A> const uint32 StateComparator<A>::kCompareArcs;
template <class A> const uint32 StateComparator<A>::kCompareAll;
// Computes equivalence classes for cyclic Fsts. For cyclic minimization
// we use the classic HopCroft minimization algorithm, which is of
//
// O(E)log(N),
//
// where E is the number of edges in the machine and N is number of states.
//
// The following paper describes the original algorithm
// An N Log N algorithm for minimizing states in a finite automaton
// by John HopCroft, January 1971
//
template <class A, class Queue>
class CyclicMinimizer {
public:
typedef typename A::Label Label;
typedef typename A::StateId StateId;
typedef typename A::StateId ClassId;
typedef typename A::Weight Weight;
typedef ReverseArc<A> RevA;
CyclicMinimizer(const ExpandedFst<A>& fst) {
Initialize(fst);
Compute(fst);
}
~CyclicMinimizer() {
delete aiter_queue_;
}
const Partition<StateId>& partition() const {
return P_;
}
// helper classes
private:
typedef ArcIterator<Fst<RevA> > ArcIter;
class ArcIterCompare {
public:
ArcIterCompare(const Partition<StateId>& partition)
: partition_(partition) {}
ArcIterCompare(const ArcIterCompare& comp)
: partition_(comp.partition_) {}
// compare two iterators based on there input labels, and proto state
// (partition class Ids)
bool operator()(const ArcIter* x, const ArcIter* y) const {
const RevA& xarc = x->Value();
const RevA& yarc = y->Value();
return (xarc.ilabel > yarc.ilabel);
}
private:
const Partition<StateId>& partition_;
};
typedef priority_queue<ArcIter*, vector<ArcIter*>, ArcIterCompare>
ArcIterQueue;
// helper methods
private:
// prepartitions the space into equivalence classes with
// same final weight
// same # arcs per state
// same outgoing arcs
void PrePartition(const Fst<A>& fst) {
VLOG(5) << "PrePartition";
typedef map<StateId, StateId, StateComparator<A> > EquivalenceMap;
StateComparator<A> comp(fst, P_, StateComparator<A>::kCompareFinal);
EquivalenceMap equiv_map(comp);
StateIterator<Fst<A> > siter(fst);
StateId class_id = P_.AddClass();
P_.Add(siter.Value(), class_id);
equiv_map[siter.Value()] = class_id;
L_.Enqueue(class_id);
for (siter.Next(); !siter.Done(); siter.Next()) {
StateId s = siter.Value();
typename EquivalenceMap::const_iterator it = equiv_map.find(s);
if (it == equiv_map.end()) {
class_id = P_.AddClass();
P_.Add(s, class_id);
equiv_map[s] = class_id;
L_.Enqueue(class_id);
} else {
P_.Add(s, it->second);
equiv_map[s] = it->second;
}
}
VLOG(5) << "Initial Partition: " << P_.num_classes();
}
// - Create inverse transition Tr_ = rev(fst)
// - loop over states in fst and split on final, creating two blocks
// in the partition corresponding to final, non-final
void Initialize(const Fst<A>& fst) {
// construct Tr
Reverse(fst, &Tr_);
ILabelCompare<RevA> ilabel_comp;
ArcSort(&Tr_, ilabel_comp);
// initial split (F, S - F)
P_.Initialize(Tr_.NumStates() - 1);
// prep partition
PrePartition(fst);
// allocate arc iterator queue
ArcIterCompare comp(P_);
aiter_queue_ = new ArcIterQueue(comp);
}
// partition all classes with destination C
void Split(ClassId C) {
// Prep priority queue. Open arc iterator for each state in C, and
// insert into priority queue.
for (PartitionIterator<StateId> siter(P_, C);
!siter.Done(); siter.Next()) {
StateId s = siter.Value();
if (Tr_.NumArcs(s + 1))
aiter_queue_->push(new ArcIterator<Fst<RevA> >(Tr_, s + 1));
}
// Now pop arc iterator from queue, split entering equivalence class
// re-insert updated iterator into queue.
Label prev_label = -1;
while (!aiter_queue_->empty()) {
ArcIterator<Fst<RevA> >* aiter = aiter_queue_->top();
aiter_queue_->pop();
if (aiter->Done()) {
delete aiter;
continue;
}
const RevA& arc = aiter->Value();
StateId from_state = aiter->Value().nextstate - 1;
Label from_label = arc.ilabel;
if (prev_label != from_label)
P_.FinalizeSplit(&L_);
StateId from_class = P_.class_id(from_state);
if (P_.class_size(from_class) > 1)
P_.SplitOn(from_state);
prev_label = from_label;
aiter->Next();
if (aiter->Done())
delete aiter;
else
aiter_queue_->push(aiter);
}
P_.FinalizeSplit(&L_);
}
// Main loop for hopcroft minimization.
void Compute(const Fst<A>& fst) {
// process active classes (FIFO, or FILO)
while (!L_.Empty()) {
ClassId C = L_.Head();
L_.Dequeue();
// split on C, all labels in C
Split(C);
}
}
// helper data
private:
// Partioning of states into equivalence classes
Partition<StateId> P_;
// L = set of active classes to be processed in partition P
Queue L_;
// reverse transition function
VectorFst<RevA> Tr_;
// Priority queue of open arc iterators for all states in the 'splitter'
// equivalence class
ArcIterQueue* aiter_queue_;
};
// Computes equivalence classes for acyclic Fsts. The implementation details
// for this algorithms is documented by the following paper.
//
// Minimization of acyclic deterministic automata in linear time
// Dominque Revuz
//
// Complexity O(|E|)
//
template <class A>
class AcyclicMinimizer {
public:
typedef typename A::Label Label;
typedef typename A::StateId StateId;
typedef typename A::StateId ClassId;
typedef typename A::Weight Weight;
AcyclicMinimizer(const ExpandedFst<A>& fst) {
Initialize(fst);
Refine(fst);
}
const Partition<StateId>& partition() {
return partition_;
}
// helper classes
private:
// DFS visitor to compute the height (distance) to final state.
class HeightVisitor {
public:
HeightVisitor() : max_height_(0), num_states_(0) { }
// invoked before dfs visit
void InitVisit(const Fst<A>& fst) {}
// invoked when state is discovered (2nd arg is DFS tree root)
bool InitState(StateId s, StateId root) {
// extend height array and initialize height (distance) to 0
for (size_t i = height_.size(); i <= s; ++i)
height_.push_back(-1);
if (s >= num_states_) num_states_ = s + 1;
return true;
}
// invoked when tree arc examined (to undiscoverted state)
bool TreeArc(StateId s, const A& arc) {
return true;
}
// invoked when back arc examined (to unfinished state)
bool BackArc(StateId s, const A& arc) {
return true;
}
// invoked when forward or cross arc examined (to finished state)
bool ForwardOrCrossArc(StateId s, const A& arc) {
if (height_[arc.nextstate] + 1 > height_[s])
height_[s] = height_[arc.nextstate] + 1;
return true;
}
// invoked when state finished (parent is kNoStateId for tree root)
void FinishState(StateId s, StateId parent, const A* parent_arc) {
if (height_[s] == -1) height_[s] = 0;
StateId h = height_[s] + 1;
if (parent >= 0) {
if (h > height_[parent]) height_[parent] = h;
if (h > max_height_) max_height_ = h;
}
}
// invoked after DFS visit
void FinishVisit() {}
size_t max_height() const { return max_height_; }
const vector<StateId>& height() const { return height_; }
const size_t num_states() const { return num_states_; }
private:
vector<StateId> height_;
size_t max_height_;
size_t num_states_;
};
// helper methods
private:
// cluster states according to height (distance to final state)
void Initialize(const Fst<A>& fst) {
// compute height (distance to final state)
HeightVisitor hvisitor;
DfsVisit(fst, &hvisitor);
// create initial partition based on height
partition_.Initialize(hvisitor.num_states());
partition_.AllocateClasses(hvisitor.max_height() + 1);
const vector<StateId>& hstates = hvisitor.height();
for (size_t s = 0; s < hstates.size(); ++s)
partition_.Add(s, hstates[s]);
}
// refine states based on arc sort (out degree, arc equivalence)
void Refine(const Fst<A>& fst) {
typedef map<StateId, StateId, StateComparator<A> > EquivalenceMap;
StateComparator<A> comp(fst, partition_);
// start with tail (height = 0)
size_t height = partition_.num_classes();
for (size_t h = 0; h < height; ++h) {
EquivalenceMap equiv_classes(comp);
// sort states within equivalence class
PartitionIterator<StateId> siter(partition_, h);
equiv_classes[siter.Value()] = h;
for (siter.Next(); !siter.Done(); siter.Next()) {
const StateId s = siter.Value();
typename EquivalenceMap::const_iterator it = equiv_classes.find(s);
if (it == equiv_classes.end())
equiv_classes[s] = partition_.AddClass();
else
equiv_classes[s] = it->second;
}
// create refined partition
for (siter.Reset(); !siter.Done();) {
const StateId s = siter.Value();
const StateId old_class = partition_.class_id(s);
const StateId new_class = equiv_classes[s];
// a move operation can invalidate the iterator, so
// we first update the iterator to the next element
// before we move the current element out of the list
siter.Next();
if (old_class != new_class)
partition_.Move(s, new_class);
}
}
}
private:
Partition<StateId> partition_;
};
// Given a partition and a mutable fst, merge states of Fst inplace
// (i.e. destructively). Merging works by taking the first state in
// a class of the partition to be the representative state for the class.
// Each arc is then reconnected to this state. All states in the class
// are merged by adding there arcs to the representative state.
template <class A>
void MergeStates(
const Partition<typename A::StateId>& partition, MutableFst<A>* fst) {
typedef typename A::StateId StateId;
vector<StateId> state_map(partition.num_classes());
for (size_t i = 0; i < partition.num_classes(); ++i) {
PartitionIterator<StateId> siter(partition, i);
state_map[i] = siter.Value(); // first state in partition;
}
// relabel destination states
for (size_t c = 0; c < partition.num_classes(); ++c) {
for (PartitionIterator<StateId> siter(partition, c);
!siter.Done(); siter.Next()) {
StateId s = siter.Value();
for (MutableArcIterator<MutableFst<A> > aiter(fst, s);
!aiter.Done(); aiter.Next()) {
A arc = aiter.Value();
arc.nextstate = state_map[partition.class_id(arc.nextstate)];
if (s == state_map[c]) // first state just set destination
aiter.SetValue(arc);
else
fst->AddArc(state_map[c], arc);
}
}
}
fst->SetStart(state_map[partition.class_id(fst->Start())]);
Connect(fst);
}
template <class A>
void AcceptorMinimize(MutableFst<A>* fst) {
typedef typename A::StateId StateId;
if (!(fst->Properties(kAcceptor | kUnweighted, true))) {
FSTERROR() << "FST is not an unweighted acceptor";
fst->SetProperties(kError, kError);
return;
}
// connect fst before minimization, handles disconnected states
Connect(fst);
if (fst->NumStates() == 0) return;
if (fst->Properties(kAcyclic, true)) {
// Acyclic minimization (revuz)
VLOG(2) << "Acyclic Minimization";
ArcSort(fst, ILabelCompare<A>());
AcyclicMinimizer<A> minimizer(*fst);
MergeStates(minimizer.partition(), fst);
} else {
// Cyclic minimizaton (hopcroft)
VLOG(2) << "Cyclic Minimization";
CyclicMinimizer<A, LifoQueue<StateId> > minimizer(*fst);
MergeStates(minimizer.partition(), fst);
}
// Merge in appropriate semiring
ArcUniqueMapper<A> mapper(*fst);
StateMap(fst, mapper);
}
// In place minimization of deterministic weighted automata and transducers.
// For transducers, then the 'sfst' argument is not null, the algorithm
// produces a compact factorization of the minimal transducer.
//
// In the acyclic case, we use an algorithm from Dominique Revuz that
// is linear in the number of arcs (edges) in the machine.
// Complexity = O(E)
//
// In the cyclic case, we use the classical hopcroft minimization.
// Complexity = O(|E|log(|N|)
//
template <class A>
void Minimize(MutableFst<A>* fst,
MutableFst<A>* sfst = 0,
float delta = kDelta) {
uint64 props = fst->Properties(kAcceptor | kIDeterministic|
kWeighted | kUnweighted, true);
if (!(props & kIDeterministic)) {
FSTERROR() << "FST is not deterministic";
fst->SetProperties(kError, kError);
return;
}
if (!(props & kAcceptor)) { // weighted transducer
VectorFst< GallicArc<A, STRING_LEFT> > gfst;
ArcMap(*fst, &gfst, ToGallicMapper<A, STRING_LEFT>());
fst->DeleteStates();
gfst.SetProperties(kAcceptor, kAcceptor);
Push(&gfst, REWEIGHT_TO_INITIAL, delta);
ArcMap(&gfst, QuantizeMapper< GallicArc<A, STRING_LEFT> >(delta));
EncodeMapper< GallicArc<A, STRING_LEFT> >
encoder(kEncodeLabels | kEncodeWeights, ENCODE);
Encode(&gfst, &encoder);
AcceptorMinimize(&gfst);
Decode(&gfst, encoder);
if (sfst == 0) {
FactorWeightFst< GallicArc<A, STRING_LEFT>,
GallicFactor<typename A::Label,
typename A::Weight, STRING_LEFT> > fwfst(gfst);
SymbolTable *osyms = fst->OutputSymbols() ?
fst->OutputSymbols()->Copy() : 0;
ArcMap(fwfst, fst, FromGallicMapper<A, STRING_LEFT>());
fst->SetOutputSymbols(osyms);
delete osyms;
} else {
sfst->SetOutputSymbols(fst->OutputSymbols());
GallicToNewSymbolsMapper<A, STRING_LEFT> mapper(sfst);
ArcMap(gfst, fst, &mapper);
fst->SetOutputSymbols(sfst->InputSymbols());
}
} else if (props & kWeighted) { // weighted acceptor
Push(fst, REWEIGHT_TO_INITIAL, delta);
ArcMap(fst, QuantizeMapper<A>(delta));
EncodeMapper<A> encoder(kEncodeLabels | kEncodeWeights, ENCODE);
Encode(fst, &encoder);
AcceptorMinimize(fst);
Decode(fst, encoder);
} else { // unweighted acceptor
AcceptorMinimize(fst);
}
}
} // namespace fst
#endif // FST_LIB_MINIMIZE_H__