| // randgen.h |
| |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| // |
| // Copyright 2005-2010 Google, Inc. |
| // Author: riley@google.com (Michael Riley) |
| // |
| // \file |
| // Classes and functions to generate random paths through an FST. |
| |
| #ifndef FST_LIB_RANDGEN_H__ |
| #define FST_LIB_RANDGEN_H__ |
| |
| #include <cmath> |
| #include <cstdlib> |
| #include <ctime> |
| #include <map> |
| |
| #include <fst/accumulator.h> |
| #include <fst/cache.h> |
| #include <fst/dfs-visit.h> |
| #include <fst/mutable-fst.h> |
| |
| namespace fst { |
| |
| // |
| // ARC SELECTORS - these function objects are used to select a random |
| // transition to take from an FST's state. They should return a number |
| // N s.t. 0 <= N <= NumArcs(). If N < NumArcs(), then the N-th |
| // transition is selected. If N == NumArcs(), then the final weight at |
| // that state is selected (i.e., the 'super-final' transition is selected). |
| // It can be assumed these will not be called unless either there |
| // are transitions leaving the state and/or the state is final. |
| // |
| |
| // Randomly selects a transition using the uniform distribution. |
| template <class A> |
| struct UniformArcSelector { |
| typedef typename A::StateId StateId; |
| typedef typename A::Weight Weight; |
| |
| UniformArcSelector(int seed = time(0)) { srand(seed); } |
| |
| size_t operator()(const Fst<A> &fst, StateId s) const { |
| double r = rand()/(RAND_MAX + 1.0); |
| size_t n = fst.NumArcs(s); |
| if (fst.Final(s) != Weight::Zero()) |
| ++n; |
| return static_cast<size_t>(r * n); |
| } |
| }; |
| |
| |
| // Randomly selects a transition w.r.t. the weights treated as negative |
| // log probabilities after normalizing for the total weight leaving |
| // the state. Weight::zero transitions are disregarded. |
| // Assumes Weight::Value() accesses the floating point |
| // representation of the weight. |
| template <class A> |
| class LogProbArcSelector { |
| public: |
| typedef typename A::StateId StateId; |
| typedef typename A::Weight Weight; |
| |
| LogProbArcSelector(int seed = time(0)) { srand(seed); } |
| |
| size_t operator()(const Fst<A> &fst, StateId s) const { |
| // Find total weight leaving state |
| double sum = 0.0; |
| for (ArcIterator< Fst<A> > aiter(fst, s); !aiter.Done(); |
| aiter.Next()) { |
| const A &arc = aiter.Value(); |
| sum += exp(-to_log_weight_(arc.weight).Value()); |
| } |
| sum += exp(-to_log_weight_(fst.Final(s)).Value()); |
| |
| double r = rand()/(RAND_MAX + 1.0); |
| double p = 0.0; |
| int n = 0; |
| for (ArcIterator< Fst<A> > aiter(fst, s); !aiter.Done(); |
| aiter.Next(), ++n) { |
| const A &arc = aiter.Value(); |
| p += exp(-to_log_weight_(arc.weight).Value()); |
| if (p > r * sum) return n; |
| } |
| return n; |
| } |
| |
| private: |
| WeightConvert<Weight, Log64Weight> to_log_weight_; |
| }; |
| |
| // Convenience definitions |
| typedef LogProbArcSelector<StdArc> StdArcSelector; |
| typedef LogProbArcSelector<LogArc> LogArcSelector; |
| |
| |
| // Same as LogProbArcSelector but use CacheLogAccumulator to cache |
| // the cummulative weight computations. |
| template <class A> |
| class FastLogProbArcSelector : public LogProbArcSelector<A> { |
| public: |
| typedef typename A::StateId StateId; |
| typedef typename A::Weight Weight; |
| using LogProbArcSelector<A>::operator(); |
| |
| FastLogProbArcSelector(int seed = time(0)) |
| : LogProbArcSelector<A>(seed), |
| seed_(seed) {} |
| |
| size_t operator()(const Fst<A> &fst, StateId s, |
| CacheLogAccumulator<A> *accumulator) const { |
| accumulator->SetState(s); |
| ArcIterator< Fst<A> > aiter(fst, s); |
| // Find total weight leaving state |
| double sum = to_log_weight_(accumulator->Sum(fst.Final(s), &aiter, 0, |
| fst.NumArcs(s))).Value(); |
| double r = -log(rand()/(RAND_MAX + 1.0)); |
| return accumulator->LowerBound(r + sum, &aiter); |
| } |
| |
| int Seed() const { return seed_; } |
| private: |
| int seed_; |
| WeightConvert<Weight, Log64Weight> to_log_weight_; |
| }; |
| |
| // Random path state info maintained by RandGenFst and passed to samplers. |
| template <typename A> |
| struct RandState { |
| typedef typename A::StateId StateId; |
| |
| StateId state_id; // current input FST state |
| size_t nsamples; // # of samples to be sampled at this state |
| size_t length; // length of path to this random state |
| size_t select; // previous sample arc selection |
| const RandState<A> *parent; // previous random state on this path |
| |
| RandState(StateId s, size_t n, size_t l, size_t k, const RandState<A> *p) |
| : state_id(s), nsamples(n), length(l), select(k), parent(p) {} |
| |
| RandState() |
| : state_id(kNoStateId), nsamples(0), length(0), select(0), parent(0) {} |
| }; |
| |
| // This class, given an arc selector, samples, with raplacement, |
| // multiple random transitions from an FST's state. This is a generic |
| // version with a straight-forward use of the arc selector. |
| // Specializations may be defined for arc selectors for greater |
| // efficiency or special behavior. |
| template <class A, class S> |
| class ArcSampler { |
| public: |
| typedef typename A::StateId StateId; |
| typedef typename A::Weight Weight; |
| |
| // The 'max_length' may be interpreted (including ignored) by a |
| // sampler as it chooses. This generic version interprets this literally. |
| ArcSampler(const Fst<A> &fst, const S &arc_selector, |
| int max_length = INT_MAX) |
| : fst_(fst), |
| arc_selector_(arc_selector), |
| max_length_(max_length) {} |
| |
| // Allow updating Fst argument; pass only if changed. |
| ArcSampler(const ArcSampler<A, S> &sampler, const Fst<A> *fst = 0) |
| : fst_(fst ? *fst : sampler.fst_), |
| arc_selector_(sampler.arc_selector_), |
| max_length_(sampler.max_length_) { |
| Reset(); |
| } |
| |
| // Samples 'rstate.nsamples' from state 'state_id'. The 'rstate.length' is |
| // the length of the path to 'rstate'. Returns true if samples were |
| // collected. No samples may be collected if either there are no (including |
| // 'super-final') transitions leaving that state or if the |
| // 'max_length' has been deemed reached. Use the iterator members to |
| // read the samples. The samples will be in their original order. |
| bool Sample(const RandState<A> &rstate) { |
| sample_map_.clear(); |
| if ((fst_.NumArcs(rstate.state_id) == 0 && |
| fst_.Final(rstate.state_id) == Weight::Zero()) || |
| rstate.length == max_length_) { |
| Reset(); |
| return false; |
| } |
| |
| for (size_t i = 0; i < rstate.nsamples; ++i) |
| ++sample_map_[arc_selector_(fst_, rstate.state_id)]; |
| Reset(); |
| return true; |
| } |
| |
| // More samples? |
| bool Done() const { return sample_iter_ == sample_map_.end(); } |
| |
| // Gets the next sample. |
| void Next() { ++sample_iter_; } |
| |
| // Returns a pair (N, K) where 0 <= N <= NumArcs(s) and 0 < K <= nsamples. |
| // If N < NumArcs(s), then the N-th transition is specified. |
| // If N == NumArcs(s), then the final weight at that state is |
| // specified (i.e., the 'super-final' transition is specified). |
| // For the specified transition, K repetitions have been sampled. |
| pair<size_t, size_t> Value() const { return *sample_iter_; } |
| |
| void Reset() { sample_iter_ = sample_map_.begin(); } |
| |
| bool Error() const { return false; } |
| |
| private: |
| const Fst<A> &fst_; |
| const S &arc_selector_; |
| int max_length_; |
| |
| // Stores (N, K) as described for Value(). |
| map<size_t, size_t> sample_map_; |
| map<size_t, size_t>::const_iterator sample_iter_; |
| |
| // disallow |
| ArcSampler<A, S> & operator=(const ArcSampler<A, S> &s); |
| }; |
| |
| |
| // Specialization for FastLogProbArcSelector. |
| template <class A> |
| class ArcSampler<A, FastLogProbArcSelector<A> > { |
| public: |
| typedef FastLogProbArcSelector<A> S; |
| typedef typename A::StateId StateId; |
| typedef typename A::Weight Weight; |
| typedef CacheLogAccumulator<A> C; |
| |
| ArcSampler(const Fst<A> &fst, const S &arc_selector, int max_length = INT_MAX) |
| : fst_(fst), |
| arc_selector_(arc_selector), |
| max_length_(max_length), |
| accumulator_(new C()) { |
| accumulator_->Init(fst); |
| } |
| |
| ArcSampler(const ArcSampler<A, S> &sampler, const Fst<A> *fst = 0) |
| : fst_(fst ? *fst : sampler.fst_), |
| arc_selector_(sampler.arc_selector_), |
| max_length_(sampler.max_length_) { |
| if (fst) { |
| accumulator_ = new C(); |
| accumulator_->Init(*fst); |
| } else { // shallow copy |
| accumulator_ = new C(*sampler.accumulator_); |
| } |
| } |
| |
| ~ArcSampler() { |
| delete accumulator_; |
| } |
| |
| bool Sample(const RandState<A> &rstate) { |
| sample_map_.clear(); |
| if ((fst_.NumArcs(rstate.state_id) == 0 && |
| fst_.Final(rstate.state_id) == Weight::Zero()) || |
| rstate.length == max_length_) { |
| Reset(); |
| return false; |
| } |
| |
| for (size_t i = 0; i < rstate.nsamples; ++i) |
| ++sample_map_[arc_selector_(fst_, rstate.state_id, accumulator_)]; |
| Reset(); |
| return true; |
| } |
| |
| bool Done() const { return sample_iter_ == sample_map_.end(); } |
| void Next() { ++sample_iter_; } |
| pair<size_t, size_t> Value() const { return *sample_iter_; } |
| void Reset() { sample_iter_ = sample_map_.begin(); } |
| |
| bool Error() const { return accumulator_->Error(); } |
| |
| private: |
| const Fst<A> &fst_; |
| const S &arc_selector_; |
| int max_length_; |
| |
| // Stores (N, K) as described for Value(). |
| map<size_t, size_t> sample_map_; |
| map<size_t, size_t>::const_iterator sample_iter_; |
| C *accumulator_; |
| |
| // disallow |
| ArcSampler<A, S> & operator=(const ArcSampler<A, S> &s); |
| }; |
| |
| |
| // Options for random path generation with RandGenFst. The template argument |
| // is an arc sampler, typically class 'ArcSampler' above. Ownership of |
| // the sampler is taken by RandGenFst. |
| template <class S> |
| struct RandGenFstOptions : public CacheOptions { |
| S *arc_sampler; // How to sample transitions at a state |
| size_t npath; // # of paths to generate |
| bool weighted; // Output tree weighted by path count; o.w. |
| // output unweighted DAG |
| bool remove_total_weight; // Remove total weight when output is weighted. |
| |
| RandGenFstOptions(const CacheOptions &copts, S *samp, |
| size_t n = 1, bool w = true, bool rw = false) |
| : CacheOptions(copts), |
| arc_sampler(samp), |
| npath(n), |
| weighted(w), |
| remove_total_weight(rw) {} |
| }; |
| |
| |
| // Implementation of RandGenFst. |
| template <class A, class B, class S> |
| class RandGenFstImpl : public CacheImpl<B> { |
| public: |
| using FstImpl<B>::SetType; |
| using FstImpl<B>::SetProperties; |
| using FstImpl<B>::SetInputSymbols; |
| using FstImpl<B>::SetOutputSymbols; |
| |
| using CacheBaseImpl< CacheState<B> >::AddArc; |
| using CacheBaseImpl< CacheState<B> >::HasArcs; |
| using CacheBaseImpl< CacheState<B> >::HasFinal; |
| using CacheBaseImpl< CacheState<B> >::HasStart; |
| using CacheBaseImpl< CacheState<B> >::SetArcs; |
| using CacheBaseImpl< CacheState<B> >::SetFinal; |
| using CacheBaseImpl< CacheState<B> >::SetStart; |
| |
| typedef B Arc; |
| typedef typename A::Label Label; |
| typedef typename A::Weight Weight; |
| typedef typename A::StateId StateId; |
| |
| RandGenFstImpl(const Fst<A> &fst, const RandGenFstOptions<S> &opts) |
| : CacheImpl<B>(opts), |
| fst_(fst.Copy()), |
| arc_sampler_(opts.arc_sampler), |
| npath_(opts.npath), |
| weighted_(opts.weighted), |
| remove_total_weight_(opts.remove_total_weight), |
| superfinal_(kNoLabel) { |
| SetType("randgen"); |
| |
| uint64 props = fst.Properties(kFstProperties, false); |
| SetProperties(RandGenProperties(props, weighted_), kCopyProperties); |
| |
| SetInputSymbols(fst.InputSymbols()); |
| SetOutputSymbols(fst.OutputSymbols()); |
| } |
| |
| RandGenFstImpl(const RandGenFstImpl &impl) |
| : CacheImpl<B>(impl), |
| fst_(impl.fst_->Copy(true)), |
| arc_sampler_(new S(*impl.arc_sampler_, fst_)), |
| npath_(impl.npath_), |
| weighted_(impl.weighted_), |
| superfinal_(kNoLabel) { |
| SetType("randgen"); |
| SetProperties(impl.Properties(), kCopyProperties); |
| SetInputSymbols(impl.InputSymbols()); |
| SetOutputSymbols(impl.OutputSymbols()); |
| } |
| |
| ~RandGenFstImpl() { |
| for (int i = 0; i < state_table_.size(); ++i) |
| delete state_table_[i]; |
| delete fst_; |
| delete arc_sampler_; |
| } |
| |
| StateId Start() { |
| if (!HasStart()) { |
| StateId s = fst_->Start(); |
| if (s == kNoStateId) |
| return kNoStateId; |
| StateId start = state_table_.size(); |
| SetStart(start); |
| RandState<A> *rstate = new RandState<A>(s, npath_, 0, 0, 0); |
| state_table_.push_back(rstate); |
| } |
| return CacheImpl<B>::Start(); |
| } |
| |
| Weight Final(StateId s) { |
| if (!HasFinal(s)) { |
| Expand(s); |
| } |
| return CacheImpl<B>::Final(s); |
| } |
| |
| size_t NumArcs(StateId s) { |
| if (!HasArcs(s)) { |
| Expand(s); |
| } |
| return CacheImpl<B>::NumArcs(s); |
| } |
| |
| size_t NumInputEpsilons(StateId s) { |
| if (!HasArcs(s)) |
| Expand(s); |
| return CacheImpl<B>::NumInputEpsilons(s); |
| } |
| |
| size_t NumOutputEpsilons(StateId s) { |
| if (!HasArcs(s)) |
| Expand(s); |
| return CacheImpl<B>::NumOutputEpsilons(s); |
| } |
| |
| uint64 Properties() const { return Properties(kFstProperties); } |
| |
| // Set error if found; return FST impl properties. |
| uint64 Properties(uint64 mask) const { |
| if ((mask & kError) && |
| (fst_->Properties(kError, false) || arc_sampler_->Error())) { |
| SetProperties(kError, kError); |
| } |
| return FstImpl<Arc>::Properties(mask); |
| } |
| |
| void InitArcIterator(StateId s, ArcIteratorData<B> *data) { |
| if (!HasArcs(s)) |
| Expand(s); |
| CacheImpl<B>::InitArcIterator(s, data); |
| } |
| |
| // Computes the outgoing transitions from a state, creating new destination |
| // states as needed. |
| void Expand(StateId s) { |
| if (s == superfinal_) { |
| SetFinal(s, Weight::One()); |
| SetArcs(s); |
| return; |
| } |
| |
| SetFinal(s, Weight::Zero()); |
| const RandState<A> &rstate = *state_table_[s]; |
| arc_sampler_->Sample(rstate); |
| ArcIterator< Fst<A> > aiter(*fst_, rstate.state_id); |
| size_t narcs = fst_->NumArcs(rstate.state_id); |
| for (;!arc_sampler_->Done(); arc_sampler_->Next()) { |
| const pair<size_t, size_t> &sample_pair = arc_sampler_->Value(); |
| size_t pos = sample_pair.first; |
| size_t count = sample_pair.second; |
| double prob = static_cast<double>(count)/rstate.nsamples; |
| if (pos < narcs) { // regular transition |
| aiter.Seek(sample_pair.first); |
| const A &aarc = aiter.Value(); |
| Weight weight = weighted_ ? to_weight_(-log(prob)) : Weight::One(); |
| B barc(aarc.ilabel, aarc.olabel, weight, state_table_.size()); |
| AddArc(s, barc); |
| RandState<A> *nrstate = |
| new RandState<A>(aarc.nextstate, count, rstate.length + 1, |
| pos, &rstate); |
| state_table_.push_back(nrstate); |
| } else { // super-final transition |
| if (weighted_) { |
| Weight weight = remove_total_weight_ ? |
| to_weight_(-log(prob)) : to_weight_(-log(prob * npath_)); |
| SetFinal(s, weight); |
| } else { |
| if (superfinal_ == kNoLabel) { |
| superfinal_ = state_table_.size(); |
| RandState<A> *nrstate = new RandState<A>(kNoStateId, 0, 0, 0, 0); |
| state_table_.push_back(nrstate); |
| } |
| for (size_t n = 0; n < count; ++n) { |
| B barc(0, 0, Weight::One(), superfinal_); |
| AddArc(s, barc); |
| } |
| } |
| } |
| } |
| SetArcs(s); |
| } |
| |
| private: |
| Fst<A> *fst_; |
| S *arc_sampler_; |
| size_t npath_; |
| vector<RandState<A> *> state_table_; |
| bool weighted_; |
| bool remove_total_weight_; |
| StateId superfinal_; |
| WeightConvert<Log64Weight, Weight> to_weight_; |
| |
| void operator=(const RandGenFstImpl<A, B, S> &); // disallow |
| }; |
| |
| |
| // Fst class to randomly generate paths through an FST; details controlled |
| // by RandGenOptionsFst. Output format is a tree weighted by the |
| // path count. |
| template <class A, class B, class S> |
| class RandGenFst : public ImplToFst< RandGenFstImpl<A, B, S> > { |
| public: |
| friend class ArcIterator< RandGenFst<A, B, S> >; |
| friend class StateIterator< RandGenFst<A, B, S> >; |
| typedef B Arc; |
| typedef S Sampler; |
| typedef typename A::Label Label; |
| typedef typename A::Weight Weight; |
| typedef typename A::StateId StateId; |
| typedef CacheState<B> State; |
| typedef RandGenFstImpl<A, B, S> Impl; |
| |
| RandGenFst(const Fst<A> &fst, const RandGenFstOptions<S> &opts) |
| : ImplToFst<Impl>(new Impl(fst, opts)) {} |
| |
| // See Fst<>::Copy() for doc. |
| RandGenFst(const RandGenFst<A, B, S> &fst, bool safe = false) |
| : ImplToFst<Impl>(fst, safe) {} |
| |
| // Get a copy of this RandGenFst. See Fst<>::Copy() for further doc. |
| virtual RandGenFst<A, B, S> *Copy(bool safe = false) const { |
| return new RandGenFst<A, B, S>(*this, safe); |
| } |
| |
| virtual inline void InitStateIterator(StateIteratorData<B> *data) const; |
| |
| virtual void InitArcIterator(StateId s, ArcIteratorData<B> *data) const { |
| GetImpl()->InitArcIterator(s, data); |
| } |
| |
| private: |
| // Makes visible to friends. |
| Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); } |
| |
| void operator=(const RandGenFst<A, B, S> &fst); // Disallow |
| }; |
| |
| |
| |
| // Specialization for RandGenFst. |
| template <class A, class B, class S> |
| class StateIterator< RandGenFst<A, B, S> > |
| : public CacheStateIterator< RandGenFst<A, B, S> > { |
| public: |
| explicit StateIterator(const RandGenFst<A, B, S> &fst) |
| : CacheStateIterator< RandGenFst<A, B, S> >(fst, fst.GetImpl()) {} |
| |
| private: |
| DISALLOW_COPY_AND_ASSIGN(StateIterator); |
| }; |
| |
| |
| // Specialization for RandGenFst. |
| template <class A, class B, class S> |
| class ArcIterator< RandGenFst<A, B, S> > |
| : public CacheArcIterator< RandGenFst<A, B, S> > { |
| public: |
| typedef typename A::StateId StateId; |
| |
| ArcIterator(const RandGenFst<A, B, S> &fst, StateId s) |
| : CacheArcIterator< RandGenFst<A, B, S> >(fst.GetImpl(), s) { |
| if (!fst.GetImpl()->HasArcs(s)) |
| fst.GetImpl()->Expand(s); |
| } |
| |
| private: |
| DISALLOW_COPY_AND_ASSIGN(ArcIterator); |
| }; |
| |
| |
| template <class A, class B, class S> inline |
| void RandGenFst<A, B, S>::InitStateIterator(StateIteratorData<B> *data) const |
| { |
| data->base = new StateIterator< RandGenFst<A, B, S> >(*this); |
| } |
| |
| // Options for random path generation. |
| template <class S> |
| struct RandGenOptions { |
| const S &arc_selector; // How an arc is selected at a state |
| int max_length; // Maximum path length |
| size_t npath; // # of paths to generate |
| bool weighted; // Output is tree weighted by path count; o.w. |
| // output unweighted union of paths. |
| bool remove_total_weight; // Remove total weight when output is weighted. |
| |
| RandGenOptions(const S &sel, int len = INT_MAX, size_t n = 1, |
| bool w = false, bool rw = false) |
| : arc_selector(sel), |
| max_length(len), |
| npath(n), |
| weighted(w), |
| remove_total_weight(rw) {} |
| }; |
| |
| |
| template <class IArc, class OArc> |
| class RandGenVisitor { |
| public: |
| typedef typename IArc::Weight Weight; |
| typedef typename IArc::StateId StateId; |
| |
| RandGenVisitor(MutableFst<OArc> *ofst) : ofst_(ofst) {} |
| |
| void InitVisit(const Fst<IArc> &ifst) { |
| ifst_ = &ifst; |
| |
| ofst_->DeleteStates(); |
| ofst_->SetInputSymbols(ifst.InputSymbols()); |
| ofst_->SetOutputSymbols(ifst.OutputSymbols()); |
| if (ifst.Properties(kError, false)) |
| ofst_->SetProperties(kError, kError); |
| path_.clear(); |
| } |
| |
| bool InitState(StateId s, StateId root) { return true; } |
| |
| bool TreeArc(StateId s, const IArc &arc) { |
| if (ifst_->Final(arc.nextstate) == Weight::Zero()) { |
| path_.push_back(arc); |
| } else { |
| OutputPath(); |
| } |
| return true; |
| } |
| |
| bool BackArc(StateId s, const IArc &arc) { |
| FSTERROR() << "RandGenVisitor: cyclic input"; |
| ofst_->SetProperties(kError, kError); |
| return false; |
| } |
| |
| bool ForwardOrCrossArc(StateId s, const IArc &arc) { |
| OutputPath(); |
| return true; |
| } |
| |
| void FinishState(StateId s, StateId p, const IArc *) { |
| if (p != kNoStateId && ifst_->Final(s) == Weight::Zero()) |
| path_.pop_back(); |
| } |
| |
| void FinishVisit() {} |
| |
| private: |
| void OutputPath() { |
| if (ofst_->Start() == kNoStateId) { |
| StateId start = ofst_->AddState(); |
| ofst_->SetStart(start); |
| } |
| |
| StateId src = ofst_->Start(); |
| for (size_t i = 0; i < path_.size(); ++i) { |
| StateId dest = ofst_->AddState(); |
| OArc arc(path_[i].ilabel, path_[i].olabel, Weight::One(), dest); |
| ofst_->AddArc(src, arc); |
| src = dest; |
| } |
| ofst_->SetFinal(src, Weight::One()); |
| } |
| |
| const Fst<IArc> *ifst_; |
| MutableFst<OArc> *ofst_; |
| vector<OArc> path_; |
| |
| DISALLOW_COPY_AND_ASSIGN(RandGenVisitor); |
| }; |
| |
| |
| // Randomly generate paths through an FST; details controlled by |
| // RandGenOptions. |
| template<class IArc, class OArc, class Selector> |
| void RandGen(const Fst<IArc> &ifst, MutableFst<OArc> *ofst, |
| const RandGenOptions<Selector> &opts) { |
| typedef ArcSampler<IArc, Selector> Sampler; |
| typedef RandGenFst<IArc, OArc, Sampler> RandFst; |
| typedef typename OArc::StateId StateId; |
| typedef typename OArc::Weight Weight; |
| |
| Sampler* arc_sampler = new Sampler(ifst, opts.arc_selector, opts.max_length); |
| RandGenFstOptions<Sampler> fopts(CacheOptions(true, 0), arc_sampler, |
| opts.npath, opts.weighted, |
| opts.remove_total_weight); |
| RandFst rfst(ifst, fopts); |
| if (opts.weighted) { |
| *ofst = rfst; |
| } else { |
| RandGenVisitor<IArc, OArc> rand_visitor(ofst); |
| DfsVisit(rfst, &rand_visitor); |
| } |
| } |
| |
| // Randomly generate a path through an FST with the uniform distribution |
| // over the transitions. |
| template<class IArc, class OArc> |
| void RandGen(const Fst<IArc> &ifst, MutableFst<OArc> *ofst) { |
| UniformArcSelector<IArc> uniform_selector; |
| RandGenOptions< UniformArcSelector<IArc> > opts(uniform_selector); |
| RandGen(ifst, ofst, opts); |
| } |
| |
| } // namespace fst |
| |
| #endif // FST_LIB_RANDGEN_H__ |