blob: 06edfefd01b0fb54b7717a0a3e9e7098283afbe8 [file] [log] [blame]
#pragma once
#include <assert.h>
#include <stdint.h>
#include <stdexcept>
#include <string>
#include <utility>
#include <c10/macros/Macros.h>
#include <c10/core/ScalarType.h>
#include <c10/Half.h>
namespace c10 {
/**
* Scalar represents a 0-dimensional tensor which contains a single element.
* Unlike a tensor, numeric literals (in C++) are implicitly convertible to Scalar
* (which is why, for example, we provide both add(Tensor) and add(Scalar) overloads
* for many operations). It may also be used in circumstances where you statically
* know a tensor is 0-dim and single size, but don't know it's type.
*/
class C10_API Scalar {
public:
Scalar() : Scalar(int64_t(0)) {}
#define DEFINE_IMPLICIT_CTOR(type,name,member) \
Scalar(type vv) \
: tag(Tag::HAS_##member) { \
v . member = convert<decltype(v.member),type>(vv); \
}
// We can't set v in the initializer list using the
// syntax v{ .member = ... } because it doesn't work on MSVC
AT_FORALL_SCALAR_TYPES(DEFINE_IMPLICIT_CTOR)
#undef DEFINE_IMPLICIT_CTOR
#define DEFINE_IMPLICIT_COMPLEX_CTOR(type, name, member) \
Scalar(type vv) : tag(Tag::HAS_##member) { \
v.member[0] = c10::convert<double>(vv.real()); \
v.member[1] = c10::convert<double>(vv.imag()); \
}
DEFINE_IMPLICIT_COMPLEX_CTOR(at::ComplexHalf,ComplexHalf,z)
DEFINE_IMPLICIT_COMPLEX_CTOR(std::complex<float>,ComplexFloat,z)
DEFINE_IMPLICIT_COMPLEX_CTOR(std::complex<double>,ComplexDouble,z)
#undef DEFINE_IMPLICIT_COMPLEX_CTOR
#define DEFINE_ACCESSOR(type,name,member) \
type to##name () const { \
if (Tag::HAS_d == tag) { \
return checked_convert<type, double>(v.d, #type); \
} else if (Tag::HAS_z == tag) { \
return checked_convert<type, std::complex<double>>({v.z[0], v.z[1]}, #type); \
} else { \
return checked_convert<type, int64_t>(v.i, #type); \
} \
}
// TODO: Support ComplexHalf accessor
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_ACCESSOR)
//also support scalar.to<int64_t>();
template<typename T>
T to();
#undef DEFINE_ACCESSOR
bool isFloatingPoint() const {
return Tag::HAS_d == tag;
}
bool isIntegral() const {
return Tag::HAS_i == tag;
}
bool isComplex() const {
return Tag::HAS_z == tag;
}
Scalar operator-() const;
private:
enum class Tag { HAS_d, HAS_i, HAS_z };
Tag tag;
union {
double d;
int64_t i;
// Can't do put std::complex in the union, because it triggers
// an nvcc bug:
// error: designator may not specify a non-POD subobject
double z[2];
} v;
};
// define the scalar.to<int64_t>() specializations
template<typename T>
inline T Scalar::to() {
throw std::runtime_error("to() cast to unexpected type.");
}
#define DEFINE_TO(T,name,_) \
template<> \
inline T Scalar::to<T>() { \
return to##name(); \
}
AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_EXCEPT_COMPLEX_HALF(DEFINE_TO)
#undef DEFINE_TO
}