Support operations on c10::complex and integer scalars (#38418)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38418

This is useful in reducing verbosity in c10::complex's general usage, and potentially also offers
performance benefits.

This brings back #34506 (which was made for std::complex).

Differential Revision: D21587012

Test Plan: Imported from OSS

Pulled By: malfet

fbshipit-source-id: 6dd10c2f417d6f6d0935c9e1d8b457fd29c163af
diff --git a/c10/test/util/complex_test_common.h b/c10/test/util/complex_test_common.h
index 73dd3b8..fe61ca0 100644
--- a/c10/test/util/complex_test_common.h
+++ b/c10/test/util/complex_test_common.h
@@ -389,6 +389,32 @@
   test_arithmetic_<double>();
 }
 
+template<typename T, typename int_t>
+void test_binary_ops_for_int_type_(T real, T img, int_t num) {
+  c10::complex<T> c(real, img);
+  ASSERT_EQ(c + num, c10::complex<T>(real + num, img));
+  ASSERT_EQ(num + c, c10::complex<T>(num + real, img));
+  ASSERT_EQ(c - num, c10::complex<T>(real - num, img));
+  ASSERT_EQ(num - c, c10::complex<T>(num - real, -img));
+  ASSERT_EQ(c * num, c10::complex<T>(real * num, img * num));
+  ASSERT_EQ(num * c, c10::complex<T>(num * real, num * img));
+  ASSERT_EQ(c / num, c10::complex<T>(real / num, img / num));
+  ASSERT_EQ(num / c, c10::complex<T>(num * real / std::norm(c), -num * img / std::norm(c)));
+}
+
+template<typename T>
+void test_binary_ops_for_all_int_types_(T real, T img, int8_t i) {
+  test_binary_ops_for_int_type_<T, int8_t>(real, img, i);
+  test_binary_ops_for_int_type_<T, int16_t>(real, img, i);
+  test_binary_ops_for_int_type_<T, int32_t>(real, img, i);
+  test_binary_ops_for_int_type_<T, int64_t>(real, img, i);
+}
+
+TEST(TestArithmeticIntScalar, All) {
+  test_binary_ops_for_all_int_types_<float>(1.0, 0.1, 1);
+  test_binary_ops_for_all_int_types_<double>(-1.3, -0.2, -2);
+}
+
 } // namespace arithmetic
 
 namespace equality {
@@ -407,7 +433,7 @@
   test_equality_<float>();
   test_equality_<double>();
 }
-  
+
 } // namespace equality
 
 namespace io {
diff --git a/c10/util/complex_type.h b/c10/util/complex_type.h
index a12b9ea..7b2093b 100644
--- a/c10/util/complex_type.h
+++ b/c10/util/complex_type.h
@@ -374,6 +374,56 @@
   return result /= rhs;
 }
 
+
+// Define operators between integral scalars and c10::complex. std::complex does not support this when T is a
+// floating-point number. This is useful because it saves a lot of "static_cast" when operate a complex and an integer.
+// This makes the code both less verbose and potentially more efficient.
+#define COMPLEX_INTEGER_OP_TEMPLATE_CONDITION \
+  typename std::enable_if_t<std::is_floating_point<fT>::value && std::is_integral<iT>::value, int> = 0
+
+template<typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
+constexpr c10::complex<fT> operator+(const c10::complex<fT>& a, const iT& b) {
+  return a + static_cast<fT>(b);
+}
+
+template<typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
+constexpr c10::complex<fT> operator+(const iT& a, const c10::complex<fT>& b) {
+  return static_cast<fT>(a) + b;
+}
+
+template<typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
+constexpr c10::complex<fT> operator-(const c10::complex<fT>& a, const iT& b) {
+  return a - static_cast<fT>(b);
+}
+
+template<typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
+constexpr c10::complex<fT> operator-(const iT& a, const c10::complex<fT>& b) {
+  return static_cast<fT>(a) - b;
+}
+
+template<typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
+constexpr c10::complex<fT> operator*(const c10::complex<fT>& a, const iT& b) {
+  return a * static_cast<fT>(b);
+}
+
+template<typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
+constexpr c10::complex<fT> operator*(const iT& a, const c10::complex<fT>& b) {
+  return static_cast<fT>(a) * b;
+}
+
+template<typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
+constexpr c10::complex<fT> operator/(const c10::complex<fT>& a, const iT& b) {
+  return a / static_cast<fT>(b);
+}
+
+template<typename fT, typename iT, COMPLEX_INTEGER_OP_TEMPLATE_CONDITION>
+constexpr c10::complex<fT> operator/(const iT& a, const c10::complex<fT>& b) {
+  return static_cast<fT>(a) / b;
+}
+
+#undef COMPLEX_INTEGER_OP_TEMPLATE_CONDITION
+
+
 template<typename T>
 constexpr bool operator==(const c10::complex<T>& lhs, const c10::complex<T>& rhs) {
   return (lhs.real() == rhs.real()) && (lhs.imag() == rhs.imag());