blob: 7d8aa118ac2dfb85da0597ef158c678487622d7f [file] [log] [blame]
// pair-weight.h
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: shumash@google.com (Masha Maria Shugrina)
//
// \file
// Pair weight templated base class for weight classes that
// contain two weights (e.g. Product, Lexicographic)
#ifndef FST_LIB_PAIR_WEIGHT_H_
#define FST_LIB_PAIR_WEIGHT_H_
#include <climits>
#include <stack>
#include <string>
#include <fst/weight.h>
DECLARE_string(fst_weight_parentheses);
DECLARE_string(fst_weight_separator);
namespace fst {
template<class W1, class W2> class PairWeight;
template <class W1, class W2>
istream &operator>>(istream &strm, PairWeight<W1, W2> &w);
template<class W1, class W2>
class PairWeight {
public:
friend istream &operator>><W1, W2>(istream&, PairWeight<W1, W2>&);
typedef PairWeight<typename W1::ReverseWeight,
typename W2::ReverseWeight>
ReverseWeight;
PairWeight() {}
PairWeight(const PairWeight& w) : value1_(w.value1_), value2_(w.value2_) {}
PairWeight(W1 w1, W2 w2) : value1_(w1), value2_(w2) {}
static const PairWeight<W1, W2> &Zero() {
static const PairWeight<W1, W2> zero(W1::Zero(), W2::Zero());
return zero;
}
static const PairWeight<W1, W2> &One() {
static const PairWeight<W1, W2> one(W1::One(), W2::One());
return one;
}
static const PairWeight<W1, W2> &NoWeight() {
static const PairWeight<W1, W2> no_weight(W1::NoWeight(), W2::NoWeight());
return no_weight;
}
istream &Read(istream &strm) {
value1_.Read(strm);
return value2_.Read(strm);
}
ostream &Write(ostream &strm) const {
value1_.Write(strm);
return value2_.Write(strm);
}
PairWeight<W1, W2> &operator=(const PairWeight<W1, W2> &w) {
value1_ = w.Value1();
value2_ = w.Value2();
return *this;
}
bool Member() const { return value1_.Member() && value2_.Member(); }
size_t Hash() const {
size_t h1 = value1_.Hash();
size_t h2 = value2_.Hash();
const int lshift = 5;
const int rshift = CHAR_BIT * sizeof(size_t) - 5;
return h1 << lshift ^ h1 >> rshift ^ h2;
}
PairWeight<W1, W2> Quantize(float delta = kDelta) const {
return PairWeight<W1, W2>(value1_.Quantize(delta),
value2_.Quantize(delta));
}
ReverseWeight Reverse() const {
return ReverseWeight(value1_.Reverse(), value2_.Reverse());
}
const W1& Value1() const { return value1_; }
const W2& Value2() const { return value2_; }
protected:
void SetValue1(const W1 &w) { value1_ = w; }
void SetValue2(const W2 &w) { value2_ = w; }
// Reads PairWeight when there are not parentheses around pair terms
inline static istream &ReadNoParen(
istream &strm, PairWeight<W1, W2>& w, char separator) {
int c;
do {
c = strm.get();
} while (isspace(c));
string s1;
while (c != separator) {
if (c == EOF) {
strm.clear(std::ios::badbit);
return strm;
}
s1 += c;
c = strm.get();
}
istringstream strm1(s1);
W1 w1 = W1::Zero();
strm1 >> w1;
// read second element
W2 w2 = W2::Zero();
strm >> w2;
w = PairWeight<W1, W2>(w1, w2);
return strm;
}
// Reads PairWeight when there are parentheses around pair terms
inline static istream &ReadWithParen(
istream &strm, PairWeight<W1, W2>& w,
char separator, char open_paren, char close_paren) {
int c;
do {
c = strm.get();
} while (isspace(c));
if (c != open_paren) {
FSTERROR() << " is fst_weight_parentheses flag set correcty? ";
strm.clear(std::ios::failbit);
return strm;
}
c = strm.get();
// read first element
stack<int> parens;
string s1;
while (c != separator || !parens.empty()) {
if (c == EOF) {
strm.clear(std::ios::badbit);
return strm;
}
s1 += c;
// if parens encountered before separator, they must be matched
if (c == open_paren) {
parens.push(1);
} else if (c == close_paren) {
// Fail for mismatched parens
if (parens.empty()) {
strm.clear(std::ios::failbit);
return strm;
}
parens.pop();
}
c = strm.get();
}
istringstream strm1(s1);
W1 w1 = W1::Zero();
strm1 >> w1;
// read second element
string s2;
c = strm.get();
while (c != EOF) {
s2 += c;
c = strm.get();
}
if (s2.empty() || (s2[s2.size() - 1] != close_paren)) {
FSTERROR() << " is fst_weight_parentheses flag set correcty? ";
strm.clear(std::ios::failbit);
return strm;
}
s2.erase(s2.size() - 1, 1);
istringstream strm2(s2);
W2 w2 = W2::Zero();
strm2 >> w2;
w = PairWeight<W1, W2>(w1, w2);
return strm;
}
private:
W1 value1_;
W2 value2_;
};
template <class W1, class W2>
inline bool operator==(const PairWeight<W1, W2> &w,
const PairWeight<W1, W2> &v) {
return w.Value1() == v.Value1() && w.Value2() == v.Value2();
}
template <class W1, class W2>
inline bool operator!=(const PairWeight<W1, W2> &w1,
const PairWeight<W1, W2> &w2) {
return w1.Value1() != w2.Value1() || w1.Value2() != w2.Value2();
}
template <class W1, class W2>
inline bool ApproxEqual(const PairWeight<W1, W2> &w1,
const PairWeight<W1, W2> &w2,
float delta = kDelta) {
return ApproxEqual(w1.Value1(), w2.Value1(), delta) &&
ApproxEqual(w1.Value2(), w2.Value2(), delta);
}
template <class W1, class W2>
inline ostream &operator<<(ostream &strm, const PairWeight<W1, W2> &w) {
if(FLAGS_fst_weight_separator.size() != 1) {
FSTERROR() << "FLAGS_fst_weight_separator.size() is not equal to 1";
strm.clear(std::ios::badbit);
return strm;
}
char separator = FLAGS_fst_weight_separator[0];
if (FLAGS_fst_weight_parentheses.empty())
return strm << w.Value1() << separator << w.Value2();
if (FLAGS_fst_weight_parentheses.size() != 2) {
FSTERROR() << "FLAGS_fst_weight_parentheses.size() is not equal to 2";
strm.clear(std::ios::badbit);
return strm;
}
char open_paren = FLAGS_fst_weight_parentheses[0];
char close_paren = FLAGS_fst_weight_parentheses[1];
return strm << open_paren << w.Value1() << separator
<< w.Value2() << close_paren ;
}
template <class W1, class W2>
inline istream &operator>>(istream &strm, PairWeight<W1, W2> &w) {
if(FLAGS_fst_weight_separator.size() != 1) {
FSTERROR() << "FLAGS_fst_weight_separator.size() is not equal to 1";
strm.clear(std::ios::badbit);
return strm;
}
char separator = FLAGS_fst_weight_separator[0];
bool read_parens = !FLAGS_fst_weight_parentheses.empty();
if (read_parens) {
if (FLAGS_fst_weight_parentheses.size() != 2) {
FSTERROR() << "FLAGS_fst_weight_parentheses.size() is not equal to 2";
strm.clear(std::ios::badbit);
return strm;
}
return PairWeight<W1, W2>::ReadWithParen(
strm, w, separator, FLAGS_fst_weight_parentheses[0],
FLAGS_fst_weight_parentheses[1]);
} else {
return PairWeight<W1, W2>::ReadNoParen(strm, w, separator);
}
}
} // namespace fst
#endif // FST_LIB_PAIR_WEIGHT_H_