|  | #pragma once | 
|  |  | 
|  | #include <assert.h> | 
|  | #include <stdint.h> | 
|  | #include <stdexcept> | 
|  | #include <string> | 
|  | #include <utility> | 
|  |  | 
|  | #include <c10/macros/Macros.h> | 
|  | #include <c10/core/ScalarType.h> | 
|  | #include <c10/Half.h> | 
|  |  | 
|  | namespace c10 { | 
|  |  | 
|  | /** | 
|  | * Scalar represents a 0-dimensional tensor which contains a single element. | 
|  | * Unlike a tensor, numeric literals (in C++) are implicitly convertible to Scalar | 
|  | * (which is why, for example, we provide both add(Tensor) and add(Scalar) overloads | 
|  | * for many operations). It may also be used in circumstances where you statically | 
|  | * know a tensor is 0-dim and single size, but don't know it's type. | 
|  | */ | 
|  | class C10_API Scalar { | 
|  | public: | 
|  | Scalar() : Scalar(int64_t(0)) {} | 
|  |  | 
|  | #define DEFINE_IMPLICIT_CTOR(type,name,member) \ | 
|  | Scalar(type vv) \ | 
|  | : tag(Tag::HAS_##member) { \ | 
|  | v . member = convert<decltype(v.member),type>(vv); \ | 
|  | } | 
|  | // We can't set v in the initializer list using the | 
|  | // syntax v{ .member = ... } because it doesn't work on MSVC | 
|  |  | 
|  | AT_FORALL_SCALAR_TYPES(DEFINE_IMPLICIT_CTOR) | 
|  |  | 
|  | #undef DEFINE_IMPLICIT_CTOR | 
|  |  | 
|  | #define DEFINE_IMPLICIT_COMPLEX_CTOR(type, name, member) \ | 
|  | Scalar(type vv) : tag(Tag::HAS_##member) {             \ | 
|  | v.member[0] = c10::convert<double>(vv.real());       \ | 
|  | v.member[1] = c10::convert<double>(vv.imag());       \ | 
|  | } | 
|  |  | 
|  | DEFINE_IMPLICIT_COMPLEX_CTOR(at::ComplexHalf,ComplexHalf,z) | 
|  | DEFINE_IMPLICIT_COMPLEX_CTOR(std::complex<float>,ComplexFloat,z) | 
|  | DEFINE_IMPLICIT_COMPLEX_CTOR(std::complex<double>,ComplexDouble,z) | 
|  |  | 
|  | #undef DEFINE_IMPLICIT_COMPLEX_CTOR | 
|  |  | 
|  | #define DEFINE_ACCESSOR(type,name,member) \ | 
|  | type to##name () const { \ | 
|  | if (Tag::HAS_d == tag) { \ | 
|  | return checked_convert<type, double>(v.d, #type); \ | 
|  | } else if (Tag::HAS_z == tag) { \ | 
|  | return checked_convert<type, std::complex<double>>({v.z[0], v.z[1]}, #type); \ | 
|  | } else { \ | 
|  | return checked_convert<type, int64_t>(v.i, #type); \ | 
|  | } \ | 
|  | } | 
|  |  | 
|  | // TODO: Support ComplexHalf accessor | 
|  | AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_ACCESSOR) | 
|  |  | 
|  | //also support scalar.to<int64_t>(); | 
|  | template<typename T> | 
|  | T to(); | 
|  |  | 
|  | #undef DEFINE_ACCESSOR | 
|  | bool isFloatingPoint() const { | 
|  | return Tag::HAS_d == tag; | 
|  | } | 
|  | bool isIntegral() const { | 
|  | return Tag::HAS_i == tag; | 
|  | } | 
|  | bool isComplex() const { | 
|  | return Tag::HAS_z == tag; | 
|  | } | 
|  |  | 
|  | Scalar operator-() const; | 
|  |  | 
|  | private: | 
|  | enum class Tag { HAS_d, HAS_i, HAS_z }; | 
|  | Tag tag; | 
|  | union { | 
|  | double d; | 
|  | int64_t i; | 
|  | // Can't do put std::complex in the union, because it triggers | 
|  | // an nvcc bug: | 
|  | //    error: designator may not specify a non-POD subobject | 
|  | double z[2]; | 
|  | } v; | 
|  | }; | 
|  |  | 
|  | // define the scalar.to<int64_t>() specializations | 
|  | template<typename T> | 
|  | inline T Scalar::to() { | 
|  | throw std::runtime_error("to() cast to unexpected type."); | 
|  | } | 
|  |  | 
|  | #define DEFINE_TO(T,name,_) \ | 
|  | template<> \ | 
|  | inline T Scalar::to<T>() { \ | 
|  | return to##name(); \ | 
|  | } | 
|  | AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_TO) | 
|  | #undef DEFINE_TO | 
|  | } |