support optional comparisons with different but comparable types (#62890)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/62565
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62890
Reviewed By: ejguan
Differential Revision: D30396008
Pulled By: dagitses
fbshipit-source-id: fca02207509f882973d54484f89c4d116505fc66
diff --git a/c10/test/util/optional_test.cpp b/c10/test/util/optional_test.cpp
index cac325f..ac976b4 100644
--- a/c10/test/util/optional_test.cpp
+++ b/c10/test/util/optional_test.cpp
@@ -146,10 +146,16 @@
using CmpTestTypes = testing::Types<
// between two optionals
std::pair<c10::optional<int>, c10::optional<int>>,
+
// between an optional and a value
std::pair<c10::optional<int>, int>,
// between a value and an optional
- std::pair<int, c10::optional<int>>>;
+ std::pair<int, c10::optional<int>>,
+
+ // between an optional and a differently typed value
+ std::pair<c10::optional<int>, long>,
+ // between a differently typed value and an optional
+ std::pair<long, c10::optional<int>>>;
template <typename T>
class CmpTest : public testing::Test {};
TYPED_TEST_CASE(CmpTest, CmpTestTypes);
diff --git a/c10/util/Optional.h b/c10/util/Optional.h
index 5e0684b..7044c79 100644
--- a/c10/util/Optional.h
+++ b/c10/util/Optional.h
@@ -1049,63 +1049,63 @@
}
// 20.5.10, Comparison with T
-template <class T>
-constexpr bool operator==(const optional<T>& x, const T& v) {
+template <class T, class U>
+constexpr bool operator==(const optional<T>& x, const U& v) {
return bool(x) ? *x == v : false;
}
-template <class T>
-constexpr bool operator==(const T& v, const optional<T>& x) {
+template <class T, class U>
+constexpr bool operator==(const U& v, const optional<T>& x) {
return bool(x) ? v == *x : false;
}
-template <class T>
-constexpr bool operator!=(const optional<T>& x, const T& v) {
+template <class T, class U>
+constexpr bool operator!=(const optional<T>& x, const U& v) {
return bool(x) ? *x != v : true;
}
-template <class T>
-constexpr bool operator!=(const T& v, const optional<T>& x) {
+template <class T, class U>
+constexpr bool operator!=(const U& v, const optional<T>& x) {
return bool(x) ? v != *x : true;
}
-template <class T>
-constexpr bool operator<(const optional<T>& x, const T& v) {
+template <class T, class U>
+constexpr bool operator<(const optional<T>& x, const U& v) {
return bool(x) ? *x < v : true;
}
-template <class T>
-constexpr bool operator>(const T& v, const optional<T>& x) {
+template <class T, class U>
+constexpr bool operator>(const U& v, const optional<T>& x) {
return bool(x) ? v > *x : true;
}
-template <class T>
-constexpr bool operator>(const optional<T>& x, const T& v) {
+template <class T, class U>
+constexpr bool operator>(const optional<T>& x, const U& v) {
return bool(x) ? *x > v : false;
}
-template <class T>
-constexpr bool operator<(const T& v, const optional<T>& x) {
+template <class T, class U>
+constexpr bool operator<(const U& v, const optional<T>& x) {
return bool(x) ? v < *x : false;
}
-template <class T>
-constexpr bool operator>=(const optional<T>& x, const T& v) {
+template <class T, class U>
+constexpr bool operator>=(const optional<T>& x, const U& v) {
return bool(x) ? *x >= v : false;
}
-template <class T>
-constexpr bool operator<=(const T& v, const optional<T>& x) {
+template <class T, class U>
+constexpr bool operator<=(const U& v, const optional<T>& x) {
return bool(x) ? v <= *x : false;
}
-template <class T>
-constexpr bool operator<=(const optional<T>& x, const T& v) {
+template <class T, class U>
+constexpr bool operator<=(const optional<T>& x, const U& v) {
return bool(x) ? *x <= v : true;
}
-template <class T>
-constexpr bool operator>=(const T& v, const optional<T>& x) {
+template <class T, class U>
+constexpr bool operator>=(const U& v, const optional<T>& x) {
return bool(x) ? v >= *x : true;
}