blob: da964791826aaeca5919818c7af0e1da9efbdf21 [file] [log] [blame]
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: krr@google.com (Kasturi Rangan Raghavan)
// \file
// LogWeight along with sign information that represents the value X in the
// linear domain as <sign(X), -ln(|X|)>
// The sign is a TropicalWeight:
// positive, TropicalWeight.Value() > 0.0, recommended value 1.0
// negative, TropicalWeight.Value() <= 0.0, recommended value -1.0
#ifndef FST_LIB_SIGNED_LOG_WEIGHT_H_
#define FST_LIB_SIGNED_LOG_WEIGHT_H_
#include <fst/float-weight.h>
#include <fst/pair-weight.h>
namespace fst {
template <class T>
class SignedLogWeightTpl
: public PairWeight<TropicalWeight, LogWeightTpl<T> > {
public:
typedef TropicalWeight X1;
typedef LogWeightTpl<T> X2;
using PairWeight<X1, X2>::Value1;
using PairWeight<X1, X2>::Value2;
using PairWeight<X1, X2>::Reverse;
using PairWeight<X1, X2>::Quantize;
using PairWeight<X1, X2>::Member;
typedef SignedLogWeightTpl<T> ReverseWeight;
SignedLogWeightTpl() : PairWeight<X1, X2>() {}
SignedLogWeightTpl(const SignedLogWeightTpl<T>& w)
: PairWeight<X1, X2> (w) { }
SignedLogWeightTpl(const PairWeight<X1, X2>& w)
: PairWeight<X1, X2> (w) { }
SignedLogWeightTpl(const X1& x1, const X2& x2)
: PairWeight<X1, X2>(x1, x2) { }
static const SignedLogWeightTpl<T> &Zero() {
static const SignedLogWeightTpl<T> zero(X1(1.0), X2::Zero());
return zero;
}
static const SignedLogWeightTpl<T> &One() {
static const SignedLogWeightTpl<T> one(X1(1.0), X2::One());
return one;
}
static const SignedLogWeightTpl<T> &NoWeight() {
static const SignedLogWeightTpl<T> no_weight(X1(1.0), X2::NoWeight());
return no_weight;
}
static const string &Type() {
static const string type = "signed_log_" + X1::Type() + "_" + X2::Type();
return type;
}
ProductWeight<X1, X2> Quantize(float delta = kDelta) const {
return PairWeight<X1, X2>::Quantize();
}
ReverseWeight Reverse() const {
return PairWeight<X1, X2>::Reverse();
}
bool Member() const {
return PairWeight<X1, X2>::Member();
}
static uint64 Properties() {
// not idempotent nor path
return kLeftSemiring | kRightSemiring | kCommutative;
}
size_t Hash() const {
size_t h1;
if (Value2() == X2::Zero() || Value1().Value() > 0.0)
h1 = TropicalWeight(1.0).Hash();
else
h1 = TropicalWeight(-1.0).Hash();
size_t h2 = Value2().Hash();
const int lshift = 5;
const int rshift = CHAR_BIT * sizeof(size_t) - 5;
return h1 << lshift ^ h1 >> rshift ^ h2;
}
};
template <class T>
inline SignedLogWeightTpl<T> Plus(const SignedLogWeightTpl<T> &w1,
const SignedLogWeightTpl<T> &w2) {
if (!w1.Member() || !w2.Member())
return SignedLogWeightTpl<T>::NoWeight();
bool s1 = w1.Value1().Value() > 0.0;
bool s2 = w2.Value1().Value() > 0.0;
T f1 = w1.Value2().Value();
T f2 = w2.Value2().Value();
if (f1 == FloatLimits<T>::kPosInfinity)
return w2;
else if (f2 == FloatLimits<T>::kPosInfinity)
return w1;
else if (f1 == f2) {
if (s1 == s2)
return SignedLogWeightTpl<T>(w1.Value1(), (f2 - log(2.0F)));
else
return SignedLogWeightTpl<T>::Zero();
} else if (f1 > f2) {
if (s1 == s2) {
return SignedLogWeightTpl<T>(
w1.Value1(), (f2 - log(1.0F + exp(f2 - f1))));
} else {
return SignedLogWeightTpl<T>(
w2.Value1(), (f2 - log(1.0F - exp(f2 - f1))));
}
} else {
if (s2 == s1) {
return SignedLogWeightTpl<T>(
w2.Value1(), (f1 - log(1.0F + exp(f1 - f2))));
} else {
return SignedLogWeightTpl<T>(
w1.Value1(), (f1 - log(1.0F - exp(f1 - f2))));
}
}
}
template <class T>
inline SignedLogWeightTpl<T> Minus(const SignedLogWeightTpl<T> &w1,
const SignedLogWeightTpl<T> &w2) {
SignedLogWeightTpl<T> minus_w2(-w2.Value1().Value(), w2.Value2());
return Plus(w1, minus_w2);
}
template <class T>
inline SignedLogWeightTpl<T> Times(const SignedLogWeightTpl<T> &w1,
const SignedLogWeightTpl<T> &w2) {
if (!w1.Member() || !w2.Member())
return SignedLogWeightTpl<T>::NoWeight();
bool s1 = w1.Value1().Value() > 0.0;
bool s2 = w2.Value1().Value() > 0.0;
T f1 = w1.Value2().Value();
T f2 = w2.Value2().Value();
if (s1 == s2)
return SignedLogWeightTpl<T>(TropicalWeight(1.0), (f1 + f2));
else
return SignedLogWeightTpl<T>(TropicalWeight(-1.0), (f1 + f2));
}
template <class T>
inline SignedLogWeightTpl<T> Divide(const SignedLogWeightTpl<T> &w1,
const SignedLogWeightTpl<T> &w2,
DivideType typ = DIVIDE_ANY) {
if (!w1.Member() || !w2.Member())
return SignedLogWeightTpl<T>::NoWeight();
bool s1 = w1.Value1().Value() > 0.0;
bool s2 = w2.Value1().Value() > 0.0;
T f1 = w1.Value2().Value();
T f2 = w2.Value2().Value();
if (f2 == FloatLimits<T>::kPosInfinity)
return SignedLogWeightTpl<T>(TropicalWeight(1.0),
FloatLimits<T>::kNumberBad);
else if (f1 == FloatLimits<T>::kPosInfinity)
return SignedLogWeightTpl<T>(TropicalWeight(1.0),
FloatLimits<T>::kPosInfinity);
else if (s1 == s2)
return SignedLogWeightTpl<T>(TropicalWeight(1.0), (f1 - f2));
else
return SignedLogWeightTpl<T>(TropicalWeight(-1.0), (f1 - f2));
}
template <class T>
inline bool ApproxEqual(const SignedLogWeightTpl<T> &w1,
const SignedLogWeightTpl<T> &w2,
float delta = kDelta) {
bool s1 = w1.Value1().Value() > 0.0;
bool s2 = w2.Value1().Value() > 0.0;
if (s1 == s2) {
return ApproxEqual(w1.Value2(), w2.Value2(), delta);
} else {
return w1.Value2() == LogWeightTpl<T>::Zero()
&& w2.Value2() == LogWeightTpl<T>::Zero();
}
}
template <class T>
inline bool operator==(const SignedLogWeightTpl<T> &w1,
const SignedLogWeightTpl<T> &w2) {
bool s1 = w1.Value1().Value() > 0.0;
bool s2 = w2.Value1().Value() > 0.0;
if (s1 == s2)
return w1.Value2() == w2.Value2();
else
return (w1.Value2() == LogWeightTpl<T>::Zero()) &&
(w2.Value2() == LogWeightTpl<T>::Zero());
}
// Single-precision signed-log weight
typedef SignedLogWeightTpl<float> SignedLogWeight;
// Double-precision signed-log weight
typedef SignedLogWeightTpl<double> SignedLog64Weight;
//
// WEIGHT CONVERTER SPECIALIZATIONS.
//
template <class W1, class W2>
bool SignedLogConvertCheck(W1 w) {
if (w.Value1().Value() < 0.0) {
FSTERROR() << "WeightConvert: can't convert weight from \""
<< W1::Type() << "\" to \"" << W2::Type();
return false;
}
return true;
}
// Convert to tropical
template <>
struct WeightConvert<SignedLogWeight, TropicalWeight> {
TropicalWeight operator()(SignedLogWeight w) const {
if (!SignedLogConvertCheck<SignedLogWeight, TropicalWeight>(w))
return TropicalWeight::NoWeight();
return w.Value2().Value();
}
};
template <>
struct WeightConvert<SignedLog64Weight, TropicalWeight> {
TropicalWeight operator()(SignedLog64Weight w) const {
if (!SignedLogConvertCheck<SignedLog64Weight, TropicalWeight>(w))
return TropicalWeight::NoWeight();
return w.Value2().Value();
}
};
// Convert to log
template <>
struct WeightConvert<SignedLogWeight, LogWeight> {
LogWeight operator()(SignedLogWeight w) const {
if (!SignedLogConvertCheck<SignedLogWeight, LogWeight>(w))
return LogWeight::NoWeight();
return w.Value2().Value();
}
};
template <>
struct WeightConvert<SignedLog64Weight, LogWeight> {
LogWeight operator()(SignedLog64Weight w) const {
if (!SignedLogConvertCheck<SignedLog64Weight, LogWeight>(w))
return LogWeight::NoWeight();
return w.Value2().Value();
}
};
// Convert to log64
template <>
struct WeightConvert<SignedLogWeight, Log64Weight> {
Log64Weight operator()(SignedLogWeight w) const {
if (!SignedLogConvertCheck<SignedLogWeight, Log64Weight>(w))
return Log64Weight::NoWeight();
return w.Value2().Value();
}
};
template <>
struct WeightConvert<SignedLog64Weight, Log64Weight> {
Log64Weight operator()(SignedLog64Weight w) const {
if (!SignedLogConvertCheck<SignedLog64Weight, Log64Weight>(w))
return Log64Weight::NoWeight();
return w.Value2().Value();
}
};
// Convert to signed log
template <>
struct WeightConvert<TropicalWeight, SignedLogWeight> {
SignedLogWeight operator()(TropicalWeight w) const {
TropicalWeight x1 = 1.0;
LogWeight x2 = w.Value();
return SignedLogWeight(x1, x2);
}
};
template <>
struct WeightConvert<LogWeight, SignedLogWeight> {
SignedLogWeight operator()(LogWeight w) const {
TropicalWeight x1 = 1.0;
LogWeight x2 = w.Value();
return SignedLogWeight(x1, x2);
}
};
template <>
struct WeightConvert<Log64Weight, SignedLogWeight> {
SignedLogWeight operator()(Log64Weight w) const {
TropicalWeight x1 = 1.0;
LogWeight x2 = w.Value();
return SignedLogWeight(x1, x2);
}
};
template <>
struct WeightConvert<SignedLog64Weight, SignedLogWeight> {
SignedLogWeight operator()(SignedLog64Weight w) const {
TropicalWeight x1 = w.Value1();
LogWeight x2 = w.Value2().Value();
return SignedLogWeight(x1, x2);
}
};
// Convert to signed log64
template <>
struct WeightConvert<TropicalWeight, SignedLog64Weight> {
SignedLog64Weight operator()(TropicalWeight w) const {
TropicalWeight x1 = 1.0;
Log64Weight x2 = w.Value();
return SignedLog64Weight(x1, x2);
}
};
template <>
struct WeightConvert<LogWeight, SignedLog64Weight> {
SignedLog64Weight operator()(LogWeight w) const {
TropicalWeight x1 = 1.0;
Log64Weight x2 = w.Value();
return SignedLog64Weight(x1, x2);
}
};
template <>
struct WeightConvert<Log64Weight, SignedLog64Weight> {
SignedLog64Weight operator()(Log64Weight w) const {
TropicalWeight x1 = 1.0;
Log64Weight x2 = w.Value();
return SignedLog64Weight(x1, x2);
}
};
template <>
struct WeightConvert<SignedLogWeight, SignedLog64Weight> {
SignedLog64Weight operator()(SignedLogWeight w) const {
TropicalWeight x1 = w.Value1();
Log64Weight x2 = w.Value2().Value();
return SignedLog64Weight(x1, x2);
}
};
} // namespace fst
#endif // FST_LIB_SIGNED_LOG_WEIGHT_H_