blob: 7245b459d9c9e94263fcc8b149bf3f17b0e013f6 [file] [log] [blame]
// encode.h
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: johans@google.com (Johan Schalkwyk)
//
// \file
// Class to encode and decoder an fst.
#ifndef FST_LIB_ENCODE_H__
#define FST_LIB_ENCODE_H__
#include <climits>
#include <unordered_map>
using std::tr1::unordered_map;
using std::tr1::unordered_multimap;
#include <string>
#include <vector>
using std::vector;
#include <fst/arc-map.h>
#include <fst/rmfinalepsilon.h>
namespace fst {
static const uint32 kEncodeLabels = 0x0001;
static const uint32 kEncodeWeights = 0x0002;
static const uint32 kEncodeFlags = 0x0003; // All non-internal flags
static const uint32 kEncodeHasISymbols = 0x0004; // For internal use
static const uint32 kEncodeHasOSymbols = 0x0008; // For internal use
enum EncodeType { ENCODE = 1, DECODE = 2 };
// Identifies stream data as an encode table (and its endianity)
static const int32 kEncodeMagicNumber = 2129983209;
// The following class encapsulates implementation details for the
// encoding and decoding of label/weight tuples used for encoding
// and decoding of Fsts. The EncodeTable is bidirectional. I.E it
// stores both the Tuple of encode labels and weights to a unique
// label, and the reverse.
template <class A> class EncodeTable {
public:
typedef typename A::Label Label;
typedef typename A::Weight Weight;
// Encoded data consists of arc input/output labels and arc weight
struct Tuple {
Tuple() {}
Tuple(Label ilabel_, Label olabel_, Weight weight_)
: ilabel(ilabel_), olabel(olabel_), weight(weight_) {}
Tuple(const Tuple& tuple)
: ilabel(tuple.ilabel), olabel(tuple.olabel), weight(tuple.weight) {}
Label ilabel;
Label olabel;
Weight weight;
};
// Comparison object for hashing EncodeTable Tuple(s).
class TupleEqual {
public:
bool operator()(const Tuple* x, const Tuple* y) const {
return (x->ilabel == y->ilabel &&
x->olabel == y->olabel &&
x->weight == y->weight);
}
};
// Hash function for EncodeTabe Tuples. Based on the encode flags
// we either hash the labels, weights or combination of them.
class TupleKey {
public:
TupleKey()
: encode_flags_(kEncodeLabels | kEncodeWeights) {}
TupleKey(const TupleKey& key)
: encode_flags_(key.encode_flags_) {}
explicit TupleKey(uint32 encode_flags)
: encode_flags_(encode_flags) {}
size_t operator()(const Tuple* x) const {
size_t hash = x->ilabel;
const int lshift = 5;
const int rshift = CHAR_BIT * sizeof(size_t) - 5;
if (encode_flags_ & kEncodeLabels)
hash = hash << lshift ^ hash >> rshift ^ x->olabel;
if (encode_flags_ & kEncodeWeights)
hash = hash << lshift ^ hash >> rshift ^ x->weight.Hash();
return hash;
}
private:
int32 encode_flags_;
};
typedef unordered_map<const Tuple*,
Label,
TupleKey,
TupleEqual> EncodeHash;
explicit EncodeTable(uint32 encode_flags)
: flags_(encode_flags),
encode_hash_(1024, TupleKey(encode_flags)),
isymbols_(0), osymbols_(0) {}
~EncodeTable() {
for (size_t i = 0; i < encode_tuples_.size(); ++i) {
delete encode_tuples_[i];
}
delete isymbols_;
delete osymbols_;
}
// Given an arc encode either input/ouptut labels or input/costs or both
Label Encode(const A &arc) {
const Tuple tuple(arc.ilabel,
flags_ & kEncodeLabels ? arc.olabel : 0,
flags_ & kEncodeWeights ? arc.weight : Weight::One());
typename EncodeHash::const_iterator it = encode_hash_.find(&tuple);
if (it == encode_hash_.end()) {
encode_tuples_.push_back(new Tuple(tuple));
encode_hash_[encode_tuples_.back()] = encode_tuples_.size();
return encode_tuples_.size();
} else {
return it->second;
}
}
// Given an arc, look up its encoded label. Returns kNoLabel if not found.
Label GetLabel(const A &arc) const {
const Tuple tuple(arc.ilabel,
flags_ & kEncodeLabels ? arc.olabel : 0,
flags_ & kEncodeWeights ? arc.weight : Weight::One());
typename EncodeHash::const_iterator it = encode_hash_.find(&tuple);
if (it == encode_hash_.end()) {
return kNoLabel;
} else {
return it->second;
}
}
// Given an encode arc Label decode back to input/output labels and costs
const Tuple* Decode(Label key) const {
if (key < 1 || key > encode_tuples_.size()) {
LOG(ERROR) << "EncodeTable::Decode: unknown decode key: " << key;
return 0;
}
return encode_tuples_[key - 1];
}
size_t Size() const { return encode_tuples_.size(); }
bool Write(ostream &strm, const string &source) const;
static EncodeTable<A> *Read(istream &strm, const string &source);
const uint32 flags() const { return flags_ & kEncodeFlags; }
int RefCount() const { return ref_count_.count(); }
int IncrRefCount() { return ref_count_.Incr(); }
int DecrRefCount() { return ref_count_.Decr(); }
SymbolTable *InputSymbols() const { return isymbols_; }
SymbolTable *OutputSymbols() const { return osymbols_; }
void SetInputSymbols(const SymbolTable* syms) {
if (isymbols_) delete isymbols_;
if (syms) {
isymbols_ = syms->Copy();
flags_ |= kEncodeHasISymbols;
} else {
isymbols_ = 0;
flags_ &= ~kEncodeHasISymbols;
}
}
void SetOutputSymbols(const SymbolTable* syms) {
if (osymbols_) delete osymbols_;
if (syms) {
osymbols_ = syms->Copy();
flags_ |= kEncodeHasOSymbols;
} else {
osymbols_ = 0;
flags_ &= ~kEncodeHasOSymbols;
}
}
private:
uint32 flags_;
vector<Tuple*> encode_tuples_;
EncodeHash encode_hash_;
RefCounter ref_count_;
SymbolTable *isymbols_; // Pre-encoded ilabel symbol table
SymbolTable *osymbols_; // Pre-encoded olabel symbol table
DISALLOW_COPY_AND_ASSIGN(EncodeTable);
};
template <class A> inline
bool EncodeTable<A>::Write(ostream &strm, const string &source) const {
WriteType(strm, kEncodeMagicNumber);
WriteType(strm, flags_);
int64 size = encode_tuples_.size();
WriteType(strm, size);
for (size_t i = 0; i < size; ++i) {
const Tuple* tuple = encode_tuples_[i];
WriteType(strm, tuple->ilabel);
WriteType(strm, tuple->olabel);
tuple->weight.Write(strm);
}
if (flags_ & kEncodeHasISymbols)
isymbols_->Write(strm);
if (flags_ & kEncodeHasOSymbols)
osymbols_->Write(strm);
strm.flush();
if (!strm) {
LOG(ERROR) << "EncodeTable::Write: write failed: " << source;
return false;
}
return true;
}
template <class A> inline
EncodeTable<A> *EncodeTable<A>::Read(istream &strm, const string &source) {
int32 magic_number = 0;
ReadType(strm, &magic_number);
if (magic_number != kEncodeMagicNumber) {
LOG(ERROR) << "EncodeTable::Read: Bad encode table header: " << source;
return 0;
}
uint32 flags;
ReadType(strm, &flags);
EncodeTable<A> *table = new EncodeTable<A>(flags);
int64 size;
ReadType(strm, &size);
if (!strm) {
LOG(ERROR) << "EncodeTable::Read: read failed: " << source;
return 0;
}
for (size_t i = 0; i < size; ++i) {
Tuple* tuple = new Tuple();
ReadType(strm, &tuple->ilabel);
ReadType(strm, &tuple->olabel);
tuple->weight.Read(strm);
if (!strm) {
LOG(ERROR) << "EncodeTable::Read: read failed: " << source;
return 0;
}
table->encode_tuples_.push_back(tuple);
table->encode_hash_[table->encode_tuples_.back()] =
table->encode_tuples_.size();
}
if (flags & kEncodeHasISymbols)
table->isymbols_ = SymbolTable::Read(strm, source);
if (flags & kEncodeHasOSymbols)
table->osymbols_ = SymbolTable::Read(strm, source);
return table;
}
// A mapper to encode/decode weighted transducers. Encoding of an
// Fst is useful for performing classical determinization or minimization
// on a weighted transducer by treating it as an unweighted acceptor over
// encoded labels.
//
// The Encode mapper stores the encoding in a local hash table (EncodeTable)
// This table is shared (and reference counted) between the encoder and
// decoder. A decoder has read only access to the EncodeTable.
//
// The EncodeMapper allows on the fly encoding of the machine. As the
// EncodeTable is generated the same table may by used to decode the machine
// on the fly. For example in the following sequence of operations
//
// Encode -> Determinize -> Decode
//
// we will use the encoding table generated during the encode step in the
// decode, even though the encoding is not complete.
//
template <class A> class EncodeMapper {
typedef typename A::Weight Weight;
typedef typename A::Label Label;
public:
EncodeMapper(uint32 flags, EncodeType type)
: flags_(flags),
type_(type),
table_(new EncodeTable<A>(flags)),
error_(false) {}
EncodeMapper(const EncodeMapper& mapper)
: flags_(mapper.flags_),
type_(mapper.type_),
table_(mapper.table_),
error_(false) {
table_->IncrRefCount();
}
// Copy constructor but setting the type, typically to DECODE
EncodeMapper(const EncodeMapper& mapper, EncodeType type)
: flags_(mapper.flags_),
type_(type),
table_(mapper.table_),
error_(mapper.error_) {
table_->IncrRefCount();
}
~EncodeMapper() {
if (!table_->DecrRefCount()) delete table_;
}
A operator()(const A &arc);
MapFinalAction FinalAction() const {
return (type_ == ENCODE && (flags_ & kEncodeWeights)) ?
MAP_REQUIRE_SUPERFINAL : MAP_NO_SUPERFINAL;
}
MapSymbolsAction InputSymbolsAction() const { return MAP_CLEAR_SYMBOLS; }
MapSymbolsAction OutputSymbolsAction() const { return MAP_CLEAR_SYMBOLS;}
uint64 Properties(uint64 inprops) {
uint64 outprops = inprops;
if (error_) outprops |= kError;
uint64 mask = kFstProperties;
if (flags_ & kEncodeLabels)
mask &= kILabelInvariantProperties & kOLabelInvariantProperties;
if (flags_ & kEncodeWeights)
mask &= kILabelInvariantProperties & kWeightInvariantProperties &
(type_ == ENCODE ? kAddSuperFinalProperties :
kRmSuperFinalProperties);
return outprops & mask;
}
const uint32 flags() const { return flags_; }
const EncodeType type() const { return type_; }
const EncodeTable<A> &table() const { return *table_; }
bool Write(ostream &strm, const string& source) {
return table_->Write(strm, source);
}
bool Write(const string& filename) {
ofstream strm(filename.c_str(), ofstream::out | ofstream::binary);
if (!strm) {
LOG(ERROR) << "EncodeMap: Can't open file: " << filename;
return false;
}
return Write(strm, filename);
}
static EncodeMapper<A> *Read(istream &strm,
const string& source,
EncodeType type = ENCODE) {
EncodeTable<A> *table = EncodeTable<A>::Read(strm, source);
return table ? new EncodeMapper(table->flags(), type, table) : 0;
}
static EncodeMapper<A> *Read(const string& filename,
EncodeType type = ENCODE) {
ifstream strm(filename.c_str(), ifstream::in | ifstream::binary);
if (!strm) {
LOG(ERROR) << "EncodeMap: Can't open file: " << filename;
return NULL;
}
return Read(strm, filename, type);
}
SymbolTable *InputSymbols() const { return table_->InputSymbols(); }
SymbolTable *OutputSymbols() const { return table_->OutputSymbols(); }
void SetInputSymbols(const SymbolTable* syms) {
table_->SetInputSymbols(syms);
}
void SetOutputSymbols(const SymbolTable* syms) {
table_->SetOutputSymbols(syms);
}
private:
uint32 flags_;
EncodeType type_;
EncodeTable<A>* table_;
bool error_;
explicit EncodeMapper(uint32 flags, EncodeType type, EncodeTable<A> *table)
: flags_(flags), type_(type), table_(table) {}
void operator=(const EncodeMapper &); // Disallow.
};
template <class A> inline
A EncodeMapper<A>::operator()(const A &arc) {
if (type_ == ENCODE) { // labels and/or weights to single label
if ((arc.nextstate == kNoStateId && !(flags_ & kEncodeWeights)) ||
(arc.nextstate == kNoStateId && (flags_ & kEncodeWeights) &&
arc.weight == Weight::Zero())) {
return arc;
} else {
Label label = table_->Encode(arc);
return A(label,
flags_ & kEncodeLabels ? label : arc.olabel,
flags_ & kEncodeWeights ? Weight::One() : arc.weight,
arc.nextstate);
}
} else { // type_ == DECODE
if (arc.nextstate == kNoStateId) {
return arc;
} else {
if (arc.ilabel == 0) return arc;
if (flags_ & kEncodeLabels && arc.ilabel != arc.olabel) {
FSTERROR() << "EncodeMapper: Label-encoded arc has different "
"input and output labels";
error_ = true;
}
if (flags_ & kEncodeWeights && arc.weight != Weight::One()) {
FSTERROR() <<
"EncodeMapper: Weight-encoded arc has non-trivial weight";
error_ = true;
}
const typename EncodeTable<A>::Tuple* tuple = table_->Decode(arc.ilabel);
if (!tuple) {
FSTERROR() << "EncodeMapper: decode failed";
error_ = true;
return A(kNoLabel, kNoLabel, Weight::NoWeight(), arc.nextstate);
} else {
return A(tuple->ilabel,
flags_ & kEncodeLabels ? tuple->olabel : arc.olabel,
flags_ & kEncodeWeights ? tuple->weight : arc.weight,
arc.nextstate);
}
}
}
}
// Complexity: O(nstates + narcs)
template<class A> inline
void Encode(MutableFst<A> *fst, EncodeMapper<A>* mapper) {
mapper->SetInputSymbols(fst->InputSymbols());
mapper->SetOutputSymbols(fst->OutputSymbols());
ArcMap(fst, mapper);
}
template<class A> inline
void Decode(MutableFst<A>* fst, const EncodeMapper<A>& mapper) {
ArcMap(fst, EncodeMapper<A>(mapper, DECODE));
RmFinalEpsilon(fst);
fst->SetInputSymbols(mapper.InputSymbols());
fst->SetOutputSymbols(mapper.OutputSymbols());
}
// On the fly label and/or weight encoding of input Fst
//
// Complexity:
// - Constructor: O(1)
// - Traversal: O(nstates_visited + narcs_visited), assuming constant
// time to visit an input state or arc.
template <class A>
class EncodeFst : public ArcMapFst<A, A, EncodeMapper<A> > {
public:
typedef A Arc;
typedef EncodeMapper<A> C;
typedef ArcMapFstImpl< A, A, EncodeMapper<A> > Impl;
using ImplToFst<Impl>::GetImpl;
EncodeFst(const Fst<A> &fst, EncodeMapper<A>* encoder)
: ArcMapFst<A, A, C>(fst, encoder, ArcMapFstOptions()) {
encoder->SetInputSymbols(fst.InputSymbols());
encoder->SetOutputSymbols(fst.OutputSymbols());
}
EncodeFst(const Fst<A> &fst, const EncodeMapper<A>& encoder)
: ArcMapFst<A, A, C>(fst, encoder, ArcMapFstOptions()) {}
// See Fst<>::Copy() for doc.
EncodeFst(const EncodeFst<A> &fst, bool copy = false)
: ArcMapFst<A, A, C>(fst, copy) {}
// Get a copy of this EncodeFst. See Fst<>::Copy() for further doc.
virtual EncodeFst<A> *Copy(bool safe = false) const {
if (safe) {
FSTERROR() << "EncodeFst::Copy(true): not allowed.";
GetImpl()->SetProperties(kError, kError);
}
return new EncodeFst(*this);
}
};
// On the fly label and/or weight encoding of input Fst
//
// Complexity:
// - Constructor: O(1)
// - Traversal: O(nstates_visited + narcs_visited), assuming constant
// time to visit an input state or arc.
template <class A>
class DecodeFst : public ArcMapFst<A, A, EncodeMapper<A> > {
public:
typedef A Arc;
typedef EncodeMapper<A> C;
typedef ArcMapFstImpl< A, A, EncodeMapper<A> > Impl;
using ImplToFst<Impl>::GetImpl;
DecodeFst(const Fst<A> &fst, const EncodeMapper<A>& encoder)
: ArcMapFst<A, A, C>(fst,
EncodeMapper<A>(encoder, DECODE),
ArcMapFstOptions()) {
GetImpl()->SetInputSymbols(encoder.InputSymbols());
GetImpl()->SetOutputSymbols(encoder.OutputSymbols());
}
// See Fst<>::Copy() for doc.
DecodeFst(const DecodeFst<A> &fst, bool safe = false)
: ArcMapFst<A, A, C>(fst, safe) {}
// Get a copy of this DecodeFst. See Fst<>::Copy() for further doc.
virtual DecodeFst<A> *Copy(bool safe = false) const {
return new DecodeFst(*this, safe);
}
};
// Specialization for EncodeFst.
template <class A>
class StateIterator< EncodeFst<A> >
: public StateIterator< ArcMapFst<A, A, EncodeMapper<A> > > {
public:
explicit StateIterator(const EncodeFst<A> &fst)
: StateIterator< ArcMapFst<A, A, EncodeMapper<A> > >(fst) {}
};
// Specialization for EncodeFst.
template <class A>
class ArcIterator< EncodeFst<A> >
: public ArcIterator< ArcMapFst<A, A, EncodeMapper<A> > > {
public:
ArcIterator(const EncodeFst<A> &fst, typename A::StateId s)
: ArcIterator< ArcMapFst<A, A, EncodeMapper<A> > >(fst, s) {}
};
// Specialization for DecodeFst.
template <class A>
class StateIterator< DecodeFst<A> >
: public StateIterator< ArcMapFst<A, A, EncodeMapper<A> > > {
public:
explicit StateIterator(const DecodeFst<A> &fst)
: StateIterator< ArcMapFst<A, A, EncodeMapper<A> > >(fst) {}
};
// Specialization for DecodeFst.
template <class A>
class ArcIterator< DecodeFst<A> >
: public ArcIterator< ArcMapFst<A, A, EncodeMapper<A> > > {
public:
ArcIterator(const DecodeFst<A> &fst, typename A::StateId s)
: ArcIterator< ArcMapFst<A, A, EncodeMapper<A> > >(fst, s) {}
};
// Useful aliases when using StdArc.
typedef EncodeFst<StdArc> StdEncodeFst;
typedef DecodeFst<StdArc> StdDecodeFst;
} // namespace fst
#endif // FST_LIB_ENCODE_H__