Relax Scalar::toXXX conversions to only check for overflow
Currently, the toXXX functions on Scalar check that the conversions are
exact. This will cause an exception in code like:
auto t = CPU(kFloat).ones({1});
t *= M_PI;
Or the equivalent in Python:
t = torch.ones(1)
t *= math.pi
This changes the checks to only throw an exception in the case of
overflow (positive or negative).
diff --git a/Half.h b/Half.h
index bfb1cc6..f090f35 100644
--- a/Half.h
+++ b/Half.h
@@ -1,5 +1,6 @@
#pragma once
+#include <limits>
#include <stdint.h>
#ifdef AT_CUDA_ENABLED
#include <cuda.h>
@@ -13,6 +14,22 @@
return static_cast<To>(f);
}
+template<typename To, typename From> bool overflows(From f) {
+ using limit = std::numeric_limits<To>;
+ return f < limit::lowest() || f > limit::max();
+}
+
+template<typename To, typename From> To checked_convert(From f, const char* name) {
+ if (overflows<To, From>(f)) {
+ std::string msg = "value cannot be converted to type ";
+ msg += name;
+ msg += " without overflow: ";
+ msg += std::to_string(f);
+ throw std::domain_error(std::move(msg));
+ }
+ return convert<To, From>(f);
+}
+
#if defined(__GNUC__)
#define AT_ALIGN(n) __attribute__((aligned(n)))
#elif defined(_WIN32)
@@ -43,6 +60,9 @@
template<> Half convert(int64_t f);
template<> int64_t convert(Half f);
+template<> bool overflows<Half, double>(double f);
+template<> bool overflows<Half, int64_t>(int64_t f);
+
inline Half::operator double() {
return convert<double,Half>(*this);
}
diff --git a/Scalar.cpp b/Scalar.cpp
index 0dc8ef3..dbeb02c 100644
--- a/Scalar.cpp
+++ b/Scalar.cpp
@@ -36,6 +36,13 @@
return static_cast<int64_t>(convert<double,Half>(f));
}
+template<> bool overflows<Half, double>(double f) {
+ return f > 65504 || f < -65504;
+}
+template<> bool overflows<Half, int64_t>(int64_t f) {
+ return f > 65504 || f < -65504;
+}
+
#ifdef AT_CUDA_ENABLED
template<> half convert(double d) {
diff --git a/Scalar.h b/Scalar.h
index f6b48cf..cfb3fc7 100644
--- a/Scalar.h
+++ b/Scalar.h
@@ -62,18 +62,9 @@
if (Tag::HAS_t == tag) { \
return local().to##name(); \
} else if (Tag::HAS_d == tag) { \
- auto casted = convert<type,double>(v.d); \
- if(convert<double,type>(casted) != v.d) { \
- throw std::domain_error(std::string("value cannot be losslessly represented in type " #name ": ") + std::to_string(v.d) ); \
- } \
- return casted; \
+ return checked_convert<type, double>(v.d, #type); \
} else { \
- assert(Tag::HAS_i == tag); \
- auto casted = convert<type,int64_t>(v.i); \
- if(convert<int64_t,type>(casted) != v.i) { \
- throw std::domain_error(std::string("value cannot be losslessly represented in type " #name ": ") + std::to_string(v.i)); \
- } \
- return casted; \
+ return checked_convert<type, int64_t>(v.i, #type); \
} \
}
diff --git a/test/scalar_test.cpp b/test/scalar_test.cpp
index 29d1128..b3d5071 100644
--- a/test/scalar_test.cpp
+++ b/test/scalar_test.cpp
@@ -1,4 +1,5 @@
#include <iostream>
+#include <math.h>
#include "ATen/ATen.h"
#include "ATen/Dispatch.h"
#include "test_assert.h"
@@ -40,6 +41,24 @@
ASSERT(s3.isBackedByTensor() && s3.toFloat() == 1.0);
}
+void test_overflow() {
+ auto s1 = Scalar(M_PI);
+ ASSERT(s1.toFloat() == static_cast<float>(M_PI));
+ s1.toHalf();
+
+ s1 = Scalar(100000);
+ ASSERT(s1.toFloat() == 100000.0);
+ ASSERT(s1.toInt() == 100000);
+
+ bool threw = false;
+ try {
+ s1.toHalf();
+ } catch (std::domain_error& e) {
+ threw = true;
+ }
+ ASSERT(threw);
+}
+
int main() {
Scalar what = 257;
Scalar bar = 3.0;
@@ -85,6 +104,7 @@
ASSERT(threw);
test_ctors();
+ test_overflow();
if(at::hasCUDA()) {
auto r = CUDA(Float).copy(next_h);