blob: c12ef4f468eef47285df73338b349d8ef34345b2 [file] [log] [blame]
// sparse-tuple-weight.h
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: krr@google.com (Kasturi Rangan Raghavan)
// Inspiration: allauzen@google.com (Cyril Allauzen)
// \file
// Sparse version of tuple-weight, based on tuple-weight.h
// Internally stores sparse key, value pairs in linked list
// Default value elemnt is the assumed value of unset keys
// Internal singleton implementation that stores first key,
// value pair as a initialized member variable to avoide
// unnecessary allocation on heap.
// Use SparseTupleWeightIterator to iterate through the key,value pairs
// Note: this does NOT iterate through the default value.
//
// Sparse tuple weight set operation definitions.
#ifndef FST_LIB_SPARSE_TUPLE_WEIGHT_H__
#define FST_LIB_SPARSE_TUPLE_WEIGHT_H__
#include<string>
#include<list>
#include<stack>
#include<tr1/unordered_map>
using std::tr1::unordered_map;
using std::tr1::unordered_multimap;
#include <fst/weight.h>
DECLARE_string(fst_weight_parentheses);
DECLARE_string(fst_weight_separator);
namespace fst {
template <class W, class K> class SparseTupleWeight;
template<class W, class K>
class SparseTupleWeightIterator;
template <class W, class K>
istream &operator>>(istream &strm, SparseTupleWeight<W, K> &w);
// Arbitrary dimension tuple weight, stored as a sorted linked-list
// W is any weight class,
// K is the key value type. kNoKey(-1) is reserved for internal use
template <class W, class K = int>
class SparseTupleWeight {
public:
typedef pair<K, W> Pair;
typedef SparseTupleWeight<typename W::ReverseWeight, K> ReverseWeight;
const static K kNoKey = -1;
SparseTupleWeight() {
Init();
}
template <class Iterator>
SparseTupleWeight(Iterator begin, Iterator end) {
Init();
// Assumes input iterator is sorted
for (Iterator it = begin; it != end; ++it)
Push(*it);
}
SparseTupleWeight(const K& key, const W &w) {
Init();
Push(key, w);
}
SparseTupleWeight(const W &w) {
Init(w);
}
SparseTupleWeight(const SparseTupleWeight<W, K> &w) {
Init(w.DefaultValue());
SetDefaultValue(w.DefaultValue());
for (SparseTupleWeightIterator<W, K> it(w); !it.Done(); it.Next()) {
Push(it.Value());
}
}
static const SparseTupleWeight<W, K> &Zero() {
static SparseTupleWeight<W, K> zero;
return zero;
}
static const SparseTupleWeight<W, K> &One() {
static SparseTupleWeight<W, K> one(W::One());
return one;
}
static const SparseTupleWeight<W, K> &NoWeight() {
static SparseTupleWeight<W, K> no_weight(W::NoWeight());
return no_weight;
}
istream &Read(istream &strm) {
ReadType(strm, &default_);
ReadType(strm, &first_);
return ReadType(strm, &rest_);
}
ostream &Write(ostream &strm) const {
WriteType(strm, default_);
WriteType(strm, first_);
return WriteType(strm, rest_);
}
SparseTupleWeight<W, K> &operator=(const SparseTupleWeight<W, K> &w) {
if (this == &w) return *this; // check for w = w
Init(w.DefaultValue());
for (SparseTupleWeightIterator<W, K> it(w); !it.Done(); it.Next()) {
Push(it.Value());
}
return *this;
}
bool Member() const {
if (!DefaultValue().Member()) return false;
for (SparseTupleWeightIterator<W, K> it(*this); !it.Done(); it.Next()) {
if (!it.Value().second.Member()) return false;
}
return true;
}
// Assumes H() function exists for the hash of the key value
size_t Hash() const {
uint64 h = 0;
std::tr1::hash<K> H;
for (SparseTupleWeightIterator<W, K> it(*this); !it.Done(); it.Next()) {
h = 5 * h + H(it.Value().first);
h = 13 * h + it.Value().second.Hash();
}
return size_t(h);
}
SparseTupleWeight<W, K> Quantize(float delta = kDelta) const {
SparseTupleWeight<W, K> w;
for (SparseTupleWeightIterator<W, K> it(*this); !it.Done(); it.Next()) {
w.Push(it.Value().first, it.Value().second.Quantize(delta));
}
return w;
}
ReverseWeight Reverse() const {
SparseTupleWeight<W, K> w;
for (SparseTupleWeightIterator<W, K> it(*this); !it.Done(); it.Next()) {
w.Push(it.Value().first, it.Value().second.Reverse());
}
return w;
}
// Common initializer among constructors.
void Init() {
Init(W::Zero());
}
void Init(const W& default_value) {
first_.first = kNoKey;
/* initialized to the reserved key value */
default_ = default_value;
rest_.clear();
}
size_t Size() const {
if (first_.first == kNoKey)
return 0;
else
return rest_.size() + 1;
}
inline void Push(const K &k, const W &w, bool default_value_check = true) {
Push(make_pair(k, w), default_value_check);
}
inline void Push(const Pair &p, bool default_value_check = true) {
if (default_value_check && p.second == default_) return;
if (first_.first == kNoKey) {
first_ = p;
} else {
rest_.push_back(p);
}
}
void SetDefaultValue(const W& val) { default_ = val; }
const W& DefaultValue() const { return default_; }
protected:
static istream& ReadNoParen(
istream&, SparseTupleWeight<W, K>&, char separator);
static istream& ReadWithParen(
istream&, SparseTupleWeight<W, K>&,
char separator, char open_paren, char close_paren);
private:
// Assumed default value of uninitialized keys, by default W::Zero()
W default_;
// Key values pairs are first stored in first_, then fill rest_
// this way we can avoid dynamic allocation in the common case
// where the weight is a single key,val pair.
Pair first_;
list<Pair> rest_;
friend istream &operator>><W, K>(istream&, SparseTupleWeight<W, K>&);
friend class SparseTupleWeightIterator<W, K>;
};
template<class W, class K>
class SparseTupleWeightIterator {
public:
typedef typename SparseTupleWeight<W, K>::Pair Pair;
typedef typename list<Pair>::const_iterator const_iterator;
typedef typename list<Pair>::iterator iterator;
explicit SparseTupleWeightIterator(const SparseTupleWeight<W, K>& w)
: first_(w.first_), rest_(w.rest_), init_(true),
iter_(rest_.begin()) {}
bool Done() const {
if (init_)
return first_.first == SparseTupleWeight<W, K>::kNoKey;
else
return iter_ == rest_.end();
}
const Pair& Value() const { return init_ ? first_ : *iter_; }
void Next() {
if (init_)
init_ = false;
else
++iter_;
}
void Reset() {
init_ = true;
iter_ = rest_.begin();
}
private:
const Pair &first_;
const list<Pair> & rest_;
bool init_; // in the initialized state?
typename list<Pair>::const_iterator iter_;
DISALLOW_COPY_AND_ASSIGN(SparseTupleWeightIterator);
};
template<class W, class K, class M>
inline void SparseTupleWeightMap(
SparseTupleWeight<W, K>* ret,
const SparseTupleWeight<W, K>& w1,
const SparseTupleWeight<W, K>& w2,
const M& operator_mapper) {
SparseTupleWeightIterator<W, K> w1_it(w1);
SparseTupleWeightIterator<W, K> w2_it(w2);
const W& v1_def = w1.DefaultValue();
const W& v2_def = w2.DefaultValue();
ret->SetDefaultValue(operator_mapper.Map(0, v1_def, v2_def));
while (!w1_it.Done() || !w2_it.Done()) {
const K& k1 = (w1_it.Done()) ? w2_it.Value().first : w1_it.Value().first;
const K& k2 = (w2_it.Done()) ? w1_it.Value().first : w2_it.Value().first;
const W& v1 = (w1_it.Done()) ? v1_def : w1_it.Value().second;
const W& v2 = (w2_it.Done()) ? v2_def : w2_it.Value().second;
if (k1 == k2) {
ret->Push(k1, operator_mapper.Map(k1, v1, v2));
if (!w1_it.Done()) w1_it.Next();
if (!w2_it.Done()) w2_it.Next();
} else if (k1 < k2) {
ret->Push(k1, operator_mapper.Map(k1, v1, v2_def));
w1_it.Next();
} else {
ret->Push(k2, operator_mapper.Map(k2, v1_def, v2));
w2_it.Next();
}
}
}
template <class W, class K>
inline bool operator==(const SparseTupleWeight<W, K> &w1,
const SparseTupleWeight<W, K> &w2) {
const W& v1_def = w1.DefaultValue();
const W& v2_def = w2.DefaultValue();
if (v1_def != v2_def) return false;
SparseTupleWeightIterator<W, K> w1_it(w1);
SparseTupleWeightIterator<W, K> w2_it(w2);
while (!w1_it.Done() || !w2_it.Done()) {
const K& k1 = (w1_it.Done()) ? w2_it.Value().first : w1_it.Value().first;
const K& k2 = (w2_it.Done()) ? w1_it.Value().first : w2_it.Value().first;
const W& v1 = (w1_it.Done()) ? v1_def : w1_it.Value().second;
const W& v2 = (w2_it.Done()) ? v2_def : w2_it.Value().second;
if (k1 == k2) {
if (v1 != v2) return false;
if (!w1_it.Done()) w1_it.Next();
if (!w2_it.Done()) w2_it.Next();
} else if (k1 < k2) {
if (v1 != v2_def) return false;
w1_it.Next();
} else {
if (v1_def != v2) return false;
w2_it.Next();
}
}
return true;
}
template <class W, class K>
inline bool operator!=(const SparseTupleWeight<W, K> &w1,
const SparseTupleWeight<W, K> &w2) {
return !(w1 == w2);
}
template <class W, class K>
inline ostream &operator<<(ostream &strm, const SparseTupleWeight<W, K> &w) {
if(FLAGS_fst_weight_separator.size() != 1) {
FSTERROR() << "FLAGS_fst_weight_separator.size() is not equal to 1";
strm.clear(std::ios::badbit);
return strm;
}
char separator = FLAGS_fst_weight_separator[0];
bool write_parens = false;
if (!FLAGS_fst_weight_parentheses.empty()) {
if (FLAGS_fst_weight_parentheses.size() != 2) {
FSTERROR() << "FLAGS_fst_weight_parentheses.size() is not equal to 2";
strm.clear(std::ios::badbit);
return strm;
}
write_parens = true;
}
if (write_parens)
strm << FLAGS_fst_weight_parentheses[0];
strm << w.DefaultValue();
strm << separator;
size_t n = w.Size();
strm << n;
strm << separator;
for (SparseTupleWeightIterator<W, K> it(w); !it.Done(); it.Next()) {
strm << it.Value().first;
strm << separator;
strm << it.Value().second;
strm << separator;
}
if (write_parens)
strm << FLAGS_fst_weight_parentheses[1];
return strm;
}
template <class W, class K>
inline istream &operator>>(istream &strm, SparseTupleWeight<W, K> &w) {
if(FLAGS_fst_weight_separator.size() != 1) {
FSTERROR() << "FLAGS_fst_weight_separator.size() is not equal to 1";
strm.clear(std::ios::badbit);
return strm;
}
char separator = FLAGS_fst_weight_separator[0];
if (!FLAGS_fst_weight_parentheses.empty()) {
if (FLAGS_fst_weight_parentheses.size() != 2) {
FSTERROR() << "FLAGS_fst_weight_parentheses.size() is not equal to 2";
strm.clear(std::ios::badbit);
return strm;
}
return SparseTupleWeight<W, K>::ReadWithParen(
strm, w, separator, FLAGS_fst_weight_parentheses[0],
FLAGS_fst_weight_parentheses[1]);
} else {
return SparseTupleWeight<W, K>::ReadNoParen(strm, w, separator);
}
}
// Reads SparseTupleWeight when there are no parentheses around tuple terms
template <class W, class K>
inline istream& SparseTupleWeight<W, K>::ReadNoParen(
istream &strm,
SparseTupleWeight<W, K> &w,
char separator) {
int c;
size_t n;
do {
c = strm.get();
} while (isspace(c));
{ // Read default weight
W default_value;
string s;
while (c != separator) {
if (c == EOF) {
strm.clear(std::ios::badbit);
return strm;
}
s += c;
c = strm.get();
}
istringstream sstrm(s);
sstrm >> default_value;
w.SetDefaultValue(default_value);
}
c = strm.get();
{ // Read n
string s;
while (c != separator) {
if (c == EOF) {
strm.clear(std::ios::badbit);
return strm;
}
s += c;
c = strm.get();
}
istringstream sstrm(s);
sstrm >> n;
}
// Read n elements
for (size_t i = 0; i < n; ++i) {
// discard separator
c = strm.get();
K p;
W r;
{ // read key
string s;
while (c != separator) {
if (c == EOF) {
strm.clear(std::ios::badbit);
return strm;
}
s += c;
c = strm.get();
}
istringstream sstrm(s);
sstrm >> p;
}
c = strm.get();
{ // read weight
string s;
while (c != separator) {
if (c == EOF) {
strm.clear(std::ios::badbit);
return strm;
}
s += c;
c = strm.get();
}
istringstream sstrm(s);
sstrm >> r;
}
w.Push(p, r);
}
c = strm.get();
if (c != separator) {
strm.clear(std::ios::badbit);
}
return strm;
}
// Reads SparseTupleWeight when there are parentheses around tuple terms
template <class W, class K>
inline istream& SparseTupleWeight<W, K>::ReadWithParen(
istream &strm,
SparseTupleWeight<W, K> &w,
char separator,
char open_paren,
char close_paren) {
int c;
size_t n;
do {
c = strm.get();
} while (isspace(c));
if (c != open_paren) {
FSTERROR() << "is fst_weight_parentheses flag set correcty? ";
strm.clear(std::ios::badbit);
return strm;
}
c = strm.get();
{ // Read weight
W default_value;
stack<int> parens;
string s;
while (c != separator || !parens.empty()) {
if (c == EOF) {
strm.clear(std::ios::badbit);
return strm;
}
s += c;
// If parens encountered before separator, they must be matched
if (c == open_paren) {
parens.push(1);
} else if (c == close_paren) {
// Fail for mismatched parens
if (parens.empty()) {
strm.clear(std::ios::failbit);
return strm;
}
parens.pop();
}
c = strm.get();
}
istringstream sstrm(s);
sstrm >> default_value;
w.SetDefaultValue(default_value);
}
c = strm.get();
{ // Read n
string s;
while (c != separator) {
if (c == EOF) {
strm.clear(std::ios::badbit);
return strm;
}
s += c;
c = strm.get();
}
istringstream sstrm(s);
sstrm >> n;
}
// Read n elements
for (size_t i = 0; i < n; ++i) {
// discard separator
c = strm.get();
K p;
W r;
{ // Read key
stack<int> parens;
string s;
while (c != separator || !parens.empty()) {
if (c == EOF) {
strm.clear(std::ios::badbit);
return strm;
}
s += c;
// If parens encountered before separator, they must be matched
if (c == open_paren) {
parens.push(1);
} else if (c == close_paren) {
// Fail for mismatched parens
if (parens.empty()) {
strm.clear(std::ios::failbit);
return strm;
}
parens.pop();
}
c = strm.get();
}
istringstream sstrm(s);
sstrm >> p;
}
c = strm.get();
{ // Read weight
stack<int> parens;
string s;
while (c != separator || !parens.empty()) {
if (c == EOF) {
strm.clear(std::ios::badbit);
return strm;
}
s += c;
// If parens encountered before separator, they must be matched
if (c == open_paren) {
parens.push(1);
} else if (c == close_paren) {
// Fail for mismatched parens
if (parens.empty()) {
strm.clear(std::ios::failbit);
return strm;
}
parens.pop();
}
c = strm.get();
}
istringstream sstrm(s);
sstrm >> r;
}
w.Push(p, r);
}
if (c != separator) {
FSTERROR() << " separator expected, not found! ";
strm.clear(std::ios::badbit);
return strm;
}
c = strm.get();
if (c != close_paren) {
FSTERROR() << " is fst_weight_parentheses flag set correcty? ";
strm.clear(std::ios::badbit);
return strm;
}
return strm;
}
} // namespace fst
#endif // FST_LIB_SPARSE_TUPLE_WEIGHT_H__