| |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| // |
| // Copyright 2005-2010 Google, Inc. |
| // Author: krr@google.com (Kasturi Rangan Raghavan) |
| // \file |
| // LogWeight along with sign information that represents the value X in the |
| // linear domain as <sign(X), -ln(|X|)> |
| // The sign is a TropicalWeight: |
| // positive, TropicalWeight.Value() > 0.0, recommended value 1.0 |
| // negative, TropicalWeight.Value() <= 0.0, recommended value -1.0 |
| |
| #ifndef FST_LIB_SIGNED_LOG_WEIGHT_H_ |
| #define FST_LIB_SIGNED_LOG_WEIGHT_H_ |
| |
| #include <fst/float-weight.h> |
| #include <fst/pair-weight.h> |
| |
| |
| namespace fst { |
| template <class T> |
| class SignedLogWeightTpl |
| : public PairWeight<TropicalWeight, LogWeightTpl<T> > { |
| public: |
| typedef TropicalWeight X1; |
| typedef LogWeightTpl<T> X2; |
| using PairWeight<X1, X2>::Value1; |
| using PairWeight<X1, X2>::Value2; |
| |
| using PairWeight<X1, X2>::Reverse; |
| using PairWeight<X1, X2>::Quantize; |
| using PairWeight<X1, X2>::Member; |
| |
| typedef SignedLogWeightTpl<T> ReverseWeight; |
| |
| SignedLogWeightTpl() : PairWeight<X1, X2>() {} |
| |
| SignedLogWeightTpl(const SignedLogWeightTpl<T>& w) |
| : PairWeight<X1, X2> (w) { } |
| |
| SignedLogWeightTpl(const PairWeight<X1, X2>& w) |
| : PairWeight<X1, X2> (w) { } |
| |
| SignedLogWeightTpl(const X1& x1, const X2& x2) |
| : PairWeight<X1, X2>(x1, x2) { } |
| |
| static const SignedLogWeightTpl<T> &Zero() { |
| static const SignedLogWeightTpl<T> zero(X1(1.0), X2::Zero()); |
| return zero; |
| } |
| |
| static const SignedLogWeightTpl<T> &One() { |
| static const SignedLogWeightTpl<T> one(X1(1.0), X2::One()); |
| return one; |
| } |
| |
| static const SignedLogWeightTpl<T> &NoWeight() { |
| static const SignedLogWeightTpl<T> no_weight(X1(1.0), X2::NoWeight()); |
| return no_weight; |
| } |
| |
| static const string &Type() { |
| static const string type = "signed_log_" + X1::Type() + "_" + X2::Type(); |
| return type; |
| } |
| |
| ProductWeight<X1, X2> Quantize(float delta = kDelta) const { |
| return PairWeight<X1, X2>::Quantize(); |
| } |
| |
| ReverseWeight Reverse() const { |
| return PairWeight<X1, X2>::Reverse(); |
| } |
| |
| bool Member() const { |
| return PairWeight<X1, X2>::Member(); |
| } |
| |
| static uint64 Properties() { |
| // not idempotent nor path |
| return kLeftSemiring | kRightSemiring | kCommutative; |
| } |
| |
| size_t Hash() const { |
| size_t h1; |
| if (Value2() == X2::Zero() || Value1().Value() > 0.0) |
| h1 = TropicalWeight(1.0).Hash(); |
| else |
| h1 = TropicalWeight(-1.0).Hash(); |
| size_t h2 = Value2().Hash(); |
| const int lshift = 5; |
| const int rshift = CHAR_BIT * sizeof(size_t) - 5; |
| return h1 << lshift ^ h1 >> rshift ^ h2; |
| } |
| }; |
| |
| template <class T> |
| inline SignedLogWeightTpl<T> Plus(const SignedLogWeightTpl<T> &w1, |
| const SignedLogWeightTpl<T> &w2) { |
| if (!w1.Member() || !w2.Member()) |
| return SignedLogWeightTpl<T>::NoWeight(); |
| bool s1 = w1.Value1().Value() > 0.0; |
| bool s2 = w2.Value1().Value() > 0.0; |
| T f1 = w1.Value2().Value(); |
| T f2 = w2.Value2().Value(); |
| if (f1 == FloatLimits<T>::kPosInfinity) |
| return w2; |
| else if (f2 == FloatLimits<T>::kPosInfinity) |
| return w1; |
| else if (f1 == f2) { |
| if (s1 == s2) |
| return SignedLogWeightTpl<T>(w1.Value1(), (f2 - log(2.0F))); |
| else |
| return SignedLogWeightTpl<T>::Zero(); |
| } else if (f1 > f2) { |
| if (s1 == s2) { |
| return SignedLogWeightTpl<T>( |
| w1.Value1(), (f2 - log(1.0F + exp(f2 - f1)))); |
| } else { |
| return SignedLogWeightTpl<T>( |
| w2.Value1(), (f2 - log(1.0F - exp(f2 - f1)))); |
| } |
| } else { |
| if (s2 == s1) { |
| return SignedLogWeightTpl<T>( |
| w2.Value1(), (f1 - log(1.0F + exp(f1 - f2)))); |
| } else { |
| return SignedLogWeightTpl<T>( |
| w1.Value1(), (f1 - log(1.0F - exp(f1 - f2)))); |
| } |
| } |
| } |
| |
| template <class T> |
| inline SignedLogWeightTpl<T> Minus(const SignedLogWeightTpl<T> &w1, |
| const SignedLogWeightTpl<T> &w2) { |
| SignedLogWeightTpl<T> minus_w2(-w2.Value1().Value(), w2.Value2()); |
| return Plus(w1, minus_w2); |
| } |
| |
| template <class T> |
| inline SignedLogWeightTpl<T> Times(const SignedLogWeightTpl<T> &w1, |
| const SignedLogWeightTpl<T> &w2) { |
| if (!w1.Member() || !w2.Member()) |
| return SignedLogWeightTpl<T>::NoWeight(); |
| bool s1 = w1.Value1().Value() > 0.0; |
| bool s2 = w2.Value1().Value() > 0.0; |
| T f1 = w1.Value2().Value(); |
| T f2 = w2.Value2().Value(); |
| if (s1 == s2) |
| return SignedLogWeightTpl<T>(TropicalWeight(1.0), (f1 + f2)); |
| else |
| return SignedLogWeightTpl<T>(TropicalWeight(-1.0), (f1 + f2)); |
| } |
| |
| template <class T> |
| inline SignedLogWeightTpl<T> Divide(const SignedLogWeightTpl<T> &w1, |
| const SignedLogWeightTpl<T> &w2, |
| DivideType typ = DIVIDE_ANY) { |
| if (!w1.Member() || !w2.Member()) |
| return SignedLogWeightTpl<T>::NoWeight(); |
| bool s1 = w1.Value1().Value() > 0.0; |
| bool s2 = w2.Value1().Value() > 0.0; |
| T f1 = w1.Value2().Value(); |
| T f2 = w2.Value2().Value(); |
| if (f2 == FloatLimits<T>::kPosInfinity) |
| return SignedLogWeightTpl<T>(TropicalWeight(1.0), |
| FloatLimits<T>::kNumberBad); |
| else if (f1 == FloatLimits<T>::kPosInfinity) |
| return SignedLogWeightTpl<T>(TropicalWeight(1.0), |
| FloatLimits<T>::kPosInfinity); |
| else if (s1 == s2) |
| return SignedLogWeightTpl<T>(TropicalWeight(1.0), (f1 - f2)); |
| else |
| return SignedLogWeightTpl<T>(TropicalWeight(-1.0), (f1 - f2)); |
| } |
| |
| template <class T> |
| inline bool ApproxEqual(const SignedLogWeightTpl<T> &w1, |
| const SignedLogWeightTpl<T> &w2, |
| float delta = kDelta) { |
| bool s1 = w1.Value1().Value() > 0.0; |
| bool s2 = w2.Value1().Value() > 0.0; |
| if (s1 == s2) { |
| return ApproxEqual(w1.Value2(), w2.Value2(), delta); |
| } else { |
| return w1.Value2() == LogWeightTpl<T>::Zero() |
| && w2.Value2() == LogWeightTpl<T>::Zero(); |
| } |
| } |
| |
| template <class T> |
| inline bool operator==(const SignedLogWeightTpl<T> &w1, |
| const SignedLogWeightTpl<T> &w2) { |
| bool s1 = w1.Value1().Value() > 0.0; |
| bool s2 = w2.Value1().Value() > 0.0; |
| if (s1 == s2) |
| return w1.Value2() == w2.Value2(); |
| else |
| return (w1.Value2() == LogWeightTpl<T>::Zero()) && |
| (w2.Value2() == LogWeightTpl<T>::Zero()); |
| } |
| |
| |
| // Single-precision signed-log weight |
| typedef SignedLogWeightTpl<float> SignedLogWeight; |
| // Double-precision signed-log weight |
| typedef SignedLogWeightTpl<double> SignedLog64Weight; |
| |
| // |
| // WEIGHT CONVERTER SPECIALIZATIONS. |
| // |
| |
| template <class W1, class W2> |
| bool SignedLogConvertCheck(W1 w) { |
| if (w.Value1().Value() < 0.0) { |
| FSTERROR() << "WeightConvert: can't convert weight from \"" |
| << W1::Type() << "\" to \"" << W2::Type(); |
| return false; |
| } |
| return true; |
| } |
| |
| // Convert to tropical |
| template <> |
| struct WeightConvert<SignedLogWeight, TropicalWeight> { |
| TropicalWeight operator()(SignedLogWeight w) const { |
| if (!SignedLogConvertCheck<SignedLogWeight, TropicalWeight>(w)) |
| return TropicalWeight::NoWeight(); |
| return w.Value2().Value(); |
| } |
| }; |
| |
| template <> |
| struct WeightConvert<SignedLog64Weight, TropicalWeight> { |
| TropicalWeight operator()(SignedLog64Weight w) const { |
| if (!SignedLogConvertCheck<SignedLog64Weight, TropicalWeight>(w)) |
| return TropicalWeight::NoWeight(); |
| return w.Value2().Value(); |
| } |
| }; |
| |
| // Convert to log |
| template <> |
| struct WeightConvert<SignedLogWeight, LogWeight> { |
| LogWeight operator()(SignedLogWeight w) const { |
| if (!SignedLogConvertCheck<SignedLogWeight, LogWeight>(w)) |
| return LogWeight::NoWeight(); |
| return w.Value2().Value(); |
| } |
| }; |
| |
| template <> |
| struct WeightConvert<SignedLog64Weight, LogWeight> { |
| LogWeight operator()(SignedLog64Weight w) const { |
| if (!SignedLogConvertCheck<SignedLog64Weight, LogWeight>(w)) |
| return LogWeight::NoWeight(); |
| return w.Value2().Value(); |
| } |
| }; |
| |
| // Convert to log64 |
| template <> |
| struct WeightConvert<SignedLogWeight, Log64Weight> { |
| Log64Weight operator()(SignedLogWeight w) const { |
| if (!SignedLogConvertCheck<SignedLogWeight, Log64Weight>(w)) |
| return Log64Weight::NoWeight(); |
| return w.Value2().Value(); |
| } |
| }; |
| |
| template <> |
| struct WeightConvert<SignedLog64Weight, Log64Weight> { |
| Log64Weight operator()(SignedLog64Weight w) const { |
| if (!SignedLogConvertCheck<SignedLog64Weight, Log64Weight>(w)) |
| return Log64Weight::NoWeight(); |
| return w.Value2().Value(); |
| } |
| }; |
| |
| // Convert to signed log |
| template <> |
| struct WeightConvert<TropicalWeight, SignedLogWeight> { |
| SignedLogWeight operator()(TropicalWeight w) const { |
| TropicalWeight x1 = 1.0; |
| LogWeight x2 = w.Value(); |
| return SignedLogWeight(x1, x2); |
| } |
| }; |
| |
| template <> |
| struct WeightConvert<LogWeight, SignedLogWeight> { |
| SignedLogWeight operator()(LogWeight w) const { |
| TropicalWeight x1 = 1.0; |
| LogWeight x2 = w.Value(); |
| return SignedLogWeight(x1, x2); |
| } |
| }; |
| |
| template <> |
| struct WeightConvert<Log64Weight, SignedLogWeight> { |
| SignedLogWeight operator()(Log64Weight w) const { |
| TropicalWeight x1 = 1.0; |
| LogWeight x2 = w.Value(); |
| return SignedLogWeight(x1, x2); |
| } |
| }; |
| |
| template <> |
| struct WeightConvert<SignedLog64Weight, SignedLogWeight> { |
| SignedLogWeight operator()(SignedLog64Weight w) const { |
| TropicalWeight x1 = w.Value1(); |
| LogWeight x2 = w.Value2().Value(); |
| return SignedLogWeight(x1, x2); |
| } |
| }; |
| |
| // Convert to signed log64 |
| template <> |
| struct WeightConvert<TropicalWeight, SignedLog64Weight> { |
| SignedLog64Weight operator()(TropicalWeight w) const { |
| TropicalWeight x1 = 1.0; |
| Log64Weight x2 = w.Value(); |
| return SignedLog64Weight(x1, x2); |
| } |
| }; |
| |
| template <> |
| struct WeightConvert<LogWeight, SignedLog64Weight> { |
| SignedLog64Weight operator()(LogWeight w) const { |
| TropicalWeight x1 = 1.0; |
| Log64Weight x2 = w.Value(); |
| return SignedLog64Weight(x1, x2); |
| } |
| }; |
| |
| template <> |
| struct WeightConvert<Log64Weight, SignedLog64Weight> { |
| SignedLog64Weight operator()(Log64Weight w) const { |
| TropicalWeight x1 = 1.0; |
| Log64Weight x2 = w.Value(); |
| return SignedLog64Weight(x1, x2); |
| } |
| }; |
| |
| template <> |
| struct WeightConvert<SignedLogWeight, SignedLog64Weight> { |
| SignedLog64Weight operator()(SignedLogWeight w) const { |
| TropicalWeight x1 = w.Value1(); |
| Log64Weight x2 = w.Value2().Value(); |
| return SignedLog64Weight(x1, x2); |
| } |
| }; |
| |
| } // namespace fst |
| |
| #endif // FST_LIB_SIGNED_LOG_WEIGHT_H_ |