| // map.h |
| |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| // |
| // Copyright 2005-2010 Google, Inc. |
| // Author: riley@google.com (Michael Riley) |
| // |
| // \file |
| // Class to map over/transform states e.g., sort transitions |
| // Consider using when operation does not change the number of states. |
| |
| #ifndef FST_LIB_STATE_MAP_H__ |
| #define FST_LIB_STATE_MAP_H__ |
| |
| #include <algorithm> |
| #include <unordered_map> |
| using std::tr1::unordered_map; |
| using std::tr1::unordered_multimap; |
| #include <string> |
| #include <utility> |
| using std::pair; using std::make_pair; |
| |
| #include <fst/cache.h> |
| #include <fst/arc-map.h> |
| #include <fst/mutable-fst.h> |
| |
| |
| namespace fst { |
| |
| // StateMapper Interface - class determinies how states are mapped. |
| // Useful for implementing operations that do not change the number of states. |
| // |
| // class StateMapper { |
| // public: |
| // typedef A FromArc; |
| // typedef B ToArc; |
| // |
| // // Typical constructor |
| // StateMapper(const Fst<A> &fst); |
| // // Required copy constructor that allows updating Fst argument; |
| // // pass only if relevant and changed. |
| // StateMapper(const StateMapper &mapper, const Fst<A> *fst = 0); |
| // |
| // // Specifies initial state of result |
| // B::StateId Start() const; |
| // // Specifies state's final weight in result |
| // B::Weight Final(B::StateId s) const; |
| // |
| // // These methods iterate through a state's arcs in result |
| // // Specifies state to iterate over |
| // void SetState(B::StateId s); |
| // // End of arcs? |
| // bool Done() const; |
| // // Current arc |
| |
| // const B &Value() const; |
| // // Advance to next arc (when !Done) |
| // void Next(); |
| // |
| // // Specifies input symbol table action the mapper requires (see above). |
| // MapSymbolsAction InputSymbolsAction() const; |
| // // Specifies output symbol table action the mapper requires (see above). |
| // MapSymbolsAction OutputSymbolsAction() const; |
| // // This specifies the known properties of an Fst mapped by this |
| // // mapper. It takes as argument the input Fst's known properties. |
| // uint64 Properties(uint64 props) const; |
| // }; |
| // |
| // We include a various state map versions below. One dimension of |
| // variation is whether the mapping mutates its input, writes to a |
| // new result Fst, or is an on-the-fly Fst. Another dimension is how |
| // we pass the mapper. We allow passing the mapper by pointer |
| // for cases that we need to change the state of the user's mapper. |
| // We also include map versions that pass the mapper |
| // by value or const reference when this suffices. |
| |
| // Maps an arc type A using a mapper function object C, passed |
| // by pointer. This version modifies its Fst input. |
| template<class A, class C> |
| void StateMap(MutableFst<A> *fst, C* mapper) { |
| typedef typename A::StateId StateId; |
| typedef typename A::Weight Weight; |
| |
| if (mapper->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) |
| fst->SetInputSymbols(0); |
| |
| if (mapper->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) |
| fst->SetOutputSymbols(0); |
| |
| if (fst->Start() == kNoStateId) |
| return; |
| |
| uint64 props = fst->Properties(kFstProperties, false); |
| |
| fst->SetStart(mapper->Start()); |
| |
| for (StateId s = 0; s < fst->NumStates(); ++s) { |
| mapper->SetState(s); |
| fst->DeleteArcs(s); |
| for (; !mapper->Done(); mapper->Next()) |
| fst->AddArc(s, mapper->Value()); |
| fst->SetFinal(s, mapper->Final(s)); |
| } |
| |
| fst->SetProperties(mapper->Properties(props), kFstProperties); |
| } |
| |
| // Maps an arc type A using a mapper function object C, passed |
| // by value. This version modifies its Fst input. |
| template<class A, class C> |
| void StateMap(MutableFst<A> *fst, C mapper) { |
| StateMap(fst, &mapper); |
| } |
| |
| |
| // Maps an arc type A to an arc type B using mapper function |
| // object C, passed by pointer. This version writes the mapped |
| // input Fst to an output MutableFst. |
| template<class A, class B, class C> |
| void StateMap(const Fst<A> &ifst, MutableFst<B> *ofst, C* mapper) { |
| typedef typename A::StateId StateId; |
| typedef typename A::Weight Weight; |
| |
| ofst->DeleteStates(); |
| |
| if (mapper->InputSymbolsAction() == MAP_COPY_SYMBOLS) |
| ofst->SetInputSymbols(ifst.InputSymbols()); |
| else if (mapper->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) |
| ofst->SetInputSymbols(0); |
| |
| if (mapper->OutputSymbolsAction() == MAP_COPY_SYMBOLS) |
| ofst->SetOutputSymbols(ifst.OutputSymbols()); |
| else if (mapper->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) |
| ofst->SetOutputSymbols(0); |
| |
| uint64 iprops = ifst.Properties(kCopyProperties, false); |
| |
| if (ifst.Start() == kNoStateId) { |
| if (iprops & kError) ofst->SetProperties(kError, kError); |
| return; |
| } |
| |
| // Add all states. |
| if (ifst.Properties(kExpanded, false)) |
| ofst->ReserveStates(CountStates(ifst)); |
| for (StateIterator< Fst<A> > siter(ifst); !siter.Done(); siter.Next()) |
| ofst->AddState(); |
| |
| ofst->SetStart(mapper->Start()); |
| |
| for (StateIterator< Fst<A> > siter(ifst); !siter.Done(); siter.Next()) { |
| StateId s = siter.Value(); |
| mapper->SetState(s); |
| for (; !mapper->Done(); mapper->Next()) |
| ofst->AddArc(s, mapper->Value()); |
| ofst->SetFinal(s, mapper->Final(s)); |
| } |
| |
| uint64 oprops = ofst->Properties(kFstProperties, false); |
| ofst->SetProperties(mapper->Properties(iprops) | oprops, kFstProperties); |
| } |
| |
| // Maps an arc type A to an arc type B using mapper function |
| // object C, passed by value. This version writes the mapped input |
| // Fst to an output MutableFst. |
| template<class A, class B, class C> |
| void StateMap(const Fst<A> &ifst, MutableFst<B> *ofst, C mapper) { |
| StateMap(ifst, ofst, &mapper); |
| } |
| |
| typedef CacheOptions StateMapFstOptions; |
| |
| template <class A, class B, class C> class StateMapFst; |
| |
| // Implementation of delayed StateMapFst. |
| template <class A, class B, class C> |
| class StateMapFstImpl : public CacheImpl<B> { |
| public: |
| using FstImpl<B>::SetType; |
| using FstImpl<B>::SetProperties; |
| using FstImpl<B>::SetInputSymbols; |
| using FstImpl<B>::SetOutputSymbols; |
| |
| using VectorFstBaseImpl<typename CacheImpl<B>::State>::NumStates; |
| |
| using CacheImpl<B>::PushArc; |
| using CacheImpl<B>::HasArcs; |
| using CacheImpl<B>::HasFinal; |
| using CacheImpl<B>::HasStart; |
| using CacheImpl<B>::SetArcs; |
| using CacheImpl<B>::SetFinal; |
| using CacheImpl<B>::SetStart; |
| |
| friend class StateIterator< StateMapFst<A, B, C> >; |
| |
| typedef B Arc; |
| typedef typename B::Weight Weight; |
| typedef typename B::StateId StateId; |
| |
| StateMapFstImpl(const Fst<A> &fst, const C &mapper, |
| const StateMapFstOptions& opts) |
| : CacheImpl<B>(opts), |
| fst_(fst.Copy()), |
| mapper_(new C(mapper, fst_)), |
| own_mapper_(true) { |
| Init(); |
| } |
| |
| StateMapFstImpl(const Fst<A> &fst, C *mapper, |
| const StateMapFstOptions& opts) |
| : CacheImpl<B>(opts), |
| fst_(fst.Copy()), |
| mapper_(mapper), |
| own_mapper_(false) { |
| Init(); |
| } |
| |
| StateMapFstImpl(const StateMapFstImpl<A, B, C> &impl) |
| : CacheImpl<B>(impl), |
| fst_(impl.fst_->Copy(true)), |
| mapper_(new C(*impl.mapper_, fst_)), |
| own_mapper_(true) { |
| Init(); |
| } |
| |
| ~StateMapFstImpl() { |
| delete fst_; |
| if (own_mapper_) delete mapper_; |
| } |
| |
| StateId Start() { |
| if (!HasStart()) |
| SetStart(mapper_->Start()); |
| return CacheImpl<B>::Start(); |
| } |
| |
| Weight Final(StateId s) { |
| if (!HasFinal(s)) |
| SetFinal(s, mapper_->Final(s)); |
| return CacheImpl<B>::Final(s); |
| } |
| |
| size_t NumArcs(StateId s) { |
| if (!HasArcs(s)) |
| Expand(s); |
| return CacheImpl<B>::NumArcs(s); |
| } |
| |
| size_t NumInputEpsilons(StateId s) { |
| if (!HasArcs(s)) |
| Expand(s); |
| return CacheImpl<B>::NumInputEpsilons(s); |
| } |
| |
| size_t NumOutputEpsilons(StateId s) { |
| if (!HasArcs(s)) |
| Expand(s); |
| return CacheImpl<B>::NumOutputEpsilons(s); |
| } |
| |
| void InitStateIterator(StateIteratorData<A> *data) const { |
| fst_->InitStateIterator(data); |
| } |
| |
| void InitArcIterator(StateId s, ArcIteratorData<B> *data) { |
| if (!HasArcs(s)) |
| Expand(s); |
| CacheImpl<B>::InitArcIterator(s, data); |
| } |
| |
| uint64 Properties() const { return Properties(kFstProperties); } |
| |
| // Set error if found; return FST impl properties. |
| uint64 Properties(uint64 mask) const { |
| if ((mask & kError) && (fst_->Properties(kError, false) || |
| (mapper_->Properties(0) & kError))) |
| SetProperties(kError, kError); |
| return FstImpl<Arc>::Properties(mask); |
| } |
| |
| void Expand(StateId s) { |
| // Add exiting arcs. |
| for (mapper_->SetState(s); !mapper_->Done(); mapper_->Next()) |
| PushArc(s, mapper_->Value()); |
| SetArcs(s); |
| } |
| |
| private: |
| void Init() { |
| SetType("statemap"); |
| |
| if (mapper_->InputSymbolsAction() == MAP_COPY_SYMBOLS) |
| SetInputSymbols(fst_->InputSymbols()); |
| else if (mapper_->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) |
| SetInputSymbols(0); |
| |
| if (mapper_->OutputSymbolsAction() == MAP_COPY_SYMBOLS) |
| SetOutputSymbols(fst_->OutputSymbols()); |
| else if (mapper_->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) |
| SetOutputSymbols(0); |
| |
| uint64 props = fst_->Properties(kCopyProperties, false); |
| SetProperties(mapper_->Properties(props)); |
| } |
| |
| const Fst<A> *fst_; |
| C* mapper_; |
| bool own_mapper_; |
| |
| void operator=(const StateMapFstImpl<A, B, C> &); // disallow |
| }; |
| |
| |
| // Maps an arc type A to an arc type B using Mapper function object |
| // C. This version is a delayed Fst. |
| template <class A, class B, class C> |
| class StateMapFst : public ImplToFst< StateMapFstImpl<A, B, C> > { |
| public: |
| friend class ArcIterator< StateMapFst<A, B, C> >; |
| |
| typedef B Arc; |
| typedef typename B::Weight Weight; |
| typedef typename B::StateId StateId; |
| typedef CacheState<B> State; |
| typedef StateMapFstImpl<A, B, C> Impl; |
| |
| StateMapFst(const Fst<A> &fst, const C &mapper, |
| const StateMapFstOptions& opts) |
| : ImplToFst<Impl>(new Impl(fst, mapper, opts)) {} |
| |
| StateMapFst(const Fst<A> &fst, C* mapper, const StateMapFstOptions& opts) |
| : ImplToFst<Impl>(new Impl(fst, mapper, opts)) {} |
| |
| StateMapFst(const Fst<A> &fst, const C &mapper) |
| : ImplToFst<Impl>(new Impl(fst, mapper, StateMapFstOptions())) {} |
| |
| StateMapFst(const Fst<A> &fst, C* mapper) |
| : ImplToFst<Impl>(new Impl(fst, mapper, StateMapFstOptions())) {} |
| |
| // See Fst<>::Copy() for doc. |
| StateMapFst(const StateMapFst<A, B, C> &fst, bool safe = false) |
| : ImplToFst<Impl>(fst, safe) {} |
| |
| // Get a copy of this StateMapFst. See Fst<>::Copy() for further doc. |
| virtual StateMapFst<A, B, C> *Copy(bool safe = false) const { |
| return new StateMapFst<A, B, C>(*this, safe); |
| } |
| |
| virtual void InitStateIterator(StateIteratorData<A> *data) const { |
| GetImpl()->InitStateIterator(data); |
| } |
| |
| virtual void InitArcIterator(StateId s, ArcIteratorData<B> *data) const { |
| GetImpl()->InitArcIterator(s, data); |
| } |
| |
| private: |
| // Makes visible to friends. |
| Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); } |
| |
| void operator=(const StateMapFst<A, B, C> &fst); // disallow |
| }; |
| |
| |
| // Specialization for StateMapFst. |
| template <class A, class B, class C> |
| class ArcIterator< StateMapFst<A, B, C> > |
| : public CacheArcIterator< StateMapFst<A, B, C> > { |
| public: |
| typedef typename A::StateId StateId; |
| |
| ArcIterator(const StateMapFst<A, B, C> &fst, StateId s) |
| : CacheArcIterator< StateMapFst<A, B, C> >(fst.GetImpl(), s) { |
| if (!fst.GetImpl()->HasArcs(s)) |
| fst.GetImpl()->Expand(s); |
| } |
| |
| private: |
| DISALLOW_COPY_AND_ASSIGN(ArcIterator); |
| }; |
| |
| // |
| // Utility Mappers |
| // |
| |
| // Mapper that returns its input. |
| template <class A> |
| class IdentityStateMapper { |
| public: |
| typedef A FromArc; |
| typedef A ToArc; |
| |
| typedef typename A::StateId StateId; |
| typedef typename A::Weight Weight; |
| |
| explicit IdentityStateMapper(const Fst<A> &fst) : fst_(fst), aiter_(0) {} |
| |
| // Allows updating Fst argument; pass only if changed. |
| IdentityStateMapper(const IdentityStateMapper<A> &mapper, |
| const Fst<A> *fst = 0) |
| : fst_(fst ? *fst : mapper.fst_), aiter_(0) {} |
| |
| ~IdentityStateMapper() { delete aiter_; } |
| |
| StateId Start() const { return fst_.Start(); } |
| |
| Weight Final(StateId s) const { return fst_.Final(s); } |
| |
| void SetState(StateId s) { |
| if (aiter_) delete aiter_; |
| aiter_ = new ArcIterator< Fst<A> >(fst_, s); |
| } |
| |
| bool Done() const { return aiter_->Done(); } |
| const A &Value() const { return aiter_->Value(); } |
| void Next() { aiter_->Next(); } |
| |
| MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; } |
| MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS;} |
| |
| uint64 Properties(uint64 props) const { return props; } |
| |
| private: |
| const Fst<A> &fst_; |
| ArcIterator< Fst<A> > *aiter_; |
| }; |
| |
| template <class A> |
| class ArcSumMapper { |
| public: |
| typedef A FromArc; |
| typedef A ToArc; |
| |
| typedef typename A::StateId StateId; |
| typedef typename A::Weight Weight; |
| |
| explicit ArcSumMapper(const Fst<A> &fst) : fst_(fst), i_(0) {} |
| |
| // Allows updating Fst argument; pass only if changed. |
| ArcSumMapper(const ArcSumMapper<A> &mapper, |
| const Fst<A> *fst = 0) |
| : fst_(fst ? *fst : mapper.fst_), i_(0) {} |
| |
| StateId Start() const { return fst_.Start(); } |
| Weight Final(StateId s) const { return fst_.Final(s); } |
| |
| void SetState(StateId s) { |
| i_ = 0; |
| arcs_.clear(); |
| arcs_.reserve(fst_.NumArcs(s)); |
| for (ArcIterator<Fst<A> > aiter(fst_, s); !aiter.Done(); aiter.Next()) |
| arcs_.push_back(aiter.Value()); |
| |
| // First sorts the exiting arcs by input label, output label |
| // and destination state and then sums weights of arcs with |
| // the same input label, output label, and destination state. |
| sort(arcs_.begin(), arcs_.end(), comp_); |
| size_t narcs = 0; |
| for (size_t i = 0; i < arcs_.size(); ++i) { |
| if (narcs > 0 && equal_(arcs_[i], arcs_[narcs - 1])) { |
| arcs_[narcs - 1].weight = Plus(arcs_[narcs - 1].weight, |
| arcs_[i].weight); |
| } else { |
| arcs_[narcs++] = arcs_[i]; |
| } |
| } |
| arcs_.resize(narcs); |
| } |
| |
| bool Done() const { return i_ >= arcs_.size(); } |
| const A &Value() const { return arcs_[i_]; } |
| void Next() { ++i_; } |
| |
| MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; } |
| MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; } |
| |
| uint64 Properties(uint64 props) const { |
| return props & kArcSortProperties & |
| kDeleteArcsProperties & kWeightInvariantProperties; |
| } |
| |
| private: |
| struct Compare { |
| bool operator()(const A& x, const A& y) { |
| if (x.ilabel < y.ilabel) return true; |
| if (x.ilabel > y.ilabel) return false; |
| if (x.olabel < y.olabel) return true; |
| if (x.olabel > y.olabel) return false; |
| if (x.nextstate < y.nextstate) return true; |
| if (x.nextstate > y.nextstate) return false; |
| return false; |
| } |
| }; |
| |
| struct Equal { |
| bool operator()(const A& x, const A& y) { |
| return (x.ilabel == y.ilabel && |
| x.olabel == y.olabel && |
| x.nextstate == y.nextstate); |
| } |
| }; |
| |
| const Fst<A> &fst_; |
| Compare comp_; |
| Equal equal_; |
| vector<A> arcs_; |
| ssize_t i_; // current arc position |
| |
| void operator=(const ArcSumMapper<A> &); // disallow |
| }; |
| |
| template <class A> |
| class ArcUniqueMapper { |
| public: |
| typedef A FromArc; |
| typedef A ToArc; |
| |
| typedef typename A::StateId StateId; |
| typedef typename A::Weight Weight; |
| |
| explicit ArcUniqueMapper(const Fst<A> &fst) : fst_(fst), i_(0) {} |
| |
| // Allows updating Fst argument; pass only if changed. |
| ArcUniqueMapper(const ArcSumMapper<A> &mapper, |
| const Fst<A> *fst = 0) |
| : fst_(fst ? *fst : mapper.fst_), i_(0) {} |
| |
| StateId Start() const { return fst_.Start(); } |
| Weight Final(StateId s) const { return fst_.Final(s); } |
| |
| void SetState(StateId s) { |
| i_ = 0; |
| arcs_.clear(); |
| arcs_.reserve(fst_.NumArcs(s)); |
| for (ArcIterator<Fst<A> > aiter(fst_, s); !aiter.Done(); aiter.Next()) |
| arcs_.push_back(aiter.Value()); |
| |
| // First sorts the exiting arcs by input label, output label |
| // and destination state and then uniques identical arcs |
| sort(arcs_.begin(), arcs_.end(), comp_); |
| typename vector<A>::iterator unique_end = |
| unique(arcs_.begin(), arcs_.end(), equal_); |
| arcs_.resize(unique_end - arcs_.begin()); |
| } |
| |
| bool Done() const { return i_ >= arcs_.size(); } |
| const A &Value() const { return arcs_[i_]; } |
| void Next() { ++i_; } |
| |
| MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; } |
| MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; } |
| |
| uint64 Properties(uint64 props) const { |
| return props & kArcSortProperties & kDeleteArcsProperties; |
| } |
| |
| private: |
| struct Compare { |
| bool operator()(const A& x, const A& y) { |
| if (x.ilabel < y.ilabel) return true; |
| if (x.ilabel > y.ilabel) return false; |
| if (x.olabel < y.olabel) return true; |
| if (x.olabel > y.olabel) return false; |
| if (x.nextstate < y.nextstate) return true; |
| if (x.nextstate > y.nextstate) return false; |
| return false; |
| } |
| }; |
| |
| struct Equal { |
| bool operator()(const A& x, const A& y) { |
| return (x.ilabel == y.ilabel && |
| x.olabel == y.olabel && |
| x.nextstate == y.nextstate && |
| x.weight == y.weight); |
| } |
| }; |
| |
| const Fst<A> &fst_; |
| Compare comp_; |
| Equal equal_; |
| vector<A> arcs_; |
| ssize_t i_; // current arc position |
| |
| void operator=(const ArcUniqueMapper<A> &); // disallow |
| }; |
| |
| |
| } // namespace fst |
| |
| #endif // FST_LIB_STATE_MAP_H__ |