| |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| // |
| // Copyright 2005-2010 Google, Inc. |
| // All Rights Reserved. |
| // |
| // Author : Johan Schalkwyk |
| // |
| // \file |
| // Classes to provide symbol-to-integer and integer-to-symbol mappings. |
| |
| #ifndef FST_LIB_SYMBOL_TABLE_H__ |
| #define FST_LIB_SYMBOL_TABLE_H__ |
| |
| #include <cstring> |
| #include <string> |
| #include <utility> |
| using std::pair; using std::make_pair; |
| #include <vector> |
| using std::vector; |
| |
| |
| #include <fst/compat.h> |
| #include <iostream> |
| #include <fstream> |
| |
| |
| #include <map> |
| |
| DECLARE_bool(fst_compat_symbols); |
| |
| namespace fst { |
| |
| // WARNING: Reading via symbol table read options should |
| // not be used. This is a temporary work around for |
| // reading symbol ranges of previously stored symbol sets. |
| struct SymbolTableReadOptions { |
| SymbolTableReadOptions() { } |
| |
| SymbolTableReadOptions(vector<pair<int64, int64> > string_hash_ranges_, |
| const string& source_) |
| : string_hash_ranges(string_hash_ranges_), |
| source(source_) { } |
| |
| vector<pair<int64, int64> > string_hash_ranges; |
| string source; |
| }; |
| |
| class SymbolTableImpl { |
| public: |
| SymbolTableImpl(const string &name) |
| : name_(name), |
| available_key_(0), |
| dense_key_limit_(0), |
| check_sum_finalized_(false) {} |
| |
| explicit SymbolTableImpl(const SymbolTableImpl& impl) |
| : name_(impl.name_), |
| available_key_(0), |
| dense_key_limit_(0), |
| check_sum_finalized_(false) { |
| for (size_t i = 0; i < impl.symbols_.size(); ++i) { |
| AddSymbol(impl.symbols_[i], impl.Find(impl.symbols_[i])); |
| } |
| } |
| |
| ~SymbolTableImpl() { |
| for (size_t i = 0; i < symbols_.size(); ++i) |
| delete[] symbols_[i]; |
| } |
| |
| // TODO(johans): Add flag to specify whether the symbol |
| // should be indexed as string or int or both. |
| int64 AddSymbol(const string& symbol, int64 key); |
| |
| int64 AddSymbol(const string& symbol) { |
| int64 key = Find(symbol); |
| return (key == -1) ? AddSymbol(symbol, available_key_++) : key; |
| } |
| |
| static SymbolTableImpl* ReadText(istream &strm, |
| const string &name, |
| bool allow_negative = false); |
| |
| static SymbolTableImpl* Read(istream &strm, |
| const SymbolTableReadOptions& opts); |
| |
| bool Write(ostream &strm) const; |
| |
| // |
| // Return the string associated with the key. If the key is out of |
| // range (<0, >max), return an empty string. |
| string Find(int64 key) const { |
| if (key >=0 && key < dense_key_limit_) |
| return string(symbols_[key]); |
| |
| map<int64, const char*>::const_iterator it = |
| key_map_.find(key); |
| if (it == key_map_.end()) { |
| return ""; |
| } |
| return string(it->second); |
| } |
| |
| // |
| // Return the key associated with the symbol. If the symbol |
| // does not exists, return SymbolTable::kNoSymbol. |
| int64 Find(const string& symbol) const { |
| return Find(symbol.c_str()); |
| } |
| |
| // |
| // Return the key associated with the symbol. If the symbol |
| // does not exists, return SymbolTable::kNoSymbol. |
| int64 Find(const char* symbol) const { |
| map<const char *, int64, StrCmp>::const_iterator it = |
| symbol_map_.find(symbol); |
| if (it == symbol_map_.end()) { |
| return -1; |
| } |
| return it->second; |
| } |
| |
| int64 GetNthKey(ssize_t pos) const { |
| if ((pos < 0) || (pos >= symbols_.size())) return -1; |
| else return Find(symbols_[pos]); |
| } |
| |
| const string& Name() const { return name_; } |
| |
| int IncrRefCount() const { |
| return ref_count_.Incr(); |
| } |
| int DecrRefCount() const { |
| return ref_count_.Decr(); |
| } |
| int RefCount() const { |
| return ref_count_.count(); |
| } |
| |
| string CheckSum() const { |
| MutexLock check_sum_lock(&check_sum_mutex_); |
| MaybeRecomputeCheckSum(); |
| return check_sum_string_; |
| } |
| |
| string LabeledCheckSum() const { |
| MutexLock check_sum_lock(&check_sum_mutex_); |
| MaybeRecomputeCheckSum(); |
| return labeled_check_sum_string_; |
| } |
| |
| int64 AvailableKey() const { |
| return available_key_; |
| } |
| |
| size_t NumSymbols() const { |
| return symbols_.size(); |
| } |
| |
| private: |
| // Recomputes the checksums (both of them) if we've had changes since the last |
| // computation (i.e., if check_sum_finalized_ is false). |
| void MaybeRecomputeCheckSum() const; |
| |
| struct StrCmp { |
| bool operator()(const char *s1, const char *s2) const { |
| return strcmp(s1, s2) < 0; |
| } |
| }; |
| |
| string name_; |
| int64 available_key_; |
| int64 dense_key_limit_; |
| vector<const char *> symbols_; |
| map<int64, const char*> key_map_; |
| map<const char *, int64, StrCmp> symbol_map_; |
| |
| mutable RefCounter ref_count_; |
| mutable bool check_sum_finalized_; |
| mutable CheckSummer check_sum_; |
| mutable CheckSummer labeled_check_sum_; |
| mutable string check_sum_string_; |
| mutable string labeled_check_sum_string_; |
| mutable Mutex check_sum_mutex_; |
| }; |
| |
| // |
| // \class SymbolTable |
| // \brief Symbol (string) to int and reverse mapping |
| // |
| // The SymbolTable implements the mappings of labels to strings and reverse. |
| // SymbolTables are used to describe the alphabet of the input and output |
| // labels for arcs in a Finite State Transducer. |
| // |
| // SymbolTables are reference counted and can therefore be shared across |
| // multiple machines. For example a language model grammar G, with a |
| // SymbolTable for the words in the language model can share this symbol |
| // table with the lexical representation L o G. |
| // |
| class SymbolTable { |
| public: |
| static const int64 kNoSymbol = -1; |
| |
| // Construct symbol table with a unique name. |
| SymbolTable(const string& name) : impl_(new SymbolTableImpl(name)) {} |
| |
| // Create a reference counted copy. |
| SymbolTable(const SymbolTable& table) : impl_(table.impl_) { |
| impl_->IncrRefCount(); |
| } |
| |
| // Derefence implentation object. When reference count hits 0, delete |
| // implementation. |
| virtual ~SymbolTable() { |
| if (!impl_->DecrRefCount()) delete impl_; |
| } |
| |
| // Read an ascii representation of the symbol table from an istream. Pass a |
| // name to give the resulting SymbolTable. |
| static SymbolTable* ReadText(istream &strm, |
| const string& name, |
| bool allow_negative = false) { |
| SymbolTableImpl* impl = SymbolTableImpl::ReadText(strm, |
| name, |
| allow_negative); |
| if (!impl) |
| return 0; |
| else |
| return new SymbolTable(impl); |
| } |
| |
| // read an ascii representation of the symbol table |
| static SymbolTable* ReadText(const string& filename, |
| bool allow_negative = false) { |
| ifstream strm(filename.c_str(), ifstream::in); |
| if (!strm) { |
| LOG(ERROR) << "SymbolTable::ReadText: Can't open file " << filename; |
| return 0; |
| } |
| return ReadText(strm, filename, allow_negative); |
| } |
| |
| |
| // WARNING: Reading via symbol table read options should |
| // not be used. This is a temporary work around. |
| static SymbolTable* Read(istream &strm, |
| const SymbolTableReadOptions& opts) { |
| SymbolTableImpl* impl = SymbolTableImpl::Read(strm, opts); |
| if (!impl) |
| return 0; |
| else |
| return new SymbolTable(impl); |
| } |
| |
| // read a binary dump of the symbol table from a stream |
| static SymbolTable* Read(istream &strm, const string& source) { |
| SymbolTableReadOptions opts; |
| opts.source = source; |
| return Read(strm, opts); |
| } |
| |
| // read a binary dump of the symbol table |
| static SymbolTable* Read(const string& filename) { |
| ifstream strm(filename.c_str(), ifstream::in | ifstream::binary); |
| if (!strm) { |
| LOG(ERROR) << "SymbolTable::Read: Can't open file " << filename; |
| return 0; |
| } |
| return Read(strm, filename); |
| } |
| |
| //-------------------------------------------------------- |
| // Derivable Interface (final) |
| //-------------------------------------------------------- |
| // create a reference counted copy |
| virtual SymbolTable* Copy() const { |
| return new SymbolTable(*this); |
| } |
| |
| // Add a symbol with given key to table. A symbol table also |
| // keeps track of the last available key (highest key value in |
| // the symbol table). |
| virtual int64 AddSymbol(const string& symbol, int64 key) { |
| MutateCheck(); |
| return impl_->AddSymbol(symbol, key); |
| } |
| |
| // Add a symbol to the table. The associated value key is automatically |
| // assigned by the symbol table. |
| virtual int64 AddSymbol(const string& symbol) { |
| MutateCheck(); |
| return impl_->AddSymbol(symbol); |
| } |
| |
| // Add another symbol table to this table. All key values will be offset |
| // by the current available key (highest key value in the symbol table). |
| // Note string symbols with the same key value with still have the same |
| // key value after the symbol table has been merged, but a different |
| // value. Adding symbol tables do not result in changes in the base table. |
| virtual void AddTable(const SymbolTable& table); |
| |
| // return the name of the symbol table |
| virtual const string& Name() const { |
| return impl_->Name(); |
| } |
| |
| // Return the label-agnostic MD5 check-sum for this table. All new symbols |
| // added to the table will result in an updated checksum. |
| // DEPRECATED. |
| virtual string CheckSum() const { |
| return impl_->CheckSum(); |
| } |
| |
| // Same as CheckSum(), but this returns an label-dependent version. |
| virtual string LabeledCheckSum() const { |
| return impl_->LabeledCheckSum(); |
| } |
| |
| virtual bool Write(ostream &strm) const { |
| return impl_->Write(strm); |
| } |
| |
| bool Write(const string& filename) const { |
| ofstream strm(filename.c_str(), ofstream::out | ofstream::binary); |
| if (!strm) { |
| LOG(ERROR) << "SymbolTable::Write: Can't open file " << filename; |
| return false; |
| } |
| return Write(strm); |
| } |
| |
| // Dump an ascii text representation of the symbol table via a stream |
| virtual bool WriteText(ostream &strm) const; |
| |
| // Dump an ascii text representation of the symbol table |
| bool WriteText(const string& filename) const { |
| ofstream strm(filename.c_str()); |
| if (!strm) { |
| LOG(ERROR) << "SymbolTable::WriteText: Can't open file " << filename; |
| return false; |
| } |
| return WriteText(strm); |
| } |
| |
| // Return the string associated with the key. If the key is out of |
| // range (<0, >max), log error and return an empty string. |
| virtual string Find(int64 key) const { |
| return impl_->Find(key); |
| } |
| |
| // Return the key associated with the symbol. If the symbol |
| // does not exists, log error and return SymbolTable::kNoSymbol |
| virtual int64 Find(const string& symbol) const { |
| return impl_->Find(symbol); |
| } |
| |
| // Return the key associated with the symbol. If the symbol |
| // does not exists, log error and return SymbolTable::kNoSymbol |
| virtual int64 Find(const char* symbol) const { |
| return impl_->Find(symbol); |
| } |
| |
| // Return the current available key (i.e highest key number+1) in |
| // the symbol table |
| virtual int64 AvailableKey(void) const { |
| return impl_->AvailableKey(); |
| } |
| |
| // Return the current number of symbols in table (not necessarily |
| // equal to AvailableKey()) |
| virtual size_t NumSymbols(void) const { |
| return impl_->NumSymbols(); |
| } |
| |
| virtual int64 GetNthKey(ssize_t pos) const { |
| return impl_->GetNthKey(pos); |
| } |
| |
| private: |
| explicit SymbolTable(SymbolTableImpl* impl) : impl_(impl) {} |
| |
| void MutateCheck() { |
| // Copy on write |
| if (impl_->RefCount() > 1) { |
| impl_->DecrRefCount(); |
| impl_ = new SymbolTableImpl(*impl_); |
| } |
| } |
| |
| const SymbolTableImpl* Impl() const { |
| return impl_; |
| } |
| |
| private: |
| SymbolTableImpl* impl_; |
| |
| void operator=(const SymbolTable &table); // disallow |
| }; |
| |
| |
| // |
| // \class SymbolTableIterator |
| // \brief Iterator class for symbols in a symbol table |
| class SymbolTableIterator { |
| public: |
| SymbolTableIterator(const SymbolTable& table) |
| : table_(table), |
| pos_(0), |
| nsymbols_(table.NumSymbols()), |
| key_(table.GetNthKey(0)) { } |
| |
| ~SymbolTableIterator() { } |
| |
| // is iterator done |
| bool Done(void) { |
| return (pos_ == nsymbols_); |
| } |
| |
| // return the Value() of the current symbol (int64 key) |
| int64 Value(void) { |
| return key_; |
| } |
| |
| // return the string of the current symbol |
| string Symbol(void) { |
| return table_.Find(key_); |
| } |
| |
| // advance iterator forward |
| void Next(void) { |
| ++pos_; |
| if (pos_ < nsymbols_) key_ = table_.GetNthKey(pos_); |
| } |
| |
| // reset iterator |
| void Reset(void) { |
| pos_ = 0; |
| key_ = table_.GetNthKey(0); |
| } |
| |
| private: |
| const SymbolTable& table_; |
| ssize_t pos_; |
| size_t nsymbols_; |
| int64 key_; |
| }; |
| |
| |
| // Tests compatibilty between two sets of symbol tables |
| inline bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2, |
| bool warning = true) { |
| if (!FLAGS_fst_compat_symbols) { |
| return true; |
| } else if (!syms1 && !syms2) { |
| return true; |
| } else if (syms1 && !syms2) { |
| if (warning) |
| LOG(WARNING) << |
| "CompatSymbols: first symbol table present but second missing"; |
| return false; |
| } else if (!syms1 && syms2) { |
| if (warning) |
| LOG(WARNING) << |
| "CompatSymbols: second symbol table present but first missing"; |
| return false; |
| } else if (syms1->LabeledCheckSum() != syms2->LabeledCheckSum()) { |
| if (warning) |
| LOG(WARNING) << "CompatSymbols: Symbol table check sums do not match"; |
| return false; |
| } else { |
| return true; |
| } |
| } |
| |
| |
| // Relabels a symbol table as specified by the input vector of pairs |
| // (old label, new label). The new symbol table only retains symbols |
| // for which a relabeling is *explicitely* specified. |
| // TODO(allauzen): consider adding options to allow for some form |
| // of implicit identity relabeling. |
| template <class Label> |
| SymbolTable *RelabelSymbolTable(const SymbolTable *table, |
| const vector<pair<Label, Label> > &pairs) { |
| SymbolTable *new_table = new SymbolTable( |
| table->Name().empty() ? string() : |
| (string("relabeled_") + table->Name())); |
| |
| for (size_t i = 0; i < pairs.size(); ++i) |
| new_table->AddSymbol(table->Find(pairs[i].first), pairs[i].second); |
| |
| return new_table; |
| } |
| |
| } // namespace fst |
| |
| #endif // FST_LIB_SYMBOL_TABLE_H__ |