blob: 93ebe76d38ff13f03bc7ffa6a69190db9c899586 [file] [log] [blame]
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// All Rights Reserved.
//
// Author : Johan Schalkwyk
//
// \file
// Classes to provide symbol-to-integer and integer-to-symbol mappings.
#ifndef FST_LIB_SYMBOL_TABLE_H__
#define FST_LIB_SYMBOL_TABLE_H__
#include <cstring>
#include <string>
#include <utility>
using std::pair; using std::make_pair;
#include <vector>
using std::vector;
#include <fst/compat.h>
#include <iostream>
#include <fstream>
#include <map>
DECLARE_bool(fst_compat_symbols);
namespace fst {
// WARNING: Reading via symbol table read options should
// not be used. This is a temporary work around for
// reading symbol ranges of previously stored symbol sets.
struct SymbolTableReadOptions {
SymbolTableReadOptions() { }
SymbolTableReadOptions(vector<pair<int64, int64> > string_hash_ranges_,
const string& source_)
: string_hash_ranges(string_hash_ranges_),
source(source_) { }
vector<pair<int64, int64> > string_hash_ranges;
string source;
};
class SymbolTableImpl {
public:
SymbolTableImpl(const string &name)
: name_(name),
available_key_(0),
dense_key_limit_(0),
check_sum_finalized_(false) {}
explicit SymbolTableImpl(const SymbolTableImpl& impl)
: name_(impl.name_),
available_key_(0),
dense_key_limit_(0),
check_sum_finalized_(false) {
for (size_t i = 0; i < impl.symbols_.size(); ++i) {
AddSymbol(impl.symbols_[i], impl.Find(impl.symbols_[i]));
}
}
~SymbolTableImpl() {
for (size_t i = 0; i < symbols_.size(); ++i)
delete[] symbols_[i];
}
// TODO(johans): Add flag to specify whether the symbol
// should be indexed as string or int or both.
int64 AddSymbol(const string& symbol, int64 key);
int64 AddSymbol(const string& symbol) {
int64 key = Find(symbol);
return (key == -1) ? AddSymbol(symbol, available_key_++) : key;
}
static SymbolTableImpl* ReadText(istream &strm,
const string &name,
bool allow_negative = false);
static SymbolTableImpl* Read(istream &strm,
const SymbolTableReadOptions& opts);
bool Write(ostream &strm) const;
//
// Return the string associated with the key. If the key is out of
// range (<0, >max), return an empty string.
string Find(int64 key) const {
if (key >=0 && key < dense_key_limit_)
return string(symbols_[key]);
map<int64, const char*>::const_iterator it =
key_map_.find(key);
if (it == key_map_.end()) {
return "";
}
return string(it->second);
}
//
// Return the key associated with the symbol. If the symbol
// does not exists, return SymbolTable::kNoSymbol.
int64 Find(const string& symbol) const {
return Find(symbol.c_str());
}
//
// Return the key associated with the symbol. If the symbol
// does not exists, return SymbolTable::kNoSymbol.
int64 Find(const char* symbol) const {
map<const char *, int64, StrCmp>::const_iterator it =
symbol_map_.find(symbol);
if (it == symbol_map_.end()) {
return -1;
}
return it->second;
}
int64 GetNthKey(ssize_t pos) const {
if ((pos < 0) || (pos >= symbols_.size())) return -1;
else return Find(symbols_[pos]);
}
const string& Name() const { return name_; }
int IncrRefCount() const {
return ref_count_.Incr();
}
int DecrRefCount() const {
return ref_count_.Decr();
}
int RefCount() const {
return ref_count_.count();
}
string CheckSum() const {
MutexLock check_sum_lock(&check_sum_mutex_);
MaybeRecomputeCheckSum();
return check_sum_string_;
}
string LabeledCheckSum() const {
MutexLock check_sum_lock(&check_sum_mutex_);
MaybeRecomputeCheckSum();
return labeled_check_sum_string_;
}
int64 AvailableKey() const {
return available_key_;
}
size_t NumSymbols() const {
return symbols_.size();
}
private:
// Recomputes the checksums (both of them) if we've had changes since the last
// computation (i.e., if check_sum_finalized_ is false).
void MaybeRecomputeCheckSum() const;
struct StrCmp {
bool operator()(const char *s1, const char *s2) const {
return strcmp(s1, s2) < 0;
}
};
string name_;
int64 available_key_;
int64 dense_key_limit_;
vector<const char *> symbols_;
map<int64, const char*> key_map_;
map<const char *, int64, StrCmp> symbol_map_;
mutable RefCounter ref_count_;
mutable bool check_sum_finalized_;
mutable CheckSummer check_sum_;
mutable CheckSummer labeled_check_sum_;
mutable string check_sum_string_;
mutable string labeled_check_sum_string_;
mutable Mutex check_sum_mutex_;
};
//
// \class SymbolTable
// \brief Symbol (string) to int and reverse mapping
//
// The SymbolTable implements the mappings of labels to strings and reverse.
// SymbolTables are used to describe the alphabet of the input and output
// labels for arcs in a Finite State Transducer.
//
// SymbolTables are reference counted and can therefore be shared across
// multiple machines. For example a language model grammar G, with a
// SymbolTable for the words in the language model can share this symbol
// table with the lexical representation L o G.
//
class SymbolTable {
public:
static const int64 kNoSymbol = -1;
// Construct symbol table with a unique name.
SymbolTable(const string& name) : impl_(new SymbolTableImpl(name)) {}
// Create a reference counted copy.
SymbolTable(const SymbolTable& table) : impl_(table.impl_) {
impl_->IncrRefCount();
}
// Derefence implentation object. When reference count hits 0, delete
// implementation.
virtual ~SymbolTable() {
if (!impl_->DecrRefCount()) delete impl_;
}
// Read an ascii representation of the symbol table from an istream. Pass a
// name to give the resulting SymbolTable.
static SymbolTable* ReadText(istream &strm,
const string& name,
bool allow_negative = false) {
SymbolTableImpl* impl = SymbolTableImpl::ReadText(strm,
name,
allow_negative);
if (!impl)
return 0;
else
return new SymbolTable(impl);
}
// read an ascii representation of the symbol table
static SymbolTable* ReadText(const string& filename,
bool allow_negative = false) {
ifstream strm(filename.c_str(), ifstream::in);
if (!strm) {
LOG(ERROR) << "SymbolTable::ReadText: Can't open file " << filename;
return 0;
}
return ReadText(strm, filename, allow_negative);
}
// WARNING: Reading via symbol table read options should
// not be used. This is a temporary work around.
static SymbolTable* Read(istream &strm,
const SymbolTableReadOptions& opts) {
SymbolTableImpl* impl = SymbolTableImpl::Read(strm, opts);
if (!impl)
return 0;
else
return new SymbolTable(impl);
}
// read a binary dump of the symbol table from a stream
static SymbolTable* Read(istream &strm, const string& source) {
SymbolTableReadOptions opts;
opts.source = source;
return Read(strm, opts);
}
// read a binary dump of the symbol table
static SymbolTable* Read(const string& filename) {
ifstream strm(filename.c_str(), ifstream::in | ifstream::binary);
if (!strm) {
LOG(ERROR) << "SymbolTable::Read: Can't open file " << filename;
return 0;
}
return Read(strm, filename);
}
//--------------------------------------------------------
// Derivable Interface (final)
//--------------------------------------------------------
// create a reference counted copy
virtual SymbolTable* Copy() const {
return new SymbolTable(*this);
}
// Add a symbol with given key to table. A symbol table also
// keeps track of the last available key (highest key value in
// the symbol table).
virtual int64 AddSymbol(const string& symbol, int64 key) {
MutateCheck();
return impl_->AddSymbol(symbol, key);
}
// Add a symbol to the table. The associated value key is automatically
// assigned by the symbol table.
virtual int64 AddSymbol(const string& symbol) {
MutateCheck();
return impl_->AddSymbol(symbol);
}
// Add another symbol table to this table. All key values will be offset
// by the current available key (highest key value in the symbol table).
// Note string symbols with the same key value with still have the same
// key value after the symbol table has been merged, but a different
// value. Adding symbol tables do not result in changes in the base table.
virtual void AddTable(const SymbolTable& table);
// return the name of the symbol table
virtual const string& Name() const {
return impl_->Name();
}
// Return the label-agnostic MD5 check-sum for this table. All new symbols
// added to the table will result in an updated checksum.
// DEPRECATED.
virtual string CheckSum() const {
return impl_->CheckSum();
}
// Same as CheckSum(), but this returns an label-dependent version.
virtual string LabeledCheckSum() const {
return impl_->LabeledCheckSum();
}
virtual bool Write(ostream &strm) const {
return impl_->Write(strm);
}
bool Write(const string& filename) const {
ofstream strm(filename.c_str(), ofstream::out | ofstream::binary);
if (!strm) {
LOG(ERROR) << "SymbolTable::Write: Can't open file " << filename;
return false;
}
return Write(strm);
}
// Dump an ascii text representation of the symbol table via a stream
virtual bool WriteText(ostream &strm) const;
// Dump an ascii text representation of the symbol table
bool WriteText(const string& filename) const {
ofstream strm(filename.c_str());
if (!strm) {
LOG(ERROR) << "SymbolTable::WriteText: Can't open file " << filename;
return false;
}
return WriteText(strm);
}
// Return the string associated with the key. If the key is out of
// range (<0, >max), log error and return an empty string.
virtual string Find(int64 key) const {
return impl_->Find(key);
}
// Return the key associated with the symbol. If the symbol
// does not exists, log error and return SymbolTable::kNoSymbol
virtual int64 Find(const string& symbol) const {
return impl_->Find(symbol);
}
// Return the key associated with the symbol. If the symbol
// does not exists, log error and return SymbolTable::kNoSymbol
virtual int64 Find(const char* symbol) const {
return impl_->Find(symbol);
}
// Return the current available key (i.e highest key number+1) in
// the symbol table
virtual int64 AvailableKey(void) const {
return impl_->AvailableKey();
}
// Return the current number of symbols in table (not necessarily
// equal to AvailableKey())
virtual size_t NumSymbols(void) const {
return impl_->NumSymbols();
}
virtual int64 GetNthKey(ssize_t pos) const {
return impl_->GetNthKey(pos);
}
private:
explicit SymbolTable(SymbolTableImpl* impl) : impl_(impl) {}
void MutateCheck() {
// Copy on write
if (impl_->RefCount() > 1) {
impl_->DecrRefCount();
impl_ = new SymbolTableImpl(*impl_);
}
}
const SymbolTableImpl* Impl() const {
return impl_;
}
private:
SymbolTableImpl* impl_;
void operator=(const SymbolTable &table); // disallow
};
//
// \class SymbolTableIterator
// \brief Iterator class for symbols in a symbol table
class SymbolTableIterator {
public:
SymbolTableIterator(const SymbolTable& table)
: table_(table),
pos_(0),
nsymbols_(table.NumSymbols()),
key_(table.GetNthKey(0)) { }
~SymbolTableIterator() { }
// is iterator done
bool Done(void) {
return (pos_ == nsymbols_);
}
// return the Value() of the current symbol (int64 key)
int64 Value(void) {
return key_;
}
// return the string of the current symbol
string Symbol(void) {
return table_.Find(key_);
}
// advance iterator forward
void Next(void) {
++pos_;
if (pos_ < nsymbols_) key_ = table_.GetNthKey(pos_);
}
// reset iterator
void Reset(void) {
pos_ = 0;
key_ = table_.GetNthKey(0);
}
private:
const SymbolTable& table_;
ssize_t pos_;
size_t nsymbols_;
int64 key_;
};
// Tests compatibilty between two sets of symbol tables
inline bool CompatSymbols(const SymbolTable *syms1, const SymbolTable *syms2,
bool warning = true) {
if (!FLAGS_fst_compat_symbols) {
return true;
} else if (!syms1 && !syms2) {
return true;
} else if (syms1 && !syms2) {
if (warning)
LOG(WARNING) <<
"CompatSymbols: first symbol table present but second missing";
return false;
} else if (!syms1 && syms2) {
if (warning)
LOG(WARNING) <<
"CompatSymbols: second symbol table present but first missing";
return false;
} else if (syms1->LabeledCheckSum() != syms2->LabeledCheckSum()) {
if (warning)
LOG(WARNING) << "CompatSymbols: Symbol table check sums do not match";
return false;
} else {
return true;
}
}
// Relabels a symbol table as specified by the input vector of pairs
// (old label, new label). The new symbol table only retains symbols
// for which a relabeling is *explicitely* specified.
// TODO(allauzen): consider adding options to allow for some form
// of implicit identity relabeling.
template <class Label>
SymbolTable *RelabelSymbolTable(const SymbolTable *table,
const vector<pair<Label, Label> > &pairs) {
SymbolTable *new_table = new SymbolTable(
table->Name().empty() ? string() :
(string("relabeled_") + table->Name()));
for (size_t i = 0; i < pairs.size(); ++i)
new_table->AddSymbol(table->Find(pairs[i].first), pairs[i].second);
return new_table;
}
} // namespace fst
#endif // FST_LIB_SYMBOL_TABLE_H__