Remove ScalarConvert and cast_wrapper in favor of static_cast (#9401)
Summary:
While talking to mruberry, I noticed a few places that use
special cast wrappers that are no longer necessary.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9401
Differential Revision: D8828874
Pulled By: colesbury
fbshipit-source-id: 2b7fe7ac3af3b71be26b43a9ad3949f8065a7bc9
diff --git a/aten/src/ATen/native/Distributions.h b/aten/src/ATen/native/Distributions.h
index 8758206..7a6e078 100644
--- a/aten/src/ATen/native/Distributions.h
+++ b/aten/src/ATen/native/Distributions.h
@@ -16,15 +16,6 @@
#define isnan std::isnan
#endif
-// workaround: cuda 8 cannot do static_cast with half
-#ifdef __CUDACC__
-template<typename R, typename T>
-deviceforcuda R cast_wrapper(T v) { return scalar_cast<R>(v); }
-#else
-template<typename R, typename T>
-deviceforcuda R cast_wrapper(T v) { return static_cast<R>(v); }
-#endif
-
template<typename scalar_t>
struct BaseSampler {
nvfunction_or_function<scalar_t(void)> sampler;
@@ -65,9 +56,9 @@
accscalar_t scale = 1.0f;
// Boost alpha for higher acceptance probability.
- if (alpha < cast_wrapper<scalar_t,float>(1.0f)) {
+ if (alpha < 1.0f) {
scale *= std::pow(1 - standard_uniform.sample(), 1.0f / alpha);
- alpha += cast_wrapper<scalar_t,float>(1.0f);
+ alpha += 1.0f;
}
// This implements the acceptance-rejection method of Marsaglia and Tsang (2000)
@@ -84,9 +75,9 @@
const accscalar_t u = 1 - standard_uniform.sample();
const accscalar_t xx = x * x;
if (u < 1.0f - 0.0331f * xx * xx)
- return cast_wrapper<scalar_t,accscalar_t>(scale * d * v);
+ return static_cast<scalar_t>(scale * d * v);
if (std::log(u) < 0.5f * xx + d * (1.0f - v + std::log(v)))
- return cast_wrapper<scalar_t,accscalar_t>(scale * d * v);
+ return static_cast<scalar_t>(scale * d * v);
}
}
@@ -159,8 +150,8 @@
template <typename scalar_t, typename accscalar_t>
deviceforcuda scalar_t standard_gamma_grad_one(scalar_t alpha_, scalar_t x_) {
// Use a Taylor series expansion for small x.
- accscalar_t x = cast_wrapper<accscalar_t,scalar_t>(x_);
- accscalar_t alpha = cast_wrapper<accscalar_t,scalar_t>(alpha_);
+ accscalar_t x = static_cast<accscalar_t>(x_);
+ accscalar_t alpha = static_cast<accscalar_t>(alpha_);
if (x < 0.8f) {
accscalar_t numer = 1;
accscalar_t denom = alpha;
@@ -178,7 +169,7 @@
const auto gamma_cdf_alpha = (std::log(x) - digamma_one<accscalar_t,accscalar_t>(alpha)) * gamma_cdf
- pow_x_alpha * series2;
const auto result = -gamma_cdf_alpha / gamma_pdf;
- return isnan(result) ? cast_wrapper<scalar_t,float>( 0.f ) : cast_wrapper<scalar_t,accscalar_t>(result);
+ return isnan(result) ? static_cast<scalar_t>( 0.f ) : static_cast<scalar_t>(result);
}
// Use a Rice saddle point expansion for large alpha.
@@ -188,7 +179,7 @@
const auto numer_2 = 1440 * (alpha * alpha) + 6 * x * (53 - 120 * x)
- 65 * x * x / alpha + alpha * (107 + 3600 * x);
const auto denom = 1244160 * (alpha * alpha) * (alpha * alpha);
- return cast_wrapper<scalar_t,accscalar_t>(numer_1 * numer_2 / denom);
+ return static_cast<scalar_t>(numer_1 * numer_2 / denom);
}
const auto denom = std::sqrt(8 * alpha);
const auto term2 = denom / (alpha - x);
@@ -198,7 +189,7 @@
- std::sqrt(2 / alpha) * (alpha + x) / ((alpha - x) * (alpha - x));
const auto stirling = 1 + 1 / (12 * alpha) * (1 + 1 / (24 * alpha));
const auto numer = x * term1;
- return cast_wrapper<scalar_t,accscalar_t>(-stirling * numer / denom);
+ return static_cast<scalar_t>(-stirling * numer / denom);
}
// Use a bivariate rational approximation to the reparameterized gradient.
@@ -218,7 +209,7 @@
}
const auto p = coef_v[0] + v * (coef_v[1] + v * (coef_v[2] + v * coef_v[3]));
const auto q = coef_v[4] + v * (coef_v[5] + v * (coef_v[6] + v * coef_v[7]));
- return cast_wrapper<scalar_t,accscalar_t>(std::exp(p / q));
+ return static_cast<scalar_t>(std::exp(p / q));
}
} // namespace
diff --git a/aten/src/ATen/native/cuda/Distributions.cu b/aten/src/ATen/native/cuda/Distributions.cu
index c591a30..4b346ca 100644
--- a/aten/src/ATen/native/cuda/Distributions.cu
+++ b/aten/src/ATen/native/cuda/Distributions.cu
@@ -46,7 +46,7 @@
blockIdx.x * blockDim.x + threadIdx.x,
seeds.second,
&state);
- ret_val = scalar_cast<scalar_t>(curand_poisson(&state, scalar_cast<float>(lambda)));
+ ret_val = static_cast<scalar_t>(curand_poisson(&state, lambda));
});
}
diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh b/aten/src/ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh
index 43f1b4d..44bd3ab 100644
--- a/aten/src/ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh
+++ b/aten/src/ATen/native/sparse/cuda/SparseCUDAApplyUtils.cuh
@@ -1,7 +1,6 @@
#pragma once
#include <ATen/cuda/detail/TensorInfo.cuh>
-#include <THC/THCNumerics.cuh>
namespace at { namespace native {
@@ -262,12 +261,12 @@
// long seg = chunk / chunksPerSeg;
// auto begin = segment_offsets[seg];
// auto end = (seg < newNnz - 1) ? segment_offsets[seg + 1] : nnz;
-// Acctype valSum = ScalarConvert<float, Acctype>::to(0);
+// Acctype valSum = static_cast<Acctype>::to(0);
// for (long valIdx = begin; valIdx < end; valIdx++) {
// const long valRow = value_indices[valIdx] * stride;
-// valSum += ScalarConvert<Dtype, Acctype>::to(valFeat[valRow]);
+// valSum += static_cast<Acctype>::to(valFeat[valRow]);
// }
-// newValues[seg * stride + featureDim] = ScalarConvert<Acctype, Dtype>::to(valSum);
+// newValues[seg * stride + featureDim] = static_cast<Dtype>::to(valSum);
// }
// }
// }
@@ -291,7 +290,7 @@
Acctype tmp[SZ];
#pragma unroll
for (int ii = 0; ii < SZ; ii++) {
- tmp[ii] = ScalarConvert<float, Acctype>::to(0);
+ tmp[ii] = 0;
}
for (int row = begin; row < end; row++) {
const int valueRow = ((int) value_indices[row]) * stride;
@@ -303,7 +302,7 @@
int featureDim = startFeature + ii * WARP_SIZE;
if (featureDim < stride)
{
- tmp[ii] += ScalarConvert<Dtype, Acctype>::to(values[valueRow + featureDim]);
+ tmp[ii] += static_cast<Acctype>(values[valueRow + featureDim]);
}
}
}
@@ -313,7 +312,7 @@
int featureDim = startFeature + ii * WARP_SIZE;
if (featureDim < stride)
{
- newValues[newValueRow + featureDim] = ScalarConvert<Acctype, Dtype>::to(tmp[ii]);
+ newValues[newValueRow + featureDim] = static_cast<Dtype>(tmp[ii]);
}
}
}
diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu
index eb36911..3521fc3 100644
--- a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu
+++ b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu
@@ -8,7 +8,6 @@
#include <THC/THCTensorMathPointwise.cuh>
#include <THC/THCThrustAllocator.cuh>
-#include <THC/THCNumerics.cuh>
#include <thrust/device_ptr.h>
#include <thrust/sequence.h>
#include <thrust/system/cuda/execution_policy.h>
@@ -89,7 +88,7 @@
scalar_t cast_alpha = alpha.to<scalar_t>();
if (cast_beta == 0) {
r_.zero_();
- } else if (cast_beta == ScalarConvert<int, scalar_t>::to(1)) {
+ } else if (cast_beta == 1) {
if (!isSameTensor(t, r_)) {
r_.copy_(t);
}
@@ -314,7 +313,7 @@
// NB: Purposely not inplace!
AT_DISPATCH_ALL_TYPES_AND_HALF(
values.type(), "add_out_dense_sparse_cuda", [&] {
- if (value.to<scalar_t>() != ScalarConvert<int, scalar_t>::to(1)) {
+ if (value.to<scalar_t>() != static_cast<scalar_t>(1)) {
values = values.mul(value);
}
});
@@ -383,7 +382,7 @@
AT_DISPATCH_ALL_TYPES_AND_HALF(
s_values_.type(), "s_add_out_sparse_cuda", [&] {
- if (value.to<scalar_t>() != ScalarConvert<int, scalar_t>::to(1)) {
+ if (value.to<scalar_t>() != static_cast<scalar_t>(1)) {
s_values_ = s_values_.mul(value);
}
});
@@ -429,7 +428,7 @@
AT_DISPATCH_ALL_TYPES(
t.type(), "sub_sparse", [&] {
scalar_t cast_value = value.to<scalar_t>();
- s_add_out_sparse_cuda(r, t, src, ScalarNegate<scalar_t>::to(cast_value));
+ s_add_out_sparse_cuda(r, t, src, -cast_value);
}
);
return r;