blob: 530cbdd01c86e2a2a0ef78757ec415c8928a78ce [file] [log] [blame]
// float-weight.h
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: riley@google.com (Michael Riley)
//
// \file
// Float weight set and associated semiring operation definitions.
//
#ifndef FST_LIB_FLOAT_WEIGHT_H__
#define FST_LIB_FLOAT_WEIGHT_H__
#include <limits>
#include <climits>
#include <sstream>
#include <string>
#include <fst/util.h>
#include <fst/weight.h>
namespace fst {
// numeric limits class
template <class T>
class FloatLimits {
public:
static const T kPosInfinity;
static const T kNegInfinity;
static const T kNumberBad;
};
template <class T>
const T FloatLimits<T>::kPosInfinity = numeric_limits<T>::infinity();
template <class T>
const T FloatLimits<T>::kNegInfinity = -FloatLimits<T>::kPosInfinity;
template <class T>
const T FloatLimits<T>::kNumberBad = numeric_limits<T>::quiet_NaN();
// weight class to be templated on floating-points types
template <class T = float>
class FloatWeightTpl {
public:
FloatWeightTpl() {}
FloatWeightTpl(T f) : value_(f) {}
FloatWeightTpl(const FloatWeightTpl<T> &w) : value_(w.value_) {}
FloatWeightTpl<T> &operator=(const FloatWeightTpl<T> &w) {
value_ = w.value_;
return *this;
}
istream &Read(istream &strm) {
return ReadType(strm, &value_);
}
ostream &Write(ostream &strm) const {
return WriteType(strm, value_);
}
size_t Hash() const {
union {
T f;
size_t s;
} u;
u.s = 0;
u.f = value_;
return u.s;
}
const T &Value() const { return value_; }
protected:
void SetValue(const T &f) { value_ = f; }
inline static string GetPrecisionString() {
int64 size = sizeof(T);
if (size == sizeof(float)) return "";
size *= CHAR_BIT;
string result;
Int64ToStr(size, &result);
return result;
}
private:
T value_;
};
// Single-precision float weight
typedef FloatWeightTpl<float> FloatWeight;
template <class T>
inline bool operator==(const FloatWeightTpl<T> &w1,
const FloatWeightTpl<T> &w2) {
// Volatile qualifier thwarts over-aggressive compiler optimizations
// that lead to problems esp. with NaturalLess().
volatile T v1 = w1.Value();
volatile T v2 = w2.Value();
return v1 == v2;
}
inline bool operator==(const FloatWeightTpl<double> &w1,
const FloatWeightTpl<double> &w2) {
return operator==<double>(w1, w2);
}
inline bool operator==(const FloatWeightTpl<float> &w1,
const FloatWeightTpl<float> &w2) {
return operator==<float>(w1, w2);
}
template <class T>
inline bool operator!=(const FloatWeightTpl<T> &w1,
const FloatWeightTpl<T> &w2) {
return !(w1 == w2);
}
inline bool operator!=(const FloatWeightTpl<double> &w1,
const FloatWeightTpl<double> &w2) {
return operator!=<double>(w1, w2);
}
inline bool operator!=(const FloatWeightTpl<float> &w1,
const FloatWeightTpl<float> &w2) {
return operator!=<float>(w1, w2);
}
template <class T>
inline bool ApproxEqual(const FloatWeightTpl<T> &w1,
const FloatWeightTpl<T> &w2,
float delta = kDelta) {
return w1.Value() <= w2.Value() + delta && w2.Value() <= w1.Value() + delta;
}
template <class T>
inline ostream &operator<<(ostream &strm, const FloatWeightTpl<T> &w) {
if (w.Value() == FloatLimits<T>::kPosInfinity)
return strm << "Infinity";
else if (w.Value() == FloatLimits<T>::kNegInfinity)
return strm << "-Infinity";
else if (w.Value() != w.Value()) // Fails for NaN
return strm << "BadNumber";
else
return strm << w.Value();
}
template <class T>
inline istream &operator>>(istream &strm, FloatWeightTpl<T> &w) {
string s;
strm >> s;
if (s == "Infinity") {
w = FloatWeightTpl<T>(FloatLimits<T>::kPosInfinity);
} else if (s == "-Infinity") {
w = FloatWeightTpl<T>(FloatLimits<T>::kNegInfinity);
} else {
char *p;
T f = strtod(s.c_str(), &p);
if (p < s.c_str() + s.size())
strm.clear(std::ios::badbit);
else
w = FloatWeightTpl<T>(f);
}
return strm;
}
// Tropical semiring: (min, +, inf, 0)
template <class T>
class TropicalWeightTpl : public FloatWeightTpl<T> {
public:
using FloatWeightTpl<T>::Value;
typedef TropicalWeightTpl<T> ReverseWeight;
TropicalWeightTpl() : FloatWeightTpl<T>() {}
TropicalWeightTpl(T f) : FloatWeightTpl<T>(f) {}
TropicalWeightTpl(const TropicalWeightTpl<T> &w) : FloatWeightTpl<T>(w) {}
static const TropicalWeightTpl<T> Zero() {
return TropicalWeightTpl<T>(FloatLimits<T>::kPosInfinity); }
static const TropicalWeightTpl<T> One() {
return TropicalWeightTpl<T>(0.0F); }
static const TropicalWeightTpl<T> NoWeight() {
return TropicalWeightTpl<T>(FloatLimits<T>::kNumberBad); }
static const string &Type() {
static const string type = "tropical" +
FloatWeightTpl<T>::GetPrecisionString();
return type;
}
bool Member() const {
// First part fails for IEEE NaN
return Value() == Value() && Value() != FloatLimits<T>::kNegInfinity;
}
TropicalWeightTpl<T> Quantize(float delta = kDelta) const {
if (Value() == FloatLimits<T>::kNegInfinity ||
Value() == FloatLimits<T>::kPosInfinity ||
Value() != Value())
return *this;
else
return TropicalWeightTpl<T>(floor(Value()/delta + 0.5F) * delta);
}
TropicalWeightTpl<T> Reverse() const { return *this; }
static uint64 Properties() {
return kLeftSemiring | kRightSemiring | kCommutative |
kPath | kIdempotent;
}
};
// Single precision tropical weight
typedef TropicalWeightTpl<float> TropicalWeight;
template <class T>
inline TropicalWeightTpl<T> Plus(const TropicalWeightTpl<T> &w1,
const TropicalWeightTpl<T> &w2) {
if (!w1.Member() || !w2.Member())
return TropicalWeightTpl<T>::NoWeight();
return w1.Value() < w2.Value() ? w1 : w2;
}
inline TropicalWeightTpl<float> Plus(const TropicalWeightTpl<float> &w1,
const TropicalWeightTpl<float> &w2) {
return Plus<float>(w1, w2);
}
inline TropicalWeightTpl<double> Plus(const TropicalWeightTpl<double> &w1,
const TropicalWeightTpl<double> &w2) {
return Plus<double>(w1, w2);
}
template <class T>
inline TropicalWeightTpl<T> Times(const TropicalWeightTpl<T> &w1,
const TropicalWeightTpl<T> &w2) {
if (!w1.Member() || !w2.Member())
return TropicalWeightTpl<T>::NoWeight();
T f1 = w1.Value(), f2 = w2.Value();
if (f1 == FloatLimits<T>::kPosInfinity)
return w1;
else if (f2 == FloatLimits<T>::kPosInfinity)
return w2;
else
return TropicalWeightTpl<T>(f1 + f2);
}
inline TropicalWeightTpl<float> Times(const TropicalWeightTpl<float> &w1,
const TropicalWeightTpl<float> &w2) {
return Times<float>(w1, w2);
}
inline TropicalWeightTpl<double> Times(const TropicalWeightTpl<double> &w1,
const TropicalWeightTpl<double> &w2) {
return Times<double>(w1, w2);
}
template <class T>
inline TropicalWeightTpl<T> Divide(const TropicalWeightTpl<T> &w1,
const TropicalWeightTpl<T> &w2,
DivideType typ = DIVIDE_ANY) {
if (!w1.Member() || !w2.Member())
return TropicalWeightTpl<T>::NoWeight();
T f1 = w1.Value(), f2 = w2.Value();
if (f2 == FloatLimits<T>::kPosInfinity)
return FloatLimits<T>::kNumberBad;
else if (f1 == FloatLimits<T>::kPosInfinity)
return FloatLimits<T>::kPosInfinity;
else
return TropicalWeightTpl<T>(f1 - f2);
}
inline TropicalWeightTpl<float> Divide(const TropicalWeightTpl<float> &w1,
const TropicalWeightTpl<float> &w2,
DivideType typ = DIVIDE_ANY) {
return Divide<float>(w1, w2, typ);
}
inline TropicalWeightTpl<double> Divide(const TropicalWeightTpl<double> &w1,
const TropicalWeightTpl<double> &w2,
DivideType typ = DIVIDE_ANY) {
return Divide<double>(w1, w2, typ);
}
// Log semiring: (log(e^-x + e^y), +, inf, 0)
template <class T>
class LogWeightTpl : public FloatWeightTpl<T> {
public:
using FloatWeightTpl<T>::Value;
typedef LogWeightTpl ReverseWeight;
LogWeightTpl() : FloatWeightTpl<T>() {}
LogWeightTpl(T f) : FloatWeightTpl<T>(f) {}
LogWeightTpl(const LogWeightTpl<T> &w) : FloatWeightTpl<T>(w) {}
static const LogWeightTpl<T> Zero() {
return LogWeightTpl<T>(FloatLimits<T>::kPosInfinity);
}
static const LogWeightTpl<T> One() {
return LogWeightTpl<T>(0.0F);
}
static const LogWeightTpl<T> NoWeight() {
return LogWeightTpl<T>(FloatLimits<T>::kNumberBad); }
static const string &Type() {
static const string type = "log" + FloatWeightTpl<T>::GetPrecisionString();
return type;
}
bool Member() const {
// First part fails for IEEE NaN
return Value() == Value() && Value() != FloatLimits<T>::kNegInfinity;
}
LogWeightTpl<T> Quantize(float delta = kDelta) const {
if (Value() == FloatLimits<T>::kNegInfinity ||
Value() == FloatLimits<T>::kPosInfinity ||
Value() != Value())
return *this;
else
return LogWeightTpl<T>(floor(Value()/delta + 0.5F) * delta);
}
LogWeightTpl<T> Reverse() const { return *this; }
static uint64 Properties() {
return kLeftSemiring | kRightSemiring | kCommutative;
}
};
// Single-precision log weight
typedef LogWeightTpl<float> LogWeight;
// Double-precision log weight
typedef LogWeightTpl<double> Log64Weight;
template <class T>
inline T LogExp(T x) { return log(1.0F + exp(-x)); }
template <class T>
inline LogWeightTpl<T> Plus(const LogWeightTpl<T> &w1,
const LogWeightTpl<T> &w2) {
T f1 = w1.Value(), f2 = w2.Value();
if (f1 == FloatLimits<T>::kPosInfinity)
return w2;
else if (f2 == FloatLimits<T>::kPosInfinity)
return w1;
else if (f1 > f2)
return LogWeightTpl<T>(f2 - LogExp(f1 - f2));
else
return LogWeightTpl<T>(f1 - LogExp(f2 - f1));
}
inline LogWeightTpl<float> Plus(const LogWeightTpl<float> &w1,
const LogWeightTpl<float> &w2) {
return Plus<float>(w1, w2);
}
inline LogWeightTpl<double> Plus(const LogWeightTpl<double> &w1,
const LogWeightTpl<double> &w2) {
return Plus<double>(w1, w2);
}
template <class T>
inline LogWeightTpl<T> Times(const LogWeightTpl<T> &w1,
const LogWeightTpl<T> &w2) {
if (!w1.Member() || !w2.Member())
return LogWeightTpl<T>::NoWeight();
T f1 = w1.Value(), f2 = w2.Value();
if (f1 == FloatLimits<T>::kPosInfinity)
return w1;
else if (f2 == FloatLimits<T>::kPosInfinity)
return w2;
else
return LogWeightTpl<T>(f1 + f2);
}
inline LogWeightTpl<float> Times(const LogWeightTpl<float> &w1,
const LogWeightTpl<float> &w2) {
return Times<float>(w1, w2);
}
inline LogWeightTpl<double> Times(const LogWeightTpl<double> &w1,
const LogWeightTpl<double> &w2) {
return Times<double>(w1, w2);
}
template <class T>
inline LogWeightTpl<T> Divide(const LogWeightTpl<T> &w1,
const LogWeightTpl<T> &w2,
DivideType typ = DIVIDE_ANY) {
if (!w1.Member() || !w2.Member())
return LogWeightTpl<T>::NoWeight();
T f1 = w1.Value(), f2 = w2.Value();
if (f2 == FloatLimits<T>::kPosInfinity)
return FloatLimits<T>::kNumberBad;
else if (f1 == FloatLimits<T>::kPosInfinity)
return FloatLimits<T>::kPosInfinity;
else
return LogWeightTpl<T>(f1 - f2);
}
inline LogWeightTpl<float> Divide(const LogWeightTpl<float> &w1,
const LogWeightTpl<float> &w2,
DivideType typ = DIVIDE_ANY) {
return Divide<float>(w1, w2, typ);
}
inline LogWeightTpl<double> Divide(const LogWeightTpl<double> &w1,
const LogWeightTpl<double> &w2,
DivideType typ = DIVIDE_ANY) {
return Divide<double>(w1, w2, typ);
}
// MinMax semiring: (min, max, inf, -inf)
template <class T>
class MinMaxWeightTpl : public FloatWeightTpl<T> {
public:
using FloatWeightTpl<T>::Value;
typedef MinMaxWeightTpl<T> ReverseWeight;
MinMaxWeightTpl() : FloatWeightTpl<T>() {}
MinMaxWeightTpl(T f) : FloatWeightTpl<T>(f) {}
MinMaxWeightTpl(const MinMaxWeightTpl<T> &w) : FloatWeightTpl<T>(w) {}
static const MinMaxWeightTpl<T> Zero() {
return MinMaxWeightTpl<T>(FloatLimits<T>::kPosInfinity);
}
static const MinMaxWeightTpl<T> One() {
return MinMaxWeightTpl<T>(FloatLimits<T>::kNegInfinity);
}
static const MinMaxWeightTpl<T> NoWeight() {
return MinMaxWeightTpl<T>(FloatLimits<T>::kNumberBad); }
static const string &Type() {
static const string type = "minmax" +
FloatWeightTpl<T>::GetPrecisionString();
return type;
}
bool Member() const {
// Fails for IEEE NaN
return Value() == Value();
}
MinMaxWeightTpl<T> Quantize(float delta = kDelta) const {
// If one of infinities, or a NaN
if (Value() == FloatLimits<T>::kNegInfinity ||
Value() == FloatLimits<T>::kPosInfinity ||
Value() != Value())
return *this;
else
return MinMaxWeightTpl<T>(floor(Value()/delta + 0.5F) * delta);
}
MinMaxWeightTpl<T> Reverse() const { return *this; }
static uint64 Properties() {
return kLeftSemiring | kRightSemiring | kCommutative | kIdempotent | kPath;
}
};
// Single-precision min-max weight
typedef MinMaxWeightTpl<float> MinMaxWeight;
// Min
template <class T>
inline MinMaxWeightTpl<T> Plus(
const MinMaxWeightTpl<T> &w1, const MinMaxWeightTpl<T> &w2) {
if (!w1.Member() || !w2.Member())
return MinMaxWeightTpl<T>::NoWeight();
return w1.Value() < w2.Value() ? w1 : w2;
}
inline MinMaxWeightTpl<float> Plus(
const MinMaxWeightTpl<float> &w1, const MinMaxWeightTpl<float> &w2) {
return Plus<float>(w1, w2);
}
inline MinMaxWeightTpl<double> Plus(
const MinMaxWeightTpl<double> &w1, const MinMaxWeightTpl<double> &w2) {
return Plus<double>(w1, w2);
}
// Max
template <class T>
inline MinMaxWeightTpl<T> Times(
const MinMaxWeightTpl<T> &w1, const MinMaxWeightTpl<T> &w2) {
if (!w1.Member() || !w2.Member())
return MinMaxWeightTpl<T>::NoWeight();
return w1.Value() >= w2.Value() ? w1 : w2;
}
inline MinMaxWeightTpl<float> Times(
const MinMaxWeightTpl<float> &w1, const MinMaxWeightTpl<float> &w2) {
return Times<float>(w1, w2);
}
inline MinMaxWeightTpl<double> Times(
const MinMaxWeightTpl<double> &w1, const MinMaxWeightTpl<double> &w2) {
return Times<double>(w1, w2);
}
// Defined only for special cases
template <class T>
inline MinMaxWeightTpl<T> Divide(const MinMaxWeightTpl<T> &w1,
const MinMaxWeightTpl<T> &w2,
DivideType typ = DIVIDE_ANY) {
if (!w1.Member() || !w2.Member())
return MinMaxWeightTpl<T>::NoWeight();
// min(w1, x) = w2, w1 >= w2 => min(w1, x) = w2, x = w2
return w1.Value() >= w2.Value() ? w1 : FloatLimits<T>::kNumberBad;
}
inline MinMaxWeightTpl<float> Divide(const MinMaxWeightTpl<float> &w1,
const MinMaxWeightTpl<float> &w2,
DivideType typ = DIVIDE_ANY) {
return Divide<float>(w1, w2, typ);
}
inline MinMaxWeightTpl<double> Divide(const MinMaxWeightTpl<double> &w1,
const MinMaxWeightTpl<double> &w2,
DivideType typ = DIVIDE_ANY) {
return Divide<double>(w1, w2, typ);
}
//
// WEIGHT CONVERTER SPECIALIZATIONS.
//
// Convert to tropical
template <>
struct WeightConvert<LogWeight, TropicalWeight> {
TropicalWeight operator()(LogWeight w) const { return w.Value(); }
};
template <>
struct WeightConvert<Log64Weight, TropicalWeight> {
TropicalWeight operator()(Log64Weight w) const { return w.Value(); }
};
// Convert to log
template <>
struct WeightConvert<TropicalWeight, LogWeight> {
LogWeight operator()(TropicalWeight w) const { return w.Value(); }
};
template <>
struct WeightConvert<Log64Weight, LogWeight> {
LogWeight operator()(Log64Weight w) const { return w.Value(); }
};
// Convert to log64
template <>
struct WeightConvert<TropicalWeight, Log64Weight> {
Log64Weight operator()(TropicalWeight w) const { return w.Value(); }
};
template <>
struct WeightConvert<LogWeight, Log64Weight> {
Log64Weight operator()(LogWeight w) const { return w.Value(); }
};
} // namespace fst
#endif // FST_LIB_FLOAT_WEIGHT_H__