| |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| // |
| // Copyright 2005-2010 Google, Inc. |
| // Author: sorenj@google.com (Jeffrey Sorensen) |
| // |
| #ifndef FST_EXTENSIONS_NGRAM_NGRAM_FST_H_ |
| #define FST_EXTENSIONS_NGRAM_NGRAM_FST_H_ |
| |
| #include <stddef.h> |
| #include <string.h> |
| #include <algorithm> |
| #include <string> |
| #include <vector> |
| using std::vector; |
| |
| #include <fst/compat.h> |
| #include <fst/fstlib.h> |
| #include <fst/mapped-file.h> |
| #include <fst/extensions/ngram/bitmap-index.h> |
| |
| // NgramFst implements a n-gram language model based upon the LOUDS data |
| // structure. Please refer to "Unary Data Strucutres for Language Models" |
| // http://research.google.com/pubs/archive/37218.pdf |
| |
| namespace fst { |
| template <class A> class NGramFst; |
| template <class A> class NGramFstMatcher; |
| |
| // Instance data containing mutable state for bookkeeping repeated access to |
| // the same state. |
| template <class A> |
| struct NGramFstInst { |
| typedef typename A::Label Label; |
| typedef typename A::StateId StateId; |
| typedef typename A::Weight Weight; |
| StateId state_; |
| size_t num_futures_; |
| size_t offset_; |
| size_t node_; |
| StateId node_state_; |
| vector<Label> context_; |
| StateId context_state_; |
| NGramFstInst() |
| : state_(kNoStateId), node_state_(kNoStateId), |
| context_state_(kNoStateId) { } |
| }; |
| |
| // Implementation class for LOUDS based NgramFst interface |
| template <class A> |
| class NGramFstImpl : public FstImpl<A> { |
| using FstImpl<A>::SetInputSymbols; |
| using FstImpl<A>::SetOutputSymbols; |
| using FstImpl<A>::SetType; |
| using FstImpl<A>::WriteHeader; |
| |
| friend class ArcIterator<NGramFst<A> >; |
| friend class NGramFstMatcher<A>; |
| |
| public: |
| using FstImpl<A>::InputSymbols; |
| using FstImpl<A>::SetProperties; |
| using FstImpl<A>::Properties; |
| |
| typedef A Arc; |
| typedef typename A::Label Label; |
| typedef typename A::StateId StateId; |
| typedef typename A::Weight Weight; |
| |
| NGramFstImpl() : data_region_(0), data_(0), owned_(false) { |
| SetType("ngram"); |
| SetInputSymbols(NULL); |
| SetOutputSymbols(NULL); |
| SetProperties(kStaticProperties); |
| } |
| |
| NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out); |
| |
| ~NGramFstImpl() { |
| if (owned_) { |
| delete [] data_; |
| } |
| delete data_region_; |
| } |
| |
| static NGramFstImpl<A>* Read(istream &strm, // NOLINT |
| const FstReadOptions &opts) { |
| NGramFstImpl<A>* impl = new NGramFstImpl(); |
| FstHeader hdr; |
| if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr)) return 0; |
| uint64 num_states, num_futures, num_final; |
| const size_t offset = sizeof(num_states) + sizeof(num_futures) + |
| sizeof(num_final); |
| // Peek at num_states and num_futures to see how much more needs to be read. |
| strm.read(reinterpret_cast<char *>(&num_states), sizeof(num_states)); |
| strm.read(reinterpret_cast<char *>(&num_futures), sizeof(num_futures)); |
| strm.read(reinterpret_cast<char *>(&num_final), sizeof(num_final)); |
| size_t size = Storage(num_states, num_futures, num_final); |
| MappedFile *data_region = MappedFile::Allocate(size); |
| char *data = reinterpret_cast<char *>(data_region->mutable_data()); |
| // Copy num_states, num_futures and num_final back into data. |
| memcpy(data, reinterpret_cast<char *>(&num_states), sizeof(num_states)); |
| memcpy(data + sizeof(num_states), reinterpret_cast<char *>(&num_futures), |
| sizeof(num_futures)); |
| memcpy(data + sizeof(num_states) + sizeof(num_futures), |
| reinterpret_cast<char *>(&num_final), sizeof(num_final)); |
| strm.read(data + offset, size - offset); |
| if (!strm) { |
| delete impl; |
| return NULL; |
| } |
| impl->Init(data, false, data_region); |
| return impl; |
| } |
| |
| bool Write(ostream &strm, // NOLINT |
| const FstWriteOptions &opts) const { |
| FstHeader hdr; |
| hdr.SetStart(Start()); |
| hdr.SetNumStates(num_states_); |
| WriteHeader(strm, opts, kFileVersion, &hdr); |
| strm.write(data_, StorageSize()); |
| return strm; |
| } |
| |
| StateId Start() const { |
| return 1; |
| } |
| |
| Weight Final(StateId state) const { |
| if (final_index_.Get(state)) { |
| return final_probs_[final_index_.Rank1(state)]; |
| } else { |
| return Weight::Zero(); |
| } |
| } |
| |
| size_t NumArcs(StateId state, NGramFstInst<A> *inst = NULL) const { |
| if (inst == NULL) { |
| const size_t next_zero = future_index_.Select0(state + 1); |
| const size_t this_zero = future_index_.Select0(state); |
| return next_zero - this_zero - 1; |
| } |
| SetInstFuture(state, inst); |
| return inst->num_futures_ + ((state == 0) ? 0 : 1); |
| } |
| |
| size_t NumInputEpsilons(StateId state) const { |
| // State 0 has no parent, thus no backoff. |
| if (state == 0) return 0; |
| return 1; |
| } |
| |
| size_t NumOutputEpsilons(StateId state) const { |
| return NumInputEpsilons(state); |
| } |
| |
| StateId NumStates() const { |
| return num_states_; |
| } |
| |
| void InitStateIterator(StateIteratorData<A>* data) const { |
| data->base = 0; |
| data->nstates = num_states_; |
| } |
| |
| static size_t Storage(uint64 num_states, uint64 num_futures, |
| uint64 num_final) { |
| uint64 b64; |
| Weight weight; |
| Label label; |
| size_t offset = sizeof(num_states) + sizeof(num_futures) + |
| sizeof(num_final); |
| offset += sizeof(b64) * ( |
| BitmapIndex::StorageSize(num_states * 2 + 1) + |
| BitmapIndex::StorageSize(num_futures + num_states + 1) + |
| BitmapIndex::StorageSize(num_states)); |
| offset += (num_states + 1) * sizeof(label) + num_futures * sizeof(label); |
| // Pad for alignemnt, see |
| // http://en.wikipedia.org/wiki/Data_structure_alignment#Computing_padding |
| offset = (offset + sizeof(weight) - 1) & ~(sizeof(weight) - 1); |
| offset += (num_states + 1) * sizeof(weight) + num_final * sizeof(weight) + |
| (num_futures + 1) * sizeof(weight); |
| return offset; |
| } |
| |
| void SetInstFuture(StateId state, NGramFstInst<A> *inst) const { |
| if (inst->state_ != state) { |
| inst->state_ = state; |
| const size_t next_zero = future_index_.Select0(state + 1); |
| const size_t this_zero = future_index_.Select0(state); |
| inst->num_futures_ = next_zero - this_zero - 1; |
| inst->offset_ = future_index_.Rank1(future_index_.Select0(state) + 1); |
| } |
| } |
| |
| void SetInstNode(NGramFstInst<A> *inst) const { |
| if (inst->node_state_ != inst->state_) { |
| inst->node_state_ = inst->state_; |
| inst->node_ = context_index_.Select1(inst->state_); |
| } |
| } |
| |
| void SetInstContext(NGramFstInst<A> *inst) const { |
| SetInstNode(inst); |
| if (inst->context_state_ != inst->state_) { |
| inst->context_state_ = inst->state_; |
| inst->context_.clear(); |
| size_t node = inst->node_; |
| while (node != 0) { |
| inst->context_.push_back(context_words_[context_index_.Rank1(node)]); |
| node = context_index_.Select1(context_index_.Rank0(node) - 1); |
| } |
| } |
| } |
| |
| // Access to the underlying representation |
| const char* GetData(size_t* data_size) const { |
| *data_size = StorageSize(); |
| return data_; |
| } |
| |
| void Init(const char* data, bool owned, MappedFile *file = 0); |
| |
| const vector<Label> &GetContext(StateId s, NGramFstInst<A> *inst) const { |
| SetInstFuture(s, inst); |
| SetInstContext(inst); |
| return inst->context_; |
| } |
| |
| size_t StorageSize() const { |
| return Storage(num_states_, num_futures_, num_final_); |
| } |
| |
| void GetStates(const vector<Label>& context, vector<StateId> *states) const; |
| |
| private: |
| StateId Transition(const vector<Label> &context, Label future) const; |
| |
| // Properties always true for this Fst class. |
| static const uint64 kStaticProperties = kAcceptor | kIDeterministic | |
| kODeterministic | kEpsilons | kIEpsilons | kOEpsilons | kILabelSorted | |
| kOLabelSorted | kWeighted | kCyclic | kInitialAcyclic | kNotTopSorted | |
| kAccessible | kCoAccessible | kNotString | kExpanded; |
| // Current file format version. |
| static const int kFileVersion = 4; |
| // Minimum file format version supported. |
| static const int kMinFileVersion = 4; |
| |
| MappedFile *data_region_; |
| const char* data_; |
| bool owned_; // True if we own data_ |
| uint64 num_states_, num_futures_, num_final_; |
| size_t root_num_children_; |
| const Label *root_children_; |
| size_t root_first_child_; |
| // borrowed references |
| const uint64 *context_, *future_, *final_; |
| const Label *context_words_, *future_words_; |
| const Weight *backoff_, *final_probs_, *future_probs_; |
| BitmapIndex context_index_; |
| BitmapIndex future_index_; |
| BitmapIndex final_index_; |
| |
| void operator=(const NGramFstImpl<A> &); // Disallow |
| }; |
| |
| template<typename A> |
| NGramFstImpl<A>::NGramFstImpl(const Fst<A> &fst, vector<StateId>* order_out) |
| : data_region_(0), data_(0), owned_(false) { |
| typedef A Arc; |
| typedef typename Arc::Label Label; |
| typedef typename Arc::Weight Weight; |
| typedef typename Arc::StateId StateId; |
| SetType("ngram"); |
| SetInputSymbols(fst.InputSymbols()); |
| SetOutputSymbols(fst.OutputSymbols()); |
| SetProperties(kStaticProperties); |
| |
| // Check basic requirements for an OpenGRM language model Fst. |
| int64 props = kAcceptor | kIDeterministic | kIEpsilons | kILabelSorted; |
| if (fst.Properties(props, true) != props) { |
| FSTERROR() << "NGramFst only accepts OpenGRM langauge models as input"; |
| SetProperties(kError, kError); |
| return; |
| } |
| |
| int64 num_states = CountStates(fst); |
| Label* context = new Label[num_states]; |
| |
| // Find the unigram state by starting from the start state, following |
| // epsilons. |
| StateId unigram = fst.Start(); |
| while (1) { |
| if (unigram == kNoStateId) { |
| FSTERROR() << "Could not identify unigram state."; |
| SetProperties(kError, kError); |
| return; |
| } |
| ArcIterator<Fst<A> > aiter(fst, unigram); |
| if (aiter.Done()) { |
| LOG(WARNING) << "Unigram state " << unigram << " has no arcs."; |
| break; |
| } |
| if (aiter.Value().ilabel != 0) break; |
| unigram = aiter.Value().nextstate; |
| } |
| |
| // Each state's context is determined by the subtree it is under from the |
| // unigram state. |
| queue<pair<StateId, Label> > label_queue; |
| vector<bool> visited(num_states); |
| // Force an epsilon link to the start state. |
| label_queue.push(make_pair(fst.Start(), 0)); |
| for (ArcIterator<Fst<A> > aiter(fst, unigram); |
| !aiter.Done(); aiter.Next()) { |
| label_queue.push(make_pair(aiter.Value().nextstate, aiter.Value().ilabel)); |
| } |
| // investigate states in breadth first fashion to assign context words. |
| while (!label_queue.empty()) { |
| pair<StateId, Label> &now = label_queue.front(); |
| if (!visited[now.first]) { |
| context[now.first] = now.second; |
| visited[now.first] = true; |
| for (ArcIterator<Fst<A> > aiter(fst, now.first); |
| !aiter.Done(); aiter.Next()) { |
| const Arc &arc = aiter.Value(); |
| if (arc.ilabel != 0) { |
| label_queue.push(make_pair(arc.nextstate, now.second)); |
| } |
| } |
| } |
| label_queue.pop(); |
| } |
| visited.clear(); |
| |
| // The arc from the start state should be assigned an epsilon to put it |
| // in front of the all other labels (which makes Start state 1 after |
| // unigram which is state 0). |
| context[fst.Start()] = 0; |
| |
| // Build the tree of contexts fst by reversing the epsilon arcs from fst. |
| VectorFst<Arc> context_fst; |
| uint64 num_final = 0; |
| for (int i = 0; i < num_states; ++i) { |
| if (fst.Final(i) != Weight::Zero()) { |
| ++num_final; |
| } |
| context_fst.SetFinal(context_fst.AddState(), fst.Final(i)); |
| } |
| context_fst.SetStart(unigram); |
| context_fst.SetInputSymbols(fst.InputSymbols()); |
| context_fst.SetOutputSymbols(fst.OutputSymbols()); |
| int64 num_context_arcs = 0; |
| int64 num_futures = 0; |
| for (StateIterator<Fst<A> > siter(fst); !siter.Done(); siter.Next()) { |
| const StateId &state = siter.Value(); |
| num_futures += fst.NumArcs(state) - fst.NumInputEpsilons(state); |
| ArcIterator<Fst<A> > aiter(fst, state); |
| if (!aiter.Done()) { |
| const Arc &arc = aiter.Value(); |
| // this arc goes from state to arc.nextstate, so create an arc from |
| // arc.nextstate to state to reverse it. |
| if (arc.ilabel == 0) { |
| context_fst.AddArc(arc.nextstate, Arc(context[state], context[state], |
| arc.weight, state)); |
| num_context_arcs++; |
| } |
| } |
| } |
| if (num_context_arcs != context_fst.NumStates() - 1) { |
| FSTERROR() << "Number of contexts arcs != number of states - 1"; |
| SetProperties(kError, kError); |
| return; |
| } |
| if (context_fst.NumStates() != num_states) { |
| FSTERROR() << "Number of contexts != number of states"; |
| SetProperties(kError, kError); |
| return; |
| } |
| int64 context_props = context_fst.Properties(kIDeterministic | |
| kILabelSorted, true); |
| if (!(context_props & kIDeterministic)) { |
| FSTERROR() << "Input fst is not structured properly"; |
| SetProperties(kError, kError); |
| return; |
| } |
| if (!(context_props & kILabelSorted)) { |
| ArcSort(&context_fst, ILabelCompare<Arc>()); |
| } |
| |
| delete [] context; |
| |
| uint64 b64; |
| Weight weight; |
| Label label = kNoLabel; |
| const size_t storage = Storage(num_states, num_futures, num_final); |
| MappedFile *data_region = MappedFile::Allocate(storage); |
| char *data = reinterpret_cast<char *>(data_region->mutable_data()); |
| memset(data, 0, storage); |
| size_t offset = 0; |
| memcpy(data + offset, reinterpret_cast<char *>(&num_states), |
| sizeof(num_states)); |
| offset += sizeof(num_states); |
| memcpy(data + offset, reinterpret_cast<char *>(&num_futures), |
| sizeof(num_futures)); |
| offset += sizeof(num_futures); |
| memcpy(data + offset, reinterpret_cast<char *>(&num_final), |
| sizeof(num_final)); |
| offset += sizeof(num_final); |
| uint64* context_bits = reinterpret_cast<uint64*>(data + offset); |
| offset += BitmapIndex::StorageSize(num_states * 2 + 1) * sizeof(b64); |
| uint64* future_bits = reinterpret_cast<uint64*>(data + offset); |
| offset += |
| BitmapIndex::StorageSize(num_futures + num_states + 1) * sizeof(b64); |
| uint64* final_bits = reinterpret_cast<uint64*>(data + offset); |
| offset += BitmapIndex::StorageSize(num_states) * sizeof(b64); |
| Label* context_words = reinterpret_cast<Label*>(data + offset); |
| offset += (num_states + 1) * sizeof(label); |
| Label* future_words = reinterpret_cast<Label*>(data + offset); |
| offset += num_futures * sizeof(label); |
| offset = (offset + sizeof(weight) - 1) & ~(sizeof(weight) - 1); |
| Weight* backoff = reinterpret_cast<Weight*>(data + offset); |
| offset += (num_states + 1) * sizeof(weight); |
| Weight* final_probs = reinterpret_cast<Weight*>(data + offset); |
| offset += num_final * sizeof(weight); |
| Weight* future_probs = reinterpret_cast<Weight*>(data + offset); |
| int64 context_arc = 0, future_arc = 0, context_bit = 0, future_bit = 0, |
| final_bit = 0; |
| |
| // pseudo-root bits |
| BitmapIndex::Set(context_bits, context_bit++); |
| ++context_bit; |
| context_words[context_arc] = label; |
| backoff[context_arc] = Weight::Zero(); |
| context_arc++; |
| |
| ++future_bit; |
| if (order_out) { |
| order_out->clear(); |
| order_out->resize(num_states); |
| } |
| |
| queue<StateId> context_q; |
| context_q.push(context_fst.Start()); |
| StateId state_number = 0; |
| while (!context_q.empty()) { |
| const StateId &state = context_q.front(); |
| if (order_out) { |
| (*order_out)[state] = state_number; |
| } |
| |
| const Weight &final = context_fst.Final(state); |
| if (final != Weight::Zero()) { |
| BitmapIndex::Set(final_bits, state_number); |
| final_probs[final_bit] = final; |
| ++final_bit; |
| } |
| |
| for (ArcIterator<VectorFst<A> > aiter(context_fst, state); |
| !aiter.Done(); aiter.Next()) { |
| const Arc &arc = aiter.Value(); |
| context_words[context_arc] = arc.ilabel; |
| backoff[context_arc] = arc.weight; |
| ++context_arc; |
| BitmapIndex::Set(context_bits, context_bit++); |
| context_q.push(arc.nextstate); |
| } |
| ++context_bit; |
| |
| for (ArcIterator<Fst<A> > aiter(fst, state); !aiter.Done(); aiter.Next()) { |
| const Arc &arc = aiter.Value(); |
| if (arc.ilabel != 0) { |
| future_words[future_arc] = arc.ilabel; |
| future_probs[future_arc] = arc.weight; |
| ++future_arc; |
| BitmapIndex::Set(future_bits, future_bit++); |
| } |
| } |
| ++future_bit; |
| ++state_number; |
| context_q.pop(); |
| } |
| |
| if ((state_number != num_states) || |
| (context_bit != num_states * 2 + 1) || |
| (context_arc != num_states) || |
| (future_arc != num_futures) || |
| (future_bit != num_futures + num_states + 1) || |
| (final_bit != num_final)) { |
| FSTERROR() << "Structure problems detected during construction"; |
| SetProperties(kError, kError); |
| return; |
| } |
| |
| Init(data, false, data_region); |
| } |
| |
| template<typename A> |
| inline void NGramFstImpl<A>::Init(const char* data, bool owned, |
| MappedFile *data_region) { |
| if (owned_) { |
| delete [] data_; |
| } |
| delete data_region_; |
| data_region_ = data_region; |
| owned_ = owned; |
| data_ = data; |
| size_t offset = 0; |
| num_states_ = *(reinterpret_cast<const uint64*>(data_ + offset)); |
| offset += sizeof(num_states_); |
| num_futures_ = *(reinterpret_cast<const uint64*>(data_ + offset)); |
| offset += sizeof(num_futures_); |
| num_final_ = *(reinterpret_cast<const uint64*>(data_ + offset)); |
| offset += sizeof(num_final_); |
| uint64 bits; |
| size_t context_bits = num_states_ * 2 + 1; |
| size_t future_bits = num_futures_ + num_states_ + 1; |
| context_ = reinterpret_cast<const uint64*>(data_ + offset); |
| offset += BitmapIndex::StorageSize(context_bits) * sizeof(bits); |
| future_ = reinterpret_cast<const uint64*>(data_ + offset); |
| offset += BitmapIndex::StorageSize(future_bits) * sizeof(bits); |
| final_ = reinterpret_cast<const uint64*>(data_ + offset); |
| offset += BitmapIndex::StorageSize(num_states_) * sizeof(bits); |
| context_words_ = reinterpret_cast<const Label*>(data_ + offset); |
| offset += (num_states_ + 1) * sizeof(*context_words_); |
| future_words_ = reinterpret_cast<const Label*>(data_ + offset); |
| offset += num_futures_ * sizeof(*future_words_); |
| offset = (offset + sizeof(*backoff_) - 1) & ~(sizeof(*backoff_) - 1); |
| backoff_ = reinterpret_cast<const Weight*>(data_ + offset); |
| offset += (num_states_ + 1) * sizeof(*backoff_); |
| final_probs_ = reinterpret_cast<const Weight*>(data_ + offset); |
| offset += num_final_ * sizeof(*final_probs_); |
| future_probs_ = reinterpret_cast<const Weight*>(data_ + offset); |
| |
| context_index_.BuildIndex(context_, context_bits); |
| future_index_.BuildIndex(future_, future_bits); |
| final_index_.BuildIndex(final_, num_states_); |
| |
| const size_t node_rank = context_index_.Rank1(0); |
| root_first_child_ = context_index_.Select0(node_rank) + 1; |
| if (context_index_.Get(root_first_child_) == false) { |
| FSTERROR() << "Missing unigrams"; |
| SetProperties(kError, kError); |
| return; |
| } |
| const size_t last_child = context_index_.Select0(node_rank + 1) - 1; |
| root_num_children_ = last_child - root_first_child_ + 1; |
| root_children_ = context_words_ + context_index_.Rank1(root_first_child_); |
| } |
| |
| template<typename A> |
| inline typename A::StateId NGramFstImpl<A>::Transition( |
| const vector<Label> &context, Label future) const { |
| const Label *children = root_children_; |
| const Label *loc = lower_bound(children, children + root_num_children_, |
| future); |
| if (loc == children + root_num_children_ || *loc != future) { |
| return context_index_.Rank1(0); |
| } |
| size_t node = root_first_child_ + loc - children; |
| size_t node_rank = context_index_.Rank1(node); |
| size_t first_child = context_index_.Select0(node_rank) + 1; |
| if (context_index_.Get(first_child) == false) { |
| return context_index_.Rank1(node); |
| } |
| size_t last_child = context_index_.Select0(node_rank + 1) - 1; |
| for (int word = context.size() - 1; word >= 0; --word) { |
| children = context_words_ + context_index_.Rank1(first_child); |
| loc = lower_bound(children, children + last_child - first_child + 1, |
| context[word]); |
| if (loc == children + last_child - first_child + 1 || |
| *loc != context[word]) { |
| break; |
| } |
| node = first_child + loc - children; |
| node_rank = context_index_.Rank1(node); |
| first_child = context_index_.Select0(node_rank) + 1; |
| if (context_index_.Get(first_child) == false) break; |
| last_child = context_index_.Select0(node_rank + 1) - 1; |
| } |
| return context_index_.Rank1(node); |
| } |
| |
| template<typename A> |
| inline void NGramFstImpl<A>::GetStates( |
| const vector<Label> &context, |
| vector<typename A::StateId>* states) const { |
| states->clear(); |
| states->push_back(0); |
| typename vector<Label>::const_reverse_iterator cit = context.rbegin(); |
| const Label *children = root_children_; |
| const Label *loc = lower_bound(children, children + root_num_children_, *cit); |
| if (loc == children + root_num_children_ || *loc != *cit) return; |
| size_t node = root_first_child_ + loc - children; |
| states->push_back(context_index_.Rank1(node)); |
| if (context.size() == 1) return; |
| size_t node_rank = context_index_.Rank1(node); |
| size_t first_child = context_index_.Select0(node_rank) + 1; |
| ++cit; |
| if (context_index_.Get(first_child) != false) { |
| size_t last_child = context_index_.Select0(node_rank + 1) - 1; |
| while (cit != context.rend()) { |
| children = context_words_ + context_index_.Rank1(first_child); |
| loc = lower_bound(children, children + last_child - first_child + 1, |
| *cit); |
| if (loc == children + last_child - first_child + 1 || *loc != *cit) { |
| break; |
| } |
| ++cit; |
| node = first_child + loc - children; |
| states->push_back(context_index_.Rank1(node)); |
| node_rank = context_index_.Rank1(node); |
| first_child = context_index_.Select0(node_rank) + 1; |
| if (context_index_.Get(first_child) == false) break; |
| last_child = context_index_.Select0(node_rank + 1) - 1; |
| } |
| } |
| } |
| |
| /*****************************************************************************/ |
| template<class A> |
| class NGramFst : public ImplToExpandedFst<NGramFstImpl<A> > { |
| friend class ArcIterator<NGramFst<A> >; |
| friend class NGramFstMatcher<A>; |
| |
| public: |
| typedef A Arc; |
| typedef typename A::StateId StateId; |
| typedef typename A::Label Label; |
| typedef typename A::Weight Weight; |
| typedef NGramFstImpl<A> Impl; |
| |
| explicit NGramFst(const Fst<A> &dst) |
| : ImplToExpandedFst<Impl>(new Impl(dst, NULL)) {} |
| |
| NGramFst(const Fst<A> &fst, vector<StateId>* order_out) |
| : ImplToExpandedFst<Impl>(new Impl(fst, order_out)) {} |
| |
| // Because the NGramFstImpl is a const stateless data structure, there |
| // is never a need to do anything beside copy the reference. |
| NGramFst(const NGramFst<A> &fst, bool safe = false) |
| : ImplToExpandedFst<Impl>(fst, false) {} |
| |
| NGramFst() : ImplToExpandedFst<Impl>(new Impl()) {} |
| |
| // Non-standard constructor to initialize NGramFst directly from data. |
| NGramFst(const char* data, bool owned) : ImplToExpandedFst<Impl>(new Impl()) { |
| GetImpl()->Init(data, owned, NULL); |
| } |
| |
| // Get method that gets the data associated with Init(). |
| const char* GetData(size_t* data_size) const { |
| return GetImpl()->GetData(data_size); |
| } |
| |
| const vector<Label> GetContext(StateId s) const { |
| return GetImpl()->GetContext(s, &inst_); |
| } |
| |
| // Consumes as much as possible of context from right to left, returns the |
| // the states corresponding to the increasingly conditioned input sequence. |
| void GetStates(const vector<Label>& context, vector<StateId> *state) const { |
| return GetImpl()->GetStates(context, state); |
| } |
| |
| virtual size_t NumArcs(StateId s) const { |
| return GetImpl()->NumArcs(s, &inst_); |
| } |
| |
| virtual NGramFst<A>* Copy(bool safe = false) const { |
| return new NGramFst(*this, safe); |
| } |
| |
| static NGramFst<A>* Read(istream &strm, const FstReadOptions &opts) { |
| Impl* impl = Impl::Read(strm, opts); |
| return impl ? new NGramFst<A>(impl) : 0; |
| } |
| |
| static NGramFst<A>* Read(const string &filename) { |
| if (!filename.empty()) { |
| ifstream strm(filename.c_str(), ifstream::in | ifstream::binary); |
| if (!strm) { |
| LOG(ERROR) << "NGramFst::Read: Can't open file: " << filename; |
| return 0; |
| } |
| return Read(strm, FstReadOptions(filename)); |
| } else { |
| return Read(cin, FstReadOptions("standard input")); |
| } |
| } |
| |
| virtual bool Write(ostream &strm, const FstWriteOptions &opts) const { |
| return GetImpl()->Write(strm, opts); |
| } |
| |
| virtual bool Write(const string &filename) const { |
| return Fst<A>::WriteFile(filename); |
| } |
| |
| virtual inline void InitStateIterator(StateIteratorData<A>* data) const { |
| GetImpl()->InitStateIterator(data); |
| } |
| |
| virtual inline void InitArcIterator( |
| StateId s, ArcIteratorData<A>* data) const; |
| |
| virtual MatcherBase<A>* InitMatcher(MatchType match_type) const { |
| return new NGramFstMatcher<A>(*this, match_type); |
| } |
| |
| size_t StorageSize() const { |
| return GetImpl()->StorageSize(); |
| } |
| |
| private: |
| explicit NGramFst(Impl* impl) : ImplToExpandedFst<Impl>(impl) {} |
| |
| Impl* GetImpl() const { |
| return |
| ImplToExpandedFst<Impl, ExpandedFst<A> >::GetImpl(); |
| } |
| |
| void SetImpl(Impl* impl, bool own_impl = true) { |
| ImplToExpandedFst<Impl, Fst<A> >::SetImpl(impl, own_impl); |
| } |
| |
| mutable NGramFstInst<A> inst_; |
| }; |
| |
| template <class A> inline void |
| NGramFst<A>::InitArcIterator(StateId s, ArcIteratorData<A>* data) const { |
| GetImpl()->SetInstFuture(s, &inst_); |
| GetImpl()->SetInstNode(&inst_); |
| data->base = new ArcIterator<NGramFst<A> >(*this, s); |
| } |
| |
| /*****************************************************************************/ |
| template <class A> |
| class NGramFstMatcher : public MatcherBase<A> { |
| public: |
| typedef A Arc; |
| typedef typename A::Label Label; |
| typedef typename A::StateId StateId; |
| typedef typename A::Weight Weight; |
| |
| NGramFstMatcher(const NGramFst<A> &fst, MatchType match_type) |
| : fst_(fst), inst_(fst.inst_), match_type_(match_type), |
| current_loop_(false), |
| loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) { |
| if (match_type_ == MATCH_OUTPUT) { |
| swap(loop_.ilabel, loop_.olabel); |
| } |
| } |
| |
| NGramFstMatcher(const NGramFstMatcher<A> &matcher, bool safe = false) |
| : fst_(matcher.fst_), inst_(matcher.inst_), |
| match_type_(matcher.match_type_), current_loop_(false), |
| loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) { |
| if (match_type_ == MATCH_OUTPUT) { |
| swap(loop_.ilabel, loop_.olabel); |
| } |
| } |
| |
| virtual NGramFstMatcher<A>* Copy(bool safe = false) const { |
| return new NGramFstMatcher<A>(*this, safe); |
| } |
| |
| virtual MatchType Type(bool test) const { |
| return match_type_; |
| } |
| |
| virtual const Fst<A> &GetFst() const { |
| return fst_; |
| } |
| |
| virtual uint64 Properties(uint64 props) const { |
| return props; |
| } |
| |
| private: |
| virtual void SetState_(StateId s) { |
| fst_.GetImpl()->SetInstFuture(s, &inst_); |
| current_loop_ = false; |
| } |
| |
| virtual bool Find_(Label label) { |
| const Label nolabel = kNoLabel; |
| done_ = true; |
| if (label == 0 || label == nolabel) { |
| if (label == 0) { |
| current_loop_ = true; |
| loop_.nextstate = inst_.state_; |
| } |
| // The unigram state has no epsilon arc. |
| if (inst_.state_ != 0) { |
| arc_.ilabel = arc_.olabel = 0; |
| fst_.GetImpl()->SetInstNode(&inst_); |
| arc_.nextstate = fst_.GetImpl()->context_index_.Rank1( |
| fst_.GetImpl()->context_index_.Select1( |
| fst_.GetImpl()->context_index_.Rank0(inst_.node_) - 1)); |
| arc_.weight = fst_.GetImpl()->backoff_[inst_.state_]; |
| done_ = false; |
| } |
| } else { |
| const Label *start = fst_.GetImpl()->future_words_ + inst_.offset_; |
| const Label *end = start + inst_.num_futures_; |
| const Label* search = lower_bound(start, end, label); |
| if (search != end && *search == label) { |
| size_t state = search - start; |
| arc_.ilabel = arc_.olabel = label; |
| arc_.weight = fst_.GetImpl()->future_probs_[inst_.offset_ + state]; |
| fst_.GetImpl()->SetInstContext(&inst_); |
| arc_.nextstate = fst_.GetImpl()->Transition(inst_.context_, label); |
| done_ = false; |
| } |
| } |
| return !Done_(); |
| } |
| |
| virtual bool Done_() const { |
| return !current_loop_ && done_; |
| } |
| |
| virtual const Arc& Value_() const { |
| return (current_loop_) ? loop_ : arc_; |
| } |
| |
| virtual void Next_() { |
| if (current_loop_) { |
| current_loop_ = false; |
| } else { |
| done_ = true; |
| } |
| } |
| |
| const NGramFst<A>& fst_; |
| NGramFstInst<A> inst_; |
| MatchType match_type_; // Supplied by caller |
| bool done_; |
| Arc arc_; |
| bool current_loop_; // Current arc is the implicit loop |
| Arc loop_; |
| }; |
| |
| /*****************************************************************************/ |
| template<class A> |
| class ArcIterator<NGramFst<A> > : public ArcIteratorBase<A> { |
| public: |
| typedef A Arc; |
| typedef typename A::Label Label; |
| typedef typename A::StateId StateId; |
| typedef typename A::Weight Weight; |
| |
| ArcIterator(const NGramFst<A> &fst, StateId state) |
| : lazy_(~0), impl_(fst.GetImpl()), i_(0), flags_(kArcValueFlags) { |
| inst_ = fst.inst_; |
| impl_->SetInstFuture(state, &inst_); |
| impl_->SetInstNode(&inst_); |
| } |
| |
| bool Done() const { |
| return i_ >= ((inst_.node_ == 0) ? inst_.num_futures_ : |
| inst_.num_futures_ + 1); |
| } |
| |
| const Arc &Value() const { |
| bool eps = (inst_.node_ != 0 && i_ == 0); |
| StateId state = (inst_.node_ == 0) ? i_ : i_ - 1; |
| if (flags_ & lazy_ & (kArcILabelValue | kArcOLabelValue)) { |
| arc_.ilabel = |
| arc_.olabel = eps ? 0 : impl_->future_words_[inst_.offset_ + state]; |
| lazy_ &= ~(kArcILabelValue | kArcOLabelValue); |
| } |
| if (flags_ & lazy_ & kArcNextStateValue) { |
| if (eps) { |
| arc_.nextstate = impl_->context_index_.Rank1( |
| impl_->context_index_.Select1( |
| impl_->context_index_.Rank0(inst_.node_) - 1)); |
| } else { |
| if (lazy_ & kArcNextStateValue) { |
| impl_->SetInstContext(&inst_); // first time only. |
| } |
| arc_.nextstate = |
| impl_->Transition(inst_.context_, |
| impl_->future_words_[inst_.offset_ + state]); |
| } |
| lazy_ &= ~kArcNextStateValue; |
| } |
| if (flags_ & lazy_ & kArcWeightValue) { |
| arc_.weight = eps ? impl_->backoff_[inst_.state_] : |
| impl_->future_probs_[inst_.offset_ + state]; |
| lazy_ &= ~kArcWeightValue; |
| } |
| return arc_; |
| } |
| |
| void Next() { |
| ++i_; |
| lazy_ = ~0; |
| } |
| |
| size_t Position() const { return i_; } |
| |
| void Reset() { |
| i_ = 0; |
| lazy_ = ~0; |
| } |
| |
| void Seek(size_t a) { |
| if (i_ != a) { |
| i_ = a; |
| lazy_ = ~0; |
| } |
| } |
| |
| uint32 Flags() const { |
| return flags_; |
| } |
| |
| void SetFlags(uint32 f, uint32 m) { |
| flags_ &= ~m; |
| flags_ |= (f & kArcValueFlags); |
| } |
| |
| private: |
| virtual bool Done_() const { return Done(); } |
| virtual const Arc& Value_() const { return Value(); } |
| virtual void Next_() { Next(); } |
| virtual size_t Position_() const { return Position(); } |
| virtual void Reset_() { Reset(); } |
| virtual void Seek_(size_t a) { Seek(a); } |
| uint32 Flags_() const { return Flags(); } |
| void SetFlags_(uint32 f, uint32 m) { SetFlags(f, m); } |
| |
| mutable Arc arc_; |
| mutable uint32 lazy_; |
| const NGramFstImpl<A> *impl_; |
| mutable NGramFstInst<A> inst_; |
| |
| size_t i_; |
| uint32 flags_; |
| |
| DISALLOW_COPY_AND_ASSIGN(ArcIterator); |
| }; |
| |
| /*****************************************************************************/ |
| // Specialization for NGramFst; see generic version in fst.h |
| // for sample usage (but use the ProdLmFst type!). This version |
| // should inline. |
| template <class A> |
| class StateIterator<NGramFst<A> > : public StateIteratorBase<A> { |
| public: |
| typedef typename A::StateId StateId; |
| |
| explicit StateIterator(const NGramFst<A> &fst) |
| : s_(0), num_states_(fst.NumStates()) { } |
| |
| bool Done() const { return s_ >= num_states_; } |
| StateId Value() const { return s_; } |
| void Next() { ++s_; } |
| void Reset() { s_ = 0; } |
| |
| private: |
| virtual bool Done_() const { return Done(); } |
| virtual StateId Value_() const { return Value(); } |
| virtual void Next_() { Next(); } |
| virtual void Reset_() { Reset(); } |
| |
| StateId s_, num_states_; |
| |
| DISALLOW_COPY_AND_ASSIGN(StateIterator); |
| }; |
| } // namespace fst |
| #endif // FST_EXTENSIONS_NGRAM_NGRAM_FST_H_ |