blob: 82ddffa7d064c078f6ba2a944dd12e6db7ed85f8 [file] [log] [blame]
// randgen.h
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: riley@google.com (Michael Riley)
//
// \file
// Classes and functions to generate random paths through an FST.
#ifndef FST_LIB_RANDGEN_H__
#define FST_LIB_RANDGEN_H__
#include <cmath>
#include <cstdlib>
#include <ctime>
#include <map>
#include <fst/accumulator.h>
#include <fst/cache.h>
#include <fst/dfs-visit.h>
#include <fst/mutable-fst.h>
namespace fst {
//
// ARC SELECTORS - these function objects are used to select a random
// transition to take from an FST's state. They should return a number
// N s.t. 0 <= N <= NumArcs(). If N < NumArcs(), then the N-th
// transition is selected. If N == NumArcs(), then the final weight at
// that state is selected (i.e., the 'super-final' transition is selected).
// It can be assumed these will not be called unless either there
// are transitions leaving the state and/or the state is final.
//
// Randomly selects a transition using the uniform distribution.
template <class A>
struct UniformArcSelector {
typedef typename A::StateId StateId;
typedef typename A::Weight Weight;
UniformArcSelector(int seed = time(0)) { srand(seed); }
size_t operator()(const Fst<A> &fst, StateId s) const {
double r = rand()/(RAND_MAX + 1.0);
size_t n = fst.NumArcs(s);
if (fst.Final(s) != Weight::Zero())
++n;
return static_cast<size_t>(r * n);
}
};
// Randomly selects a transition w.r.t. the weights treated as negative
// log probabilities after normalizing for the total weight leaving
// the state. Weight::zero transitions are disregarded.
// Assumes Weight::Value() accesses the floating point
// representation of the weight.
template <class A>
class LogProbArcSelector {
public:
typedef typename A::StateId StateId;
typedef typename A::Weight Weight;
LogProbArcSelector(int seed = time(0)) { srand(seed); }
size_t operator()(const Fst<A> &fst, StateId s) const {
// Find total weight leaving state
double sum = 0.0;
for (ArcIterator< Fst<A> > aiter(fst, s); !aiter.Done();
aiter.Next()) {
const A &arc = aiter.Value();
sum += exp(-to_log_weight_(arc.weight).Value());
}
sum += exp(-to_log_weight_(fst.Final(s)).Value());
double r = rand()/(RAND_MAX + 1.0);
double p = 0.0;
int n = 0;
for (ArcIterator< Fst<A> > aiter(fst, s); !aiter.Done();
aiter.Next(), ++n) {
const A &arc = aiter.Value();
p += exp(-to_log_weight_(arc.weight).Value());
if (p > r * sum) return n;
}
return n;
}
private:
WeightConvert<Weight, Log64Weight> to_log_weight_;
};
// Convenience definitions
typedef LogProbArcSelector<StdArc> StdArcSelector;
typedef LogProbArcSelector<LogArc> LogArcSelector;
// Same as LogProbArcSelector but use CacheLogAccumulator to cache
// the cummulative weight computations.
template <class A>
class FastLogProbArcSelector : public LogProbArcSelector<A> {
public:
typedef typename A::StateId StateId;
typedef typename A::Weight Weight;
using LogProbArcSelector<A>::operator();
FastLogProbArcSelector(int seed = time(0))
: LogProbArcSelector<A>(seed),
seed_(seed) {}
size_t operator()(const Fst<A> &fst, StateId s,
CacheLogAccumulator<A> *accumulator) const {
accumulator->SetState(s);
ArcIterator< Fst<A> > aiter(fst, s);
// Find total weight leaving state
double sum = to_log_weight_(accumulator->Sum(fst.Final(s), &aiter, 0,
fst.NumArcs(s))).Value();
double r = -log(rand()/(RAND_MAX + 1.0));
return accumulator->LowerBound(r + sum, &aiter);
}
int Seed() const { return seed_; }
private:
int seed_;
WeightConvert<Weight, Log64Weight> to_log_weight_;
};
// Random path state info maintained by RandGenFst and passed to samplers.
template <typename A>
struct RandState {
typedef typename A::StateId StateId;
StateId state_id; // current input FST state
size_t nsamples; // # of samples to be sampled at this state
size_t length; // length of path to this random state
size_t select; // previous sample arc selection
const RandState<A> *parent; // previous random state on this path
RandState(StateId s, size_t n, size_t l, size_t k, const RandState<A> *p)
: state_id(s), nsamples(n), length(l), select(k), parent(p) {}
RandState()
: state_id(kNoStateId), nsamples(0), length(0), select(0), parent(0) {}
};
// This class, given an arc selector, samples, with raplacement,
// multiple random transitions from an FST's state. This is a generic
// version with a straight-forward use of the arc selector.
// Specializations may be defined for arc selectors for greater
// efficiency or special behavior.
template <class A, class S>
class ArcSampler {
public:
typedef typename A::StateId StateId;
typedef typename A::Weight Weight;
// The 'max_length' may be interpreted (including ignored) by a
// sampler as it chooses. This generic version interprets this literally.
ArcSampler(const Fst<A> &fst, const S &arc_selector,
int max_length = INT_MAX)
: fst_(fst),
arc_selector_(arc_selector),
max_length_(max_length) {}
// Allow updating Fst argument; pass only if changed.
ArcSampler(const ArcSampler<A, S> &sampler, const Fst<A> *fst = 0)
: fst_(fst ? *fst : sampler.fst_),
arc_selector_(sampler.arc_selector_),
max_length_(sampler.max_length_) {
Reset();
}
// Samples 'rstate.nsamples' from state 'state_id'. The 'rstate.length' is
// the length of the path to 'rstate'. Returns true if samples were
// collected. No samples may be collected if either there are no (including
// 'super-final') transitions leaving that state or if the
// 'max_length' has been deemed reached. Use the iterator members to
// read the samples. The samples will be in their original order.
bool Sample(const RandState<A> &rstate) {
sample_map_.clear();
if ((fst_.NumArcs(rstate.state_id) == 0 &&
fst_.Final(rstate.state_id) == Weight::Zero()) ||
rstate.length == max_length_) {
Reset();
return false;
}
for (size_t i = 0; i < rstate.nsamples; ++i)
++sample_map_[arc_selector_(fst_, rstate.state_id)];
Reset();
return true;
}
// More samples?
bool Done() const { return sample_iter_ == sample_map_.end(); }
// Gets the next sample.
void Next() { ++sample_iter_; }
// Returns a pair (N, K) where 0 <= N <= NumArcs(s) and 0 < K <= nsamples.
// If N < NumArcs(s), then the N-th transition is specified.
// If N == NumArcs(s), then the final weight at that state is
// specified (i.e., the 'super-final' transition is specified).
// For the specified transition, K repetitions have been sampled.
pair<size_t, size_t> Value() const { return *sample_iter_; }
void Reset() { sample_iter_ = sample_map_.begin(); }
bool Error() const { return false; }
private:
const Fst<A> &fst_;
const S &arc_selector_;
int max_length_;
// Stores (N, K) as described for Value().
map<size_t, size_t> sample_map_;
map<size_t, size_t>::const_iterator sample_iter_;
// disallow
ArcSampler<A, S> & operator=(const ArcSampler<A, S> &s);
};
// Specialization for FastLogProbArcSelector.
template <class A>
class ArcSampler<A, FastLogProbArcSelector<A> > {
public:
typedef FastLogProbArcSelector<A> S;
typedef typename A::StateId StateId;
typedef typename A::Weight Weight;
typedef CacheLogAccumulator<A> C;
ArcSampler(const Fst<A> &fst, const S &arc_selector, int max_length = INT_MAX)
: fst_(fst),
arc_selector_(arc_selector),
max_length_(max_length),
accumulator_(new C()) {
accumulator_->Init(fst);
}
ArcSampler(const ArcSampler<A, S> &sampler, const Fst<A> *fst = 0)
: fst_(fst ? *fst : sampler.fst_),
arc_selector_(sampler.arc_selector_),
max_length_(sampler.max_length_) {
if (fst) {
accumulator_ = new C();
accumulator_->Init(*fst);
} else { // shallow copy
accumulator_ = new C(*sampler.accumulator_);
}
}
~ArcSampler() {
delete accumulator_;
}
bool Sample(const RandState<A> &rstate) {
sample_map_.clear();
if ((fst_.NumArcs(rstate.state_id) == 0 &&
fst_.Final(rstate.state_id) == Weight::Zero()) ||
rstate.length == max_length_) {
Reset();
return false;
}
for (size_t i = 0; i < rstate.nsamples; ++i)
++sample_map_[arc_selector_(fst_, rstate.state_id, accumulator_)];
Reset();
return true;
}
bool Done() const { return sample_iter_ == sample_map_.end(); }
void Next() { ++sample_iter_; }
pair<size_t, size_t> Value() const { return *sample_iter_; }
void Reset() { sample_iter_ = sample_map_.begin(); }
bool Error() const { return accumulator_->Error(); }
private:
const Fst<A> &fst_;
const S &arc_selector_;
int max_length_;
// Stores (N, K) as described for Value().
map<size_t, size_t> sample_map_;
map<size_t, size_t>::const_iterator sample_iter_;
C *accumulator_;
// disallow
ArcSampler<A, S> & operator=(const ArcSampler<A, S> &s);
};
// Options for random path generation with RandGenFst. The template argument
// is an arc sampler, typically class 'ArcSampler' above. Ownership of
// the sampler is taken by RandGenFst.
template <class S>
struct RandGenFstOptions : public CacheOptions {
S *arc_sampler; // How to sample transitions at a state
size_t npath; // # of paths to generate
bool weighted; // Output tree weighted by path count; o.w.
// output unweighted DAG
bool remove_total_weight; // Remove total weight when output is weighted.
RandGenFstOptions(const CacheOptions &copts, S *samp,
size_t n = 1, bool w = true, bool rw = false)
: CacheOptions(copts),
arc_sampler(samp),
npath(n),
weighted(w),
remove_total_weight(rw) {}
};
// Implementation of RandGenFst.
template <class A, class B, class S>
class RandGenFstImpl : public CacheImpl<B> {
public:
using FstImpl<B>::SetType;
using FstImpl<B>::SetProperties;
using FstImpl<B>::SetInputSymbols;
using FstImpl<B>::SetOutputSymbols;
using CacheBaseImpl< CacheState<B> >::AddArc;
using CacheBaseImpl< CacheState<B> >::HasArcs;
using CacheBaseImpl< CacheState<B> >::HasFinal;
using CacheBaseImpl< CacheState<B> >::HasStart;
using CacheBaseImpl< CacheState<B> >::SetArcs;
using CacheBaseImpl< CacheState<B> >::SetFinal;
using CacheBaseImpl< CacheState<B> >::SetStart;
typedef B Arc;
typedef typename A::Label Label;
typedef typename A::Weight Weight;
typedef typename A::StateId StateId;
RandGenFstImpl(const Fst<A> &fst, const RandGenFstOptions<S> &opts)
: CacheImpl<B>(opts),
fst_(fst.Copy()),
arc_sampler_(opts.arc_sampler),
npath_(opts.npath),
weighted_(opts.weighted),
remove_total_weight_(opts.remove_total_weight),
superfinal_(kNoLabel) {
SetType("randgen");
uint64 props = fst.Properties(kFstProperties, false);
SetProperties(RandGenProperties(props, weighted_), kCopyProperties);
SetInputSymbols(fst.InputSymbols());
SetOutputSymbols(fst.OutputSymbols());
}
RandGenFstImpl(const RandGenFstImpl &impl)
: CacheImpl<B>(impl),
fst_(impl.fst_->Copy(true)),
arc_sampler_(new S(*impl.arc_sampler_, fst_)),
npath_(impl.npath_),
weighted_(impl.weighted_),
superfinal_(kNoLabel) {
SetType("randgen");
SetProperties(impl.Properties(), kCopyProperties);
SetInputSymbols(impl.InputSymbols());
SetOutputSymbols(impl.OutputSymbols());
}
~RandGenFstImpl() {
for (int i = 0; i < state_table_.size(); ++i)
delete state_table_[i];
delete fst_;
delete arc_sampler_;
}
StateId Start() {
if (!HasStart()) {
StateId s = fst_->Start();
if (s == kNoStateId)
return kNoStateId;
StateId start = state_table_.size();
SetStart(start);
RandState<A> *rstate = new RandState<A>(s, npath_, 0, 0, 0);
state_table_.push_back(rstate);
}
return CacheImpl<B>::Start();
}
Weight Final(StateId s) {
if (!HasFinal(s)) {
Expand(s);
}
return CacheImpl<B>::Final(s);
}
size_t NumArcs(StateId s) {
if (!HasArcs(s)) {
Expand(s);
}
return CacheImpl<B>::NumArcs(s);
}
size_t NumInputEpsilons(StateId s) {
if (!HasArcs(s))
Expand(s);
return CacheImpl<B>::NumInputEpsilons(s);
}
size_t NumOutputEpsilons(StateId s) {
if (!HasArcs(s))
Expand(s);
return CacheImpl<B>::NumOutputEpsilons(s);
}
uint64 Properties() const { return Properties(kFstProperties); }
// Set error if found; return FST impl properties.
uint64 Properties(uint64 mask) const {
if ((mask & kError) &&
(fst_->Properties(kError, false) || arc_sampler_->Error())) {
SetProperties(kError, kError);
}
return FstImpl<Arc>::Properties(mask);
}
void InitArcIterator(StateId s, ArcIteratorData<B> *data) {
if (!HasArcs(s))
Expand(s);
CacheImpl<B>::InitArcIterator(s, data);
}
// Computes the outgoing transitions from a state, creating new destination
// states as needed.
void Expand(StateId s) {
if (s == superfinal_) {
SetFinal(s, Weight::One());
SetArcs(s);
return;
}
SetFinal(s, Weight::Zero());
const RandState<A> &rstate = *state_table_[s];
arc_sampler_->Sample(rstate);
ArcIterator< Fst<A> > aiter(*fst_, rstate.state_id);
size_t narcs = fst_->NumArcs(rstate.state_id);
for (;!arc_sampler_->Done(); arc_sampler_->Next()) {
const pair<size_t, size_t> &sample_pair = arc_sampler_->Value();
size_t pos = sample_pair.first;
size_t count = sample_pair.second;
double prob = static_cast<double>(count)/rstate.nsamples;
if (pos < narcs) { // regular transition
aiter.Seek(sample_pair.first);
const A &aarc = aiter.Value();
Weight weight = weighted_ ? to_weight_(-log(prob)) : Weight::One();
B barc(aarc.ilabel, aarc.olabel, weight, state_table_.size());
AddArc(s, barc);
RandState<A> *nrstate =
new RandState<A>(aarc.nextstate, count, rstate.length + 1,
pos, &rstate);
state_table_.push_back(nrstate);
} else { // super-final transition
if (weighted_) {
Weight weight = remove_total_weight_ ?
to_weight_(-log(prob)) : to_weight_(-log(prob * npath_));
SetFinal(s, weight);
} else {
if (superfinal_ == kNoLabel) {
superfinal_ = state_table_.size();
RandState<A> *nrstate = new RandState<A>(kNoStateId, 0, 0, 0, 0);
state_table_.push_back(nrstate);
}
for (size_t n = 0; n < count; ++n) {
B barc(0, 0, Weight::One(), superfinal_);
AddArc(s, barc);
}
}
}
}
SetArcs(s);
}
private:
Fst<A> *fst_;
S *arc_sampler_;
size_t npath_;
vector<RandState<A> *> state_table_;
bool weighted_;
bool remove_total_weight_;
StateId superfinal_;
WeightConvert<Log64Weight, Weight> to_weight_;
void operator=(const RandGenFstImpl<A, B, S> &); // disallow
};
// Fst class to randomly generate paths through an FST; details controlled
// by RandGenOptionsFst. Output format is a tree weighted by the
// path count.
template <class A, class B, class S>
class RandGenFst : public ImplToFst< RandGenFstImpl<A, B, S> > {
public:
friend class ArcIterator< RandGenFst<A, B, S> >;
friend class StateIterator< RandGenFst<A, B, S> >;
typedef B Arc;
typedef S Sampler;
typedef typename A::Label Label;
typedef typename A::Weight Weight;
typedef typename A::StateId StateId;
typedef CacheState<B> State;
typedef RandGenFstImpl<A, B, S> Impl;
RandGenFst(const Fst<A> &fst, const RandGenFstOptions<S> &opts)
: ImplToFst<Impl>(new Impl(fst, opts)) {}
// See Fst<>::Copy() for doc.
RandGenFst(const RandGenFst<A, B, S> &fst, bool safe = false)
: ImplToFst<Impl>(fst, safe) {}
// Get a copy of this RandGenFst. See Fst<>::Copy() for further doc.
virtual RandGenFst<A, B, S> *Copy(bool safe = false) const {
return new RandGenFst<A, B, S>(*this, safe);
}
virtual inline void InitStateIterator(StateIteratorData<B> *data) const;
virtual void InitArcIterator(StateId s, ArcIteratorData<B> *data) const {
GetImpl()->InitArcIterator(s, data);
}
private:
// Makes visible to friends.
Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
void operator=(const RandGenFst<A, B, S> &fst); // Disallow
};
// Specialization for RandGenFst.
template <class A, class B, class S>
class StateIterator< RandGenFst<A, B, S> >
: public CacheStateIterator< RandGenFst<A, B, S> > {
public:
explicit StateIterator(const RandGenFst<A, B, S> &fst)
: CacheStateIterator< RandGenFst<A, B, S> >(fst, fst.GetImpl()) {}
private:
DISALLOW_COPY_AND_ASSIGN(StateIterator);
};
// Specialization for RandGenFst.
template <class A, class B, class S>
class ArcIterator< RandGenFst<A, B, S> >
: public CacheArcIterator< RandGenFst<A, B, S> > {
public:
typedef typename A::StateId StateId;
ArcIterator(const RandGenFst<A, B, S> &fst, StateId s)
: CacheArcIterator< RandGenFst<A, B, S> >(fst.GetImpl(), s) {
if (!fst.GetImpl()->HasArcs(s))
fst.GetImpl()->Expand(s);
}
private:
DISALLOW_COPY_AND_ASSIGN(ArcIterator);
};
template <class A, class B, class S> inline
void RandGenFst<A, B, S>::InitStateIterator(StateIteratorData<B> *data) const
{
data->base = new StateIterator< RandGenFst<A, B, S> >(*this);
}
// Options for random path generation.
template <class S>
struct RandGenOptions {
const S &arc_selector; // How an arc is selected at a state
int max_length; // Maximum path length
size_t npath; // # of paths to generate
bool weighted; // Output is tree weighted by path count; o.w.
// output unweighted union of paths.
bool remove_total_weight; // Remove total weight when output is weighted.
RandGenOptions(const S &sel, int len = INT_MAX, size_t n = 1,
bool w = false, bool rw = false)
: arc_selector(sel),
max_length(len),
npath(n),
weighted(w),
remove_total_weight(rw) {}
};
template <class IArc, class OArc>
class RandGenVisitor {
public:
typedef typename IArc::Weight Weight;
typedef typename IArc::StateId StateId;
RandGenVisitor(MutableFst<OArc> *ofst) : ofst_(ofst) {}
void InitVisit(const Fst<IArc> &ifst) {
ifst_ = &ifst;
ofst_->DeleteStates();
ofst_->SetInputSymbols(ifst.InputSymbols());
ofst_->SetOutputSymbols(ifst.OutputSymbols());
if (ifst.Properties(kError, false))
ofst_->SetProperties(kError, kError);
path_.clear();
}
bool InitState(StateId s, StateId root) { return true; }
bool TreeArc(StateId s, const IArc &arc) {
if (ifst_->Final(arc.nextstate) == Weight::Zero()) {
path_.push_back(arc);
} else {
OutputPath();
}
return true;
}
bool BackArc(StateId s, const IArc &arc) {
FSTERROR() << "RandGenVisitor: cyclic input";
ofst_->SetProperties(kError, kError);
return false;
}
bool ForwardOrCrossArc(StateId s, const IArc &arc) {
OutputPath();
return true;
}
void FinishState(StateId s, StateId p, const IArc *) {
if (p != kNoStateId && ifst_->Final(s) == Weight::Zero())
path_.pop_back();
}
void FinishVisit() {}
private:
void OutputPath() {
if (ofst_->Start() == kNoStateId) {
StateId start = ofst_->AddState();
ofst_->SetStart(start);
}
StateId src = ofst_->Start();
for (size_t i = 0; i < path_.size(); ++i) {
StateId dest = ofst_->AddState();
OArc arc(path_[i].ilabel, path_[i].olabel, Weight::One(), dest);
ofst_->AddArc(src, arc);
src = dest;
}
ofst_->SetFinal(src, Weight::One());
}
const Fst<IArc> *ifst_;
MutableFst<OArc> *ofst_;
vector<OArc> path_;
DISALLOW_COPY_AND_ASSIGN(RandGenVisitor);
};
// Randomly generate paths through an FST; details controlled by
// RandGenOptions.
template<class IArc, class OArc, class Selector>
void RandGen(const Fst<IArc> &ifst, MutableFst<OArc> *ofst,
const RandGenOptions<Selector> &opts) {
typedef ArcSampler<IArc, Selector> Sampler;
typedef RandGenFst<IArc, OArc, Sampler> RandFst;
typedef typename OArc::StateId StateId;
typedef typename OArc::Weight Weight;
Sampler* arc_sampler = new Sampler(ifst, opts.arc_selector, opts.max_length);
RandGenFstOptions<Sampler> fopts(CacheOptions(true, 0), arc_sampler,
opts.npath, opts.weighted,
opts.remove_total_weight);
RandFst rfst(ifst, fopts);
if (opts.weighted) {
*ofst = rfst;
} else {
RandGenVisitor<IArc, OArc> rand_visitor(ofst);
DfsVisit(rfst, &rand_visitor);
}
}
// Randomly generate a path through an FST with the uniform distribution
// over the transitions.
template<class IArc, class OArc>
void RandGen(const Fst<IArc> &ifst, MutableFst<OArc> *ofst) {
UniformArcSelector<IArc> uniform_selector;
RandGenOptions< UniformArcSelector<IArc> > opts(uniform_selector);
RandGen(ifst, ofst, opts);
}
} // namespace fst
#endif // FST_LIB_RANDGEN_H__