Conversions to and from complex numbers. (#11420)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11420
Surprisingly tricky! Here are the major pieces:
- We grow a even yet more ludicrous macro
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF
which does what it says on the tin. This is because I was
too lazy to figure out how to define the necessary conversions
in and out of ComplexHalf without triggering ambiguity problems.
It doesn't seem to be as simple as just Half. Leave it for
when someone actually wants this.
- Scalar now can hold std::complex<double>. Internally, it is
stored as double[2] because nvcc chokes on a non-POD type
inside a union.
- overflow() checking is generalized to work with complex.
When converting *to* std::complex<T>, all we need to do is check
for overflow against T. When converting *from* complex, we
must check (1) if To is not complex, that imag() == 0
and (2) for overflow componentwise.
- convert() is generalized to work with complex<->real conversions.
Complex to real drops the imaginary component; we rely on
overflow checking to tell if this actually loses fidelity. To get
the specializations and overloads to work out, we introduce
a new Converter class that actually is specializable.
- Complex scalars convert into Python complex numbers
- This probably fixes complex tensor printing, but there is no way
to test this right now.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Reviewed By: cpuhrsch
Differential Revision: D9697878
Pulled By: ezyang
fbshipit-source-id: 181519e56bbab67ed1e5b49c691b873e124d7946
diff --git a/aten/src/ATen/core/Half.h b/aten/src/ATen/core/Half.h
index d381d1a..c306fcd 100644
--- a/aten/src/ATen/core/Half.h
+++ b/aten/src/ATen/core/Half.h
@@ -10,6 +10,7 @@
/// intrinsics directly on the Half type from device code.
#include <ATen/core/Macros.h>
+#include <ATen/core/C++17.h>
#include <cmath>
#include <cstdint>
@@ -19,6 +20,7 @@
#include <string>
#include <utility>
#include <sstream>
+#include <complex>
#ifdef __CUDACC__
#include <cuda_fp16.h>
@@ -74,18 +76,66 @@
Half real_;
Half imag_;
ComplexHalf() = default;
+ Half real() const { return real_; }
+ Half imag() const { return imag_; }
+ inline ComplexHalf(std::complex<float> value)
+ : real_(value.real()), imag_(value.imag()) {}
+ inline operator std::complex<float>() const {
+ return {real_, imag_};
+ }
+};
+
+template <typename T>
+struct is_complex_t : public std::false_type {};
+
+template <typename T>
+struct is_complex_t<std::complex<T>> : public std::true_type {};
+
+template <>
+struct is_complex_t<ComplexHalf> : public std::true_type {};
+
+// Extract double from std::complex<double>; is identity otherwise
+// TODO: Write in more idiomatic C++17
+template <typename T> struct scalar_value_type { using type = T; };
+template <typename T> struct scalar_value_type<std::complex<T>> { using type = T; };
+template <> struct scalar_value_type<ComplexHalf> { using type = Half; };
+
+// The old implementation of Converter as a function made nvcc's head explode
+// when we added std::complex on top of the specializations for CUDA-only types
+// like __half, so I rewrote it as a templated class (so, no more overloads,
+// just (partial) specialization).
+
+template <typename To, typename From, typename Enable = void>
+struct Converter {
+ To operator()(From f) {
+ return static_cast<To>(f);
+ }
};
template <typename To, typename From>
-To convert(From f) {
- return static_cast<To>(f);
+To convert(From from) {
+ return Converter<To, From>()(from);
}
+template <typename To, typename FromV>
+struct Converter<
+ To, std::complex<FromV>,
+ typename std::enable_if<
+ c10::guts::negation<
+ is_complex_t<To>
+ >::value
+ >::type
+> {
+ To operator()(std::complex<FromV> f) {
+ return static_cast<To>(f.real());
+ }
+};
+
// skip isnan and isinf check for integral types
template <typename To, typename From>
typename std::enable_if<std::is_integral<From>::value, bool>::type overflows(
From f) {
- using limit = std::numeric_limits<To>;
+ using limit = std::numeric_limits<typename scalar_value_type<To>::type>;
if (!limit::is_signed && std::numeric_limits<From>::is_signed) {
// allow for negative numbers to wrap using two's complement arithmetic.
// For example, with uint8, this allows for `a - b` to be treated as
@@ -97,9 +147,9 @@
}
template <typename To, typename From>
-typename std::enable_if<!std::is_integral<From>::value, bool>::type overflows(
+typename std::enable_if<std::is_floating_point<From>::value, bool>::type overflows(
From f) {
- using limit = std::numeric_limits<To>;
+ using limit = std::numeric_limits<typename scalar_value_type<To>::type>;
if (limit::has_infinity && std::isinf(static_cast<double>(f))) {
return false;
}
@@ -109,6 +159,23 @@
return f < limit::lowest() || f > limit::max();
}
+
+template <typename To, typename From>
+typename std::enable_if<is_complex_t<From>::value, bool>::type overflows(
+ From f) {
+ // casts from complex to real are considered to overflow if the
+ // imaginary component is non-zero
+ if (!is_complex_t<To>::value && f.imag() != 0) {
+ return true;
+ }
+ // Check for overflow componentwise
+ // (Technically, the imag overflow check is guaranteed to be false
+ // when !is_complex_t<To>, but any optimizer worth its salt will be
+ // able to figure it out.)
+ return overflows<typename scalar_value_type<To>::type, typename From::value_type>(f.real()) ||
+ overflows<typename scalar_value_type<To>::type, typename From::value_type>(f.imag());
+}
+
template <typename To, typename From>
To checked_convert(From f, const char* name) {
if (overflows<To, From>(f)) {
diff --git a/aten/src/ATen/core/Scalar.cpp b/aten/src/ATen/core/Scalar.cpp
index 4916e39..7bdc770 100644
--- a/aten/src/ATen/core/Scalar.cpp
+++ b/aten/src/ATen/core/Scalar.cpp
@@ -3,11 +3,13 @@
namespace at {
Scalar Scalar::operator-() const {
- if (isFloatingPoint()) {
- return Scalar(-v.d);
- } else {
- return Scalar(-v.i);
- }
+ if (isFloatingPoint()) {
+ return Scalar(-v.d);
+ } else if (isComplex()) {
+ return Scalar(std::complex<double>(-v.z[0], -v.z[1]));
+ } else {
+ return Scalar(-v.i);
+ }
}
} // namespace at
diff --git a/aten/src/ATen/core/Scalar.h b/aten/src/ATen/core/Scalar.h
index 0e55688..35c4b53 100644
--- a/aten/src/ATen/core/Scalar.h
+++ b/aten/src/ATen/core/Scalar.h
@@ -23,21 +23,39 @@
: tag(Tag::HAS_##member) { \
v . member = convert<decltype(v.member),type>(vv); \
}
+ // We can't set v in the initializer list using the
+ // syntax v{ .member = ... } because it doesn't work on MSVC
AT_FORALL_SCALAR_TYPES(DEFINE_IMPLICIT_CTOR)
#undef DEFINE_IMPLICIT_CTOR
+#define DEFINE_IMPLICIT_COMPLEX_CTOR(type,name,member) \
+ Scalar(type vv) \
+ : tag(Tag::HAS_##member) { \
+ v . member[0] = convert<double>(vv.real()); \
+ v . member[1] = convert<double>(vv.imag()); \
+ }
+
+ DEFINE_IMPLICIT_COMPLEX_CTOR(at::ComplexHalf,ComplexHalf,z)
+ DEFINE_IMPLICIT_COMPLEX_CTOR(std::complex<float>,ComplexFloat,z)
+ DEFINE_IMPLICIT_COMPLEX_CTOR(std::complex<double>,ComplexDouble,z)
+
+#undef DEFINE_IMPLICIT_COMPLEX_CTOR
+
#define DEFINE_ACCESSOR(type,name,member) \
type to##name () const { \
if (Tag::HAS_d == tag) { \
return checked_convert<type, double>(v.d, #type); \
+ } else if (Tag::HAS_z == tag) { \
+ return checked_convert<type, std::complex<double>>({v.z[0], v.z[1]}, #type); \
} else { \
return checked_convert<type, int64_t>(v.i, #type); \
} \
}
- AT_FORALL_SCALAR_TYPES(DEFINE_ACCESSOR)
+ // TODO: Support ComplexHalf accessor
+ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_ACCESSOR)
//also support scalar.to<int64_t>();
template<typename T>
@@ -50,15 +68,22 @@
bool isIntegral() const {
return Tag::HAS_i == tag;
}
+ bool isComplex() const {
+ return Tag::HAS_z == tag;
+ }
Scalar operator-() const;
private:
- enum class Tag { HAS_d, HAS_i };
+ enum class Tag { HAS_d, HAS_i, HAS_z };
Tag tag;
union {
double d;
- int64_t i = 0;
+ int64_t i;
+ // Can't do put std::complex in the union, because it triggers
+ // an nvcc bug:
+ // error: designator may not specify a non-POD subobject
+ double z[2];
} v;
friend struct Type;
};
diff --git a/aten/src/ATen/core/ScalarType.h b/aten/src/ATen/core/ScalarType.h
index 7c8f124..b5e1a47 100644
--- a/aten/src/ATen/core/ScalarType.h
+++ b/aten/src/ATen/core/ScalarType.h
@@ -25,6 +25,21 @@
_(std::complex<float>,ComplexFloat,z) /* 9 */ \
_(std::complex<double>,ComplexDouble,z) /* 10 */
+// If you want to support ComplexHalf for real, replace occurrences
+// of this macro with AT_FORALL_SCALAR_TYPES_WITH_COMPLEX. But
+// beware: convert() doesn't work for all the conversions you need...
+#define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(_) \
+_(uint8_t,Byte,i) \
+_(int8_t,Char,i) \
+_(int16_t,Short,i) \
+_(int,Int,i) \
+_(int64_t,Long,i) \
+_(at::Half,Half,d) \
+_(float,Float,d) \
+_(double,Double,d) \
+_(std::complex<float>,ComplexFloat,z) \
+_(std::complex<double>,ComplexDouble,z)
+
#define AT_FORALL_SCALAR_TYPES(_) \
_(uint8_t,Byte,i) \
_(int8_t,Char,i) \
diff --git a/aten/src/ATen/cuda/CUDAHalf.cu b/aten/src/ATen/cuda/CUDAHalf.cu
index a552b44..bd12125 100644
--- a/aten/src/ATen/cuda/CUDAHalf.cu
+++ b/aten/src/ATen/cuda/CUDAHalf.cu
@@ -7,36 +7,31 @@
namespace at {
#if CUDA_VERSION < 9000 && !defined(__HIP_PLATFORM_HCC__)
-template <> AT_CUDA_API
-half convert(Half aten_half) {
+
+half Converter<half, Half>::operator()(Half aten_half) {
return half{aten_half.x};
}
-template <> AT_CUDA_API
-half convert(double value) {
+half Converter<half, double>::operator()(double value) {
return half{Half(value).x};
}
-template <> AT_CUDA_API
-Half convert(half cuda_half) {
+Half Converter<Half, half>::operator()(half cuda_half) {
return Half(cuda_half.x, Half::from_bits);
}
#else
-template <> AT_CUDA_API
-half convert(Half aten_half) {
+half Converter<half, Half>::operator()(Half aten_half) {
__half_raw x_raw;
x_raw.x = aten_half.x;
return half(x_raw);
}
-template <> AT_CUDA_API
-Half convert(half cuda_half) {
+Half Converter<Half, half>::operator()(half cuda_half) {
__half_raw raw(cuda_half);
return Half(raw.x, Half::from_bits);
}
-template <> AT_CUDA_API
-half convert(double value) {
+half Converter<half, double>::operator()(double value) {
__half_raw raw;
raw.x = Half(value).x;
return half {raw};
diff --git a/aten/src/ATen/cuda/CUDAHalf.cuh b/aten/src/ATen/cuda/CUDAHalf.cuh
index 034ce27..6558ed5 100644
--- a/aten/src/ATen/cuda/CUDAHalf.cuh
+++ b/aten/src/ATen/cuda/CUDAHalf.cuh
@@ -8,9 +8,22 @@
#include <cuda_fp16.h>
namespace at {
-template <> AT_CUDA_API half convert(Half aten_half);
-template <> AT_CUDA_API Half convert(half cuda_half);
-template <> AT_CUDA_API half convert(double value);
+
+template <>
+struct AT_CUDA_API Converter<half, Half> {
+ half operator()(Half);
+};
+
+template <>
+struct AT_CUDA_API Converter<Half, half> {
+ Half operator()(half);
+};
+
+template <>
+struct AT_CUDA_API Converter<half, double> {
+ half operator()(double);
+};
+
#if CUDA_VERSION >= 9000 || defined(__HIP_PLATFORM_HCC__)
template <> __half HalfFix(Half h);
template <> Half HalfFix(__half h);
diff --git a/aten/src/ATen/native/TypeProperties.cpp b/aten/src/ATen/native/TypeProperties.cpp
index a3c5f68..af81c02 100644
--- a/aten/src/ATen/native/TypeProperties.cpp
+++ b/aten/src/ATen/native/TypeProperties.cpp
@@ -13,6 +13,10 @@
return self.type().is_distributed();
}
+bool is_complex(const Tensor& self) {
+ return at::isComplexType(self.type().scalarType());
+}
+
bool is_floating_point(const Tensor& self) {
return at::isFloatingType(self.type().scalarType());
}
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 4e14e8f..2b71982 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -860,6 +860,10 @@
variants: function, method
device_guard: false
+- func: is_complex(Tensor self) -> bool
+ variants: function, method
+ device_guard: false
+
- func: is_nonzero(Tensor self) -> bool
variants: function, method
device_guard: false
diff --git a/aten/src/ATen/templates/Tensor.h b/aten/src/ATen/templates/Tensor.h
index f6c357c..45471dc 100644
--- a/aten/src/ATen/templates/Tensor.h
+++ b/aten/src/ATen/templates/Tensor.h
@@ -194,12 +194,12 @@
//toLongData(), toFloatData() etc.
#define TO_TYPE_DATA(T,name,_) \
T * to##name##Data() const;
- AT_FORALL_SCALAR_TYPES(TO_TYPE_DATA)
+ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(TO_TYPE_DATA)
#undef TO_TYPE_DATA
#define TO_C_TYPE(T,name,_) \
T toC##name () const;
- AT_FORALL_SCALAR_TYPES(TO_C_TYPE)
+ AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(TO_C_TYPE)
#undef TO_C_TYPE
template<typename T, size_t N>
diff --git a/aten/src/ATen/templates/TensorMethods.h b/aten/src/ATen/templates/TensorMethods.h
index 4aaaf82..4520a54 100644
--- a/aten/src/ATen/templates/TensorMethods.h
+++ b/aten/src/ATen/templates/TensorMethods.h
@@ -101,13 +101,13 @@
return data<T>(); \
}
-AT_FORALL_SCALAR_TYPES(DEFINE_CAST)
+AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_CAST)
#undef DEFINE_CAST
#define DEFINE_TO_C_TYPE(T,name,_) \
inline T Tensor::toC##name () const { return _local_scalar().to##name (); }
-AT_FORALL_SCALAR_TYPES(DEFINE_TO_C_TYPE)
+AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_TO_C_TYPE)
#undef DEFINE_TO_C_TYPE
} //namespace at
diff --git a/test/test_torch.py b/test/test_torch.py
index e0af307..ea0d69d 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -189,6 +189,7 @@
'is_coalesced',
'is_distributed',
'is_floating_point',
+ 'is_complex',
'is_nonzero',
'is_same_size',
'is_signed',
diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp
index ea8869e..a863a18 100644
--- a/tools/autograd/templates/python_variable_methods.cpp
+++ b/tools/autograd/templates/python_variable_methods.cpp
@@ -159,6 +159,15 @@
return self.toCDouble();
}
+static std::complex<double> dispatch_to_CComplexDouble(const Tensor & self) {
+ AutoNoGIL no_gil;
+ DeviceGuard device_guard(self);
+ if (self.numel() != 1) {
+ throw ValueError("only one element tensors can be converted to Python scalars");
+ }
+ return self.toCComplexDouble();
+}
+
static int64_t dispatch_to_CLong(const Tensor & self) {
AutoNoGIL no_gil;
DeviceGuard device_guard(self);
@@ -365,6 +374,8 @@
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
if (self_.is_floating_point()) {
return wrap(dispatch_to_CDouble(self_));
+ } else if (self_.is_complex()) {
+ return wrap(dispatch_to_CComplexDouble(self_));
} else {
return wrap(dispatch_to_CLong(self_));
}
diff --git a/torch/csrc/autograd/utils/wrap_outputs.h b/torch/csrc/autograd/utils/wrap_outputs.h
index 1417956..40ef7ce 100644
--- a/torch/csrc/autograd/utils/wrap_outputs.h
+++ b/torch/csrc/autograd/utils/wrap_outputs.h
@@ -81,6 +81,12 @@
return PyFloat_FromDouble(value);
}
+inline PyObject* wrap(std::complex<double> value) {
+ // I could probably also use FromComplex with a reinterpret cast,
+ // but... eh.
+ return PyComplex_FromDoubles(value.real(), value.imag());
+}
+
inline PyObject* wrap(void* value) {
return THPUtils_packInt64(reinterpret_cast<intptr_t>(value));
}