| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #ifndef TENSORFLOW_COMPILER_XLA_COMPARISON_UTIL_H_ |
| #define TENSORFLOW_COMPILER_XLA_COMPARISON_UTIL_H_ |
| |
| #include <string> |
| |
| #include "absl/base/attributes.h" |
| #include "absl/base/macros.h" |
| #include "absl/types/optional.h" |
| #include "tensorflow/compiler/xla/primitive_util.h" |
| #include "tensorflow/compiler/xla/statusor.h" |
| #include "tensorflow/compiler/xla/types.h" |
| #include "tensorflow/compiler/xla/xla_data.pb.h" |
| |
| namespace xla { |
| |
| // A utility class for primitive comparisons. A comparison includes three |
| // components: the type of the elements being compared (F32, S16, etc), whether |
| // it is a partial or total order comparison, and the actual comparison operator |
| // (==, <=, >, etc). |
| // |
| // Note that integer comparisons are always total order. Float comparisons can |
| // be either total or partial order. |
| // |
| // Some examples: |
| // |
| // Comparison a( |
| // Comparison::Direction::kLt, |
| // xla::PrimitiveType::BF16, |
| // Comparison::Order::kTotal |
| // ); |
| // a.ToString(); /* ".LT.BF16.TOTALORDER" */ |
| // |
| // Comparison b(Comparison::Direction::kEq, xla::PrimitiveType::U32); |
| // b.IsTotalOrder(); /* true */ |
| class Comparison { |
| public: |
| // Represents the ordering of the comparison. |
| enum class Order : uint8_t { |
| // https://en.wikipedia.org/wiki/Total_order |
| kTotal, |
| // https://en.wikipedia.org/wiki/Partially_ordered_set |
| kPartial, |
| }; |
| |
| // Represents different comparison operations. |
| enum class Direction : uint8_t { |
| kEq, |
| kNe, |
| kGe, |
| kGt, |
| kLe, |
| kLt, |
| }; |
| |
| // (DEPRECATED) Represents the type of comparison. Prefer xla::PrimitiveType |
| // and Comparison::Order, since there are multiple floating point |
| // representations that support total ordering. |
| enum class [[deprecated("Use PrimitiveType and Order")]] Type : uint8_t{ |
| kFloat, |
| kFloatTotalOrder, |
| kSigned, |
| kUnsigned, |
| }; |
| |
| Comparison() = delete; |
| |
| // This will default to the expected behavior for Comparison::Order: integers |
| // will use total ordering, and floats will use partial ordering. |
| explicit Comparison(Direction dir, PrimitiveType type); |
| |
| // Pass in a Comparison::Order to specify a non-default ordering, e.g., some |
| // targets may support total order floating point type comparisons. |
| explicit Comparison(Direction dir, PrimitiveType type, Order order); |
| |
| // Returns a comparison with a primitive type matching the Comparison::Type |
| // and using a default bit width of 32. For example, |
| // Comparison(Direction::kLt, Type::kFloat).PrimitiveType() /* F32 */ |
| [[deprecated( |
| "Use Comparison(Comparison::Direction, " |
| "PrimitiveType)")]] explicit Comparison(Direction dir, Type type); |
| |
| inline Direction GetDirection() const { return dir_; } |
| inline PrimitiveType GetPrimitiveType() const { return primitive_type_; } |
| inline Order GetOrder() const { return order_; } |
| |
| [[deprecated("Use GetPrimitiveType() and GetOrder()")]] inline Type GetType() |
| const { |
| return type_; |
| } |
| |
| inline bool IsEq() const { return dir_ == Direction::kEq; } |
| inline bool IsNe() const { return dir_ == Direction::kNe; } |
| inline bool IsGe() const { return dir_ == Direction::kGe; } |
| inline bool IsGt() const { return dir_ == Direction::kGt; } |
| inline bool IsLt() const { return dir_ == Direction::kLt; } |
| inline bool IsTotalOrder() const { return order_ == Order::kTotal; } |
| inline bool IsPartialOrder() const { return order_ == Order::kPartial; } |
| |
| // Returns whether this is a floating point total order comparison. |
| inline bool IsF32TotalOrder() const { |
| return primitive_type_ == PrimitiveType::F32 && IsTotalOrder(); |
| } |
| inline bool IsBf16TotalOrder() const { |
| return primitive_type_ == PrimitiveType::BF16 && IsTotalOrder(); |
| } |
| |
| // Returns whether this is a standard comparison, i.e., what you would expect |
| // as the industry standard on most architectures. |
| inline bool IsStandardF32() const { |
| return primitive_type_ == PrimitiveType::F32 && IsPartialOrder(); |
| } |
| inline bool IsStandardBf16() const { |
| return primitive_type_ == PrimitiveType::BF16 && IsPartialOrder(); |
| } |
| inline bool IsStandardS32() const { |
| return primitive_type_ == PrimitiveType::S32 && IsTotalOrder(); |
| } |
| inline bool IsStandardU32() const { |
| return primitive_type_ == PrimitiveType::U32 && IsTotalOrder(); |
| } |
| |
| inline bool IsIntegralPrimitiveType() const { |
| return primitive_util::IsIntegralType(primitive_type_); |
| } |
| inline bool IsFloatingPointPrimitiveType() const { |
| return primitive_util::IsFloatingPointType(primitive_type_); |
| } |
| |
| // Returns whether (a dir a) is always true for this comparison. |
| bool IsReflexive() const; |
| |
| // Returns whether (a dir a) is always false for this comparison. |
| bool IsAntireflexive() const; |
| |
| // Gets the converse of the given comparison direction (e.g. >= turns to <=). |
| // Useful when commuting operands to get constants into immediate-accepting |
| // positions in the ISA. |
| Comparison Converse() const; |
| |
| // Gets the inverse of the given comparison if it exists (e.g. >= turns to <). |
| // Returns optional value because not all inversions may be supported. |
| absl::optional<Comparison> Inverse() const; |
| |
| // Returns a string version of this comparison, e.g., ".GT.F32.TOTALORDER" |
| std::string ToString(std::string prefix1 = ".", std::string prefix2 = ".", |
| std::string prefix3 = ".") const; |
| |
| // Returns a comparison operator: (T, T) -> bool for this Comparison's |
| // Direction. |
| template <typename T> |
| std::function<bool(T, T)> GetComparator() const { |
| switch (GetDirection()) { |
| case Direction::kEq: |
| return std::equal_to<T>(); |
| case Direction::kNe: |
| return std::not_equal_to<T>(); |
| case Direction::kGe: |
| return std::greater_equal<T>(); |
| case Direction::kGt: |
| return std::greater<T>(); |
| case Direction::kLe: |
| return std::less_equal<T>(); |
| case Direction::kLt: |
| return std::less<T>(); |
| } |
| } |
| |
| // Applies the comparison from this Comparison's direction. Note that this |
| // does not account for the PrimitiveType and Order associated with this |
| // comparison, and instead uses the type T. For example, |
| // |
| // float operand = absl::bit_cast<float>(0x7fc00000); /* NaN */ |
| // Comparison(Direction::kEq, xla::F32, Order::kTotal) |
| // .Compare<float>(operand, operand) // false, since it's using IEEE-754. |
| template <typename T> |
| bool Compare(const T a, const T b) const { |
| return GetComparator<T>()(a, b); |
| } |
| |
| // Returns the Comparison::Type for the given primitive type. This assumes |
| // that each numerical representation follows the standard behavior, e.g., |
| // integers are total order and floats are partial order. |
| [[deprecated("Use PrimitiveType and Order")]] static Comparison::Type |
| DefaultComparisonType(PrimitiveType type); |
| |
| private: |
| // The direction of the Comparison, e.g., GT. |
| const Direction dir_; |
| // The primitive type of the Comparison operands, e.g., F32. |
| const PrimitiveType primitive_type_; |
| // The ordering of the Comparison, e.g., kPartial. |
| const Order order_; |
| // The Type of the Comparison. This tries to mesh together the ordering and |
| // the numerical data classification. |
| [[deprecated]] const Type type_; |
| }; |
| |
| using ComparisonDirection = Comparison::Direction; |
| using ComparisonOrder = Comparison::Order; |
| |
| inline std::ostream& operator<<(std::ostream& os, const Comparison& cmp) { |
| return os << cmp.ToString(); |
| } |
| |
| std::string ComparisonDirectionToString(Comparison::Direction direction); |
| std::string ComparisonTypeToString(Comparison::Type type); |
| std::string ComparisonPrimitiveTypeToString(PrimitiveType type); |
| std::string ComparisonOrderToString(Comparison::Order order); |
| |
| StatusOr<Comparison::Direction> StringToComparisonDirection( |
| absl::string_view direction); |
| StatusOr<Comparison::Type> StringToComparisonType(absl::string_view comparison); |
| StatusOr<Comparison::Order> StringToComparisonOrder(absl::string_view order); |
| |
| // Returns a comparison function using the provided key function on each value, |
| // i.e. `key_fn(a) < key_fn(b)`. |
| template <typename KeyFn> |
| auto LessThanByKey(KeyFn&& key_fn) { |
| return [=](const auto& a, const auto& b) { return key_fn(a) < key_fn(b); }; |
| } |
| |
| } // namespace xla |
| #endif // TENSORFLOW_COMPILER_XLA_COMPARISON_UTIL_H_ |