blob: 2189121a225c61cc9547b53be0d149a0a3c2de48 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_COMPARISON_UTIL_H_
#define TENSORFLOW_COMPILER_XLA_COMPARISON_UTIL_H_
#include <string>
#include "absl/base/attributes.h"
#include "absl/base/macros.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
// A utility class for primitive comparisons. A comparison includes three
// components: the type of the elements being compared (F32, S16, etc), whether
// it is a partial or total order comparison, and the actual comparison operator
// (==, <=, >, etc).
//
// Note that integer comparisons are always total order. Float comparisons can
// be either total or partial order.
//
// Some examples:
//
// Comparison a(
// Comparison::Direction::kLt,
// xla::PrimitiveType::BF16,
// Comparison::Order::kTotal
// );
// a.ToString(); /* ".LT.BF16.TOTALORDER" */
//
// Comparison b(Comparison::Direction::kEq, xla::PrimitiveType::U32);
// b.IsTotalOrder(); /* true */
class Comparison {
public:
// Represents the ordering of the comparison.
enum class Order : uint8_t {
// https://en.wikipedia.org/wiki/Total_order
kTotal,
// https://en.wikipedia.org/wiki/Partially_ordered_set
kPartial,
};
// Represents different comparison operations.
enum class Direction : uint8_t {
kEq,
kNe,
kGe,
kGt,
kLe,
kLt,
};
// (DEPRECATED) Represents the type of comparison. Prefer xla::PrimitiveType
// and Comparison::Order, since there are multiple floating point
// representations that support total ordering.
enum class [[deprecated("Use PrimitiveType and Order")]] Type : uint8_t{
kFloat,
kFloatTotalOrder,
kSigned,
kUnsigned,
};
Comparison() = delete;
// This will default to the expected behavior for Comparison::Order: integers
// will use total ordering, and floats will use partial ordering.
explicit Comparison(Direction dir, PrimitiveType type);
// Pass in a Comparison::Order to specify a non-default ordering, e.g., some
// targets may support total order floating point type comparisons.
explicit Comparison(Direction dir, PrimitiveType type, Order order);
// Returns a comparison with a primitive type matching the Comparison::Type
// and using a default bit width of 32. For example,
// Comparison(Direction::kLt, Type::kFloat).PrimitiveType() /* F32 */
[[deprecated(
"Use Comparison(Comparison::Direction, "
"PrimitiveType)")]] explicit Comparison(Direction dir, Type type);
inline Direction GetDirection() const { return dir_; }
inline PrimitiveType GetPrimitiveType() const { return primitive_type_; }
inline Order GetOrder() const { return order_; }
[[deprecated("Use GetPrimitiveType() and GetOrder()")]] inline Type GetType()
const {
return type_;
}
inline bool IsEq() const { return dir_ == Direction::kEq; }
inline bool IsNe() const { return dir_ == Direction::kNe; }
inline bool IsGe() const { return dir_ == Direction::kGe; }
inline bool IsGt() const { return dir_ == Direction::kGt; }
inline bool IsLt() const { return dir_ == Direction::kLt; }
inline bool IsTotalOrder() const { return order_ == Order::kTotal; }
inline bool IsPartialOrder() const { return order_ == Order::kPartial; }
// Returns whether this is a floating point total order comparison.
inline bool IsF32TotalOrder() const {
return primitive_type_ == PrimitiveType::F32 && IsTotalOrder();
}
inline bool IsBf16TotalOrder() const {
return primitive_type_ == PrimitiveType::BF16 && IsTotalOrder();
}
// Returns whether this is a standard comparison, i.e., what you would expect
// as the industry standard on most architectures.
inline bool IsStandardF32() const {
return primitive_type_ == PrimitiveType::F32 && IsPartialOrder();
}
inline bool IsStandardBf16() const {
return primitive_type_ == PrimitiveType::BF16 && IsPartialOrder();
}
inline bool IsStandardS32() const {
return primitive_type_ == PrimitiveType::S32 && IsTotalOrder();
}
inline bool IsStandardU32() const {
return primitive_type_ == PrimitiveType::U32 && IsTotalOrder();
}
inline bool IsIntegralPrimitiveType() const {
return primitive_util::IsIntegralType(primitive_type_);
}
inline bool IsFloatingPointPrimitiveType() const {
return primitive_util::IsFloatingPointType(primitive_type_);
}
// Returns whether (a dir a) is always true for this comparison.
bool IsReflexive() const;
// Returns whether (a dir a) is always false for this comparison.
bool IsAntireflexive() const;
// Gets the converse of the given comparison direction (e.g. >= turns to <=).
// Useful when commuting operands to get constants into immediate-accepting
// positions in the ISA.
Comparison Converse() const;
// Gets the inverse of the given comparison if it exists (e.g. >= turns to <).
// Returns optional value because not all inversions may be supported.
absl::optional<Comparison> Inverse() const;
// Returns a string version of this comparison, e.g., ".GT.F32.TOTALORDER"
std::string ToString(std::string prefix1 = ".", std::string prefix2 = ".",
std::string prefix3 = ".") const;
// Returns a comparison operator: (T, T) -> bool for this Comparison's
// Direction.
template <typename T>
std::function<bool(T, T)> GetComparator() const {
switch (GetDirection()) {
case Direction::kEq:
return std::equal_to<T>();
case Direction::kNe:
return std::not_equal_to<T>();
case Direction::kGe:
return std::greater_equal<T>();
case Direction::kGt:
return std::greater<T>();
case Direction::kLe:
return std::less_equal<T>();
case Direction::kLt:
return std::less<T>();
}
}
// Applies the comparison from this Comparison's direction. Note that this
// does not account for the PrimitiveType and Order associated with this
// comparison, and instead uses the type T. For example,
//
// float operand = absl::bit_cast<float>(0x7fc00000); /* NaN */
// Comparison(Direction::kEq, xla::F32, Order::kTotal)
// .Compare<float>(operand, operand) // false, since it's using IEEE-754.
template <typename T>
bool Compare(const T a, const T b) const {
return GetComparator<T>()(a, b);
}
// Returns the Comparison::Type for the given primitive type. This assumes
// that each numerical representation follows the standard behavior, e.g.,
// integers are total order and floats are partial order.
[[deprecated("Use PrimitiveType and Order")]] static Comparison::Type
DefaultComparisonType(PrimitiveType type);
private:
// The direction of the Comparison, e.g., GT.
const Direction dir_;
// The primitive type of the Comparison operands, e.g., F32.
const PrimitiveType primitive_type_;
// The ordering of the Comparison, e.g., kPartial.
const Order order_;
// The Type of the Comparison. This tries to mesh together the ordering and
// the numerical data classification.
[[deprecated]] const Type type_;
};
using ComparisonDirection = Comparison::Direction;
using ComparisonOrder = Comparison::Order;
inline std::ostream& operator<<(std::ostream& os, const Comparison& cmp) {
return os << cmp.ToString();
}
std::string ComparisonDirectionToString(Comparison::Direction direction);
std::string ComparisonTypeToString(Comparison::Type type);
std::string ComparisonPrimitiveTypeToString(PrimitiveType type);
std::string ComparisonOrderToString(Comparison::Order order);
StatusOr<Comparison::Direction> StringToComparisonDirection(
absl::string_view direction);
StatusOr<Comparison::Type> StringToComparisonType(absl::string_view comparison);
StatusOr<Comparison::Order> StringToComparisonOrder(absl::string_view order);
// Returns a comparison function using the provided key function on each value,
// i.e. `key_fn(a) < key_fn(b)`.
template <typename KeyFn>
auto LessThanByKey(KeyFn&& key_fn) {
return [=](const auto& a, const auto& b) { return key_fn(a) < key_fn(b); };
}
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_COMPARISON_UTIL_H_