Enable BFloat16 support for Convolutions on ROCm (#30948)
Summary:
This PR adds bfloat16 support for convolutions on ROCm.
- Intergrates MIOpen bfloat16 convolution support into PyTorch
- Enables bfloat16 convolution for non-miopen paths, i.e THCUNN, native hip kernels
- Enables bfloat16 type for probability distribution functions(this is included in this PR since conv unit tests use bfloat16 random number generators)
Native cuda kernels for convolution and random functions will be compiled for CUDA as well.
iotamudelta bddppq
Pull Request resolved: https://github.com/pytorch/pytorch/pull/30948
Differential Revision: D19274164
Pulled By: ezyang
fbshipit-source-id: c0888a6ac72a2c5749b1ebb2195ac6f2209996be
diff --git a/aten/src/ATen/AccumulateType.h b/aten/src/ATen/AccumulateType.h
index 753e0a3..3ee5339 100644
--- a/aten/src/ATen/AccumulateType.h
+++ b/aten/src/ATen/AccumulateType.h
@@ -22,6 +22,7 @@
#if defined(__CUDACC__) || defined(__HIPCC__)
template <> struct AccumulateType<half, true> { using type = float; };
+template <> struct AccumulateType<BFloat16, true> {using type = float; };
#endif
template <> struct AccumulateType<Half, true> { using type = float; };
template <> struct AccumulateType<float, true> { using type = float; };
diff --git a/aten/src/ATen/Dispatch.h b/aten/src/ATen/Dispatch.h
index 52363ad..f377f16 100644
--- a/aten/src/ATen/Dispatch.h
+++ b/aten/src/ATen/Dispatch.h
@@ -138,6 +138,23 @@
} \
}()
+#define AT_DISPATCH_FLOATING_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \
+ [&] { \
+ const auto& the_type = TYPE; \
+ /* don't use TYPE again in case it is an expensive or side-effect op */ \
+ at::ScalarType _st = ::detail::scalar_type(the_type); \
+ switch (_st) { \
+ AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(SCALARTYPE1, \
+ decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE1>::t), __VA_ARGS__) \
+ AT_PRIVATE_CASE_TYPE(SCALARTYPE2, \
+ decltype(c10::impl::ScalarTypeToCPPType<SCALARTYPE2>::t), __VA_ARGS__) \
+ default: \
+ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
+ } \
+ }()
+
#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& the_type = TYPE; \
diff --git a/aten/src/ATen/miopen/Descriptors.cpp b/aten/src/ATen/miopen/Descriptors.cpp
index 30aae58..6a64767 100644
--- a/aten/src/ATen/miopen/Descriptors.cpp
+++ b/aten/src/ATen/miopen/Descriptors.cpp
@@ -11,8 +11,11 @@
return miopenFloat;
} else if (scalar_type == at::kHalf) {
return miopenHalf;
+ } else if (scalar_type == at::kBFloat16) {
+ return miopenBFloat16;
+ } else {
+ throw std::runtime_error("TensorDescriptor only supports float, half and bfloat16 tensors");
}
- throw std::runtime_error("TensorDescriptor only supports float and half tensors");
}
} // anonymous namespace
@@ -51,6 +54,8 @@
return "miopenFloat";
case miopenHalf:
return "miopenHalf";
+ case miopenBFloat16:
+ return "miopenBFloat16";
default:
std::ostringstream oss;
oss << "(unknown data-type " << static_cast<int>(dtype) << ")";
diff --git a/aten/src/ATen/miopen/Descriptors.h b/aten/src/ATen/miopen/Descriptors.h
index a399e27..f5b7d9c 100644
--- a/aten/src/ATen/miopen/Descriptors.h
+++ b/aten/src/ATen/miopen/Descriptors.h
@@ -13,6 +13,7 @@
switch (dataType) {
case miopenHalf: return 2;
case miopenFloat: return 4;
+ case miopenBFloat16: return 2;
default: return 8;
}
}
@@ -145,7 +146,7 @@
float f;
double d;
Constant(miopenDataType_t dataType, double value) {
- if (dataType == miopenHalf || dataType == miopenFloat) {
+ if (dataType == miopenHalf || dataType == miopenFloat || dataType == miopenBFloat16) {
f = static_cast<float>(value);
} else {
d = value;
diff --git a/aten/src/ATen/miopen/Types.cpp b/aten/src/ATen/miopen/Types.cpp
index 890b3bf..7d5559a 100644
--- a/aten/src/ATen/miopen/Types.cpp
+++ b/aten/src/ATen/miopen/Types.cpp
@@ -10,6 +10,8 @@
return miopenFloat;
} else if (tensor.scalar_type() == at::kHalf) {
return miopenHalf;
+ } else if (tensor.scalar_type() == at::kBFloat16) {
+ return miopenBFloat16;
}
std::string msg("getMiopenDataType() not supported for ");
msg += toString(tensor.scalar_type());
diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp
index 65b368b..875552f 100644
--- a/aten/src/ATen/native/Convolution.cpp
+++ b/aten/src/ATen/native/Convolution.cpp
@@ -40,7 +40,7 @@
bool use_cudnn(const at::Tensor& input, const at::Tensor& weight) const;
bool use_cudnn_depthwise(const at::Tensor& input, const at::Tensor& weight) const;
bool cudnn_use_channels_last(const at::Tensor& input, const at::Tensor& weight) const;
- bool use_miopen(const at::Tensor& input) const;
+ bool use_miopen(const at::Tensor& input, bool bias_defined) const;
bool use_mkldnn(const at::Tensor& input) const;
bool use_nnpack(const at::Tensor& input) const;
bool is_depthwise(const at::Tensor& input, const at::Tensor& weight) const;
@@ -198,13 +198,14 @@
return !is_output_padding_big();
}
-auto ConvParams::use_miopen(const at::Tensor& input) const -> bool {
+auto ConvParams::use_miopen(const at::Tensor& input, bool bias_defined) const -> bool {
- return ((input.scalar_type() == at::kFloat) || (input.scalar_type() == at::kHalf))
+ return ((input.scalar_type() == at::kFloat) || (input.scalar_type() == at::kHalf) || (input.scalar_type() == at::kBFloat16))
&& detail::getCUDAHooks().compiledWithMIOpen()
&& input.is_cuda()
&& input.dim() <= MIOPEN_DIM_MAX
&& !(groups > 1 && is_dilated()) // MIOpen currently does not support dilation with groups of size > 1
+ && !(input.scalar_type() == at::kBFloat16 && bias_defined) // MIOpen currently doesn't support bias with bfloat16
;
}
@@ -637,7 +638,7 @@
output = output + reshape_bias(input.dim(), bias);
}
- } else if (params.use_miopen(input)){
+ } else if (params.use_miopen(input, bias.defined())){
output = at::miopen_depthwise_convolution(
input.contiguous(), weight, bias,
padding, stride, dilation, params.groups, params.benchmark, params.deterministic);
@@ -667,7 +668,7 @@
output = output + reshape_bias(input.dim(), bias);
}
}
- } else if (params.use_miopen(input)) {
+ } else if (params.use_miopen(input, bias.defined())) {
TORCH_CHECK(input.options().type_equal(weight.options()),
"Input type (", input.toString(), ") and weight type (", weight.toString(),
") should be the same");
diff --git a/aten/src/ATen/native/cuda/Distributions.cu b/aten/src/ATen/native/cuda/Distributions.cu
index c6dc334..0033ee4 100644
--- a/aten/src/ATen/native/cuda/Distributions.cu
+++ b/aten/src/ATen/native/cuda/Distributions.cu
@@ -338,7 +338,7 @@
rng_engine_inputs = gen->philox_engine_inputs(20);
}
Tensor ret = at::empty(lambda.sizes(), lambda.options());
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(ret.scalar_type(), "poisson_cuda", [&] {
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, ret.scalar_type(), "poisson_cuda", [&] {
poisson_cuda_kernel<scalar_t>(ret, lambda, rng_engine_inputs);
});
return ret;
@@ -353,7 +353,7 @@
rng_engine_inputs = gen->philox_engine_inputs(10);
}
Tensor ret = at::empty(alpha.sizes(), alpha.options());
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(ret.scalar_type(), "gamma_cuda", [&] {
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, ret.scalar_type(), "gamma_cuda", [&] {
gamma_cuda_kernel<scalar_t>(ret, alpha, rng_engine_inputs);
});
return ret;
@@ -368,7 +368,7 @@
rng_engine_inputs = gen->philox_engine_inputs(10);
}
Tensor ret = at::empty(alpha.sizes(), alpha.options());
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(ret.scalar_type(), "dirichlet", [&] {
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, ret.scalar_type(), "dirichlet", [&] {
Tensor gamma = at::empty(alpha.sizes(), alpha.options());
gamma_cuda_kernel<scalar_t>(gamma, alpha, rng_engine_inputs);
dirichlet_scalar_cuda_kernel<scalar_t>(ret, gamma);
@@ -378,7 +378,7 @@
Tensor _standard_gamma_grad_cuda(const Tensor& self, const Tensor& output) {
Tensor ret = at::empty(self.sizes(), self.options());
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(self.scalar_type(), "_standard_gamma_grad_cuda", [&] {
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "_standard_gamma_grad_cuda", [&] {
gamma_grad_cuda_kernel<scalar_t>(ret, self, output);
});
return ret;
@@ -402,10 +402,10 @@
rng_engine_inputs = gen->philox_engine_inputs(10);
}
auto p = std::get<0>(expand_inplace(self, p_.to(kCUDA)));
- AT_DISPATCH_ALL_TYPES_AND2(
- at::ScalarType::Half, at::ScalarType::Bool, self.scalar_type(), "bernoulli_tensor_cuda_self_", [&] {
+ AT_DISPATCH_ALL_TYPES_AND3(
+ at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, self.scalar_type(), "bernoulli_tensor_cuda_self_", [&] {
using self_t = scalar_t;
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(p.scalar_type(), "bernoulli_tensor_cuda_p_", [&] {
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, p.scalar_type(), "bernoulli_tensor_cuda_p_", [&] {
using p_t = scalar_t;
return bernoulli_tensor_cuda_kernel<self_t, p_t>(self, p, rng_engine_inputs);
});
@@ -415,7 +415,7 @@
void uniform_kernel_cuda(TensorIterator& iter, double from_, double to_, Generator* gen_) {
auto gen = get_generator_or_default<CUDAGenerator>(gen_, cuda::detail::getDefaultCUDAGenerator());
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "uniform_cuda", [&] {
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "uniform_cuda", [&] {
auto from = static_cast<scalar_t>(from_);
auto to = static_cast<scalar_t>(to_);
TORCH_CHECK(from <= to,
@@ -454,7 +454,7 @@
void random_kernel_cuda(TensorIterator& iter, uint64_t range, int64_t base, Generator* gen_) {
auto gen = get_generator_or_default<CUDAGenerator>(gen_, cuda::detail::getDefaultCUDAGenerator());
- AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Bool, at::ScalarType::Half, iter.dtype(), "random_cuda", [&] {
+ AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "random_cuda", [&] {
if (std::is_same<scalar_t, double>::value || std::is_same<scalar_t, int64_t>::value) {
// define lambda to mod with range and add base
auto random_func = [range, base] __device__ (uint64_t rand) {
@@ -486,7 +486,7 @@
void normal_kernel_cuda(TensorIterator& iter, double mean_, double std_, Generator* gen_) {
auto gen = get_generator_or_default<CUDAGenerator>(gen_, cuda::detail::getDefaultCUDAGenerator());
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "normal_cuda", [&] {
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "normal_cuda", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;
auto mean = static_cast<accscalar_t>(mean_);
auto std = static_cast<accscalar_t>(std_);
@@ -510,7 +510,7 @@
void cauchy_kernel_cuda(TensorIterator& iter, double median_, double sigma_, Generator* gen_) {
auto gen = get_generator_or_default<CUDAGenerator>(gen_, cuda::detail::getDefaultCUDAGenerator());
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "cauchy_cuda", [&] {
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "cauchy_cuda", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;
auto median = static_cast<accscalar_t>(median_);
auto sigma = static_cast<accscalar_t>(sigma_);
@@ -543,7 +543,7 @@
// Note that HIP doesn't support std::nextafter in device code.
auto nextafter_1_0_float = std::nextafter(1.0f, 0.0f);
auto nextafter_1_0_double = std::nextafter(1.0, 0.0);
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "exponential_cuda", [&] {
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "exponential_cuda", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;
auto lambda = static_cast<accscalar_t>(lambda_);
if (std::is_same<scalar_t, double>::value) {
@@ -584,7 +584,7 @@
void geometric_kernel_cuda(TensorIterator& iter, double p_, Generator* gen_) {
auto gen = get_generator_or_default<CUDAGenerator>(gen_, cuda::detail::getDefaultCUDAGenerator());
- AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.dtype(), "geometric_cuda", [&] {
+ AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "geometric_cuda", [&] {
if (std::is_same<scalar_t, double>::value) {
// define lambda for geometric transformation
auto geometric_func = [p_] __device__ (double rand) {
@@ -610,7 +610,7 @@
void log_normal_kernel_cuda(TensorIterator& iter, double mean_, double std_, Generator* gen_) {
auto gen = get_generator_or_default<CUDAGenerator>(gen_, cuda::detail::getDefaultCUDAGenerator());
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "log_normal_cuda", [&] {
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "log_normal_cuda", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;
auto mean = static_cast<accscalar_t>(mean_);
auto std = static_cast<accscalar_t>(std_);
@@ -638,8 +638,8 @@
void bernoulli_scalar_cuda_kernel(TensorIterator& iter, double p_, Generator* gen_) {
auto gen = get_generator_or_default<CUDAGenerator>(gen_, cuda::detail::getDefaultCUDAGenerator());
- AT_DISPATCH_ALL_TYPES_AND2(
- at::ScalarType::Half, at::ScalarType::Bool, iter.dtype(), "bernoulli_scalar_cuda_", [&] {
+ AT_DISPATCH_ALL_TYPES_AND3(
+ at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, iter.dtype(), "bernoulli_scalar_cuda_", [&] {
if (std::is_same<scalar_t, double>::value) {
// define lambda for bernoulli transformation
auto bernoulli_func = [p_] __device__ (double rand) {
@@ -673,7 +673,7 @@
uint64_t range;
auto iter_scalar_type = iter.dtype();
if (isFloatingType(iter_scalar_type)) {
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter_scalar_type, "random_cuda_range_calc", [&] {
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter_scalar_type, "random_cuda_range_calc", [&] {
range = static_cast<uint64_t>((1ULL << std::numeric_limits<scalar_t>::digits) + 1);
});
} else {
diff --git a/aten/src/ATen/native/cuda/NaiveConvolutionTranspose2d.cu b/aten/src/ATen/native/cuda/NaiveConvolutionTranspose2d.cu
index 1e72fa3..29d63e4 100644
--- a/aten/src/ATen/native/cuda/NaiveConvolutionTranspose2d.cu
+++ b/aten/src/ATen/native/cuda/NaiveConvolutionTranspose2d.cu
@@ -256,7 +256,7 @@
ones.fill_(1);
}
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
input.scalar_type(), "slow_conv_transpose2d_out_cuda", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;
@@ -460,7 +460,7 @@
grad_columns.resize_({n_output_plane * kernel_width * kernel_height,
input_height * input_width});
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
grad_output.scalar_type(), "slow_conv_transpose2d_backward_out_cuda", [&] {
// Helpers
Tensor grad_input_n = Tensor();
@@ -663,7 +663,7 @@
columns.resize_({n_output_plane * kernel_width * kernel_height,
input_height * input_width});
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
input.scalar_type(), "slow_conv_transpose2d_acc_grad_parameters_cuda", [&] {
// Helpers
Tensor input_n = Tensor();
diff --git a/aten/src/ATen/native/cuda/NaiveConvolutionTranspose3d.cu b/aten/src/ATen/native/cuda/NaiveConvolutionTranspose3d.cu
index dad5508..717143a 100644
--- a/aten/src/ATen/native/cuda/NaiveConvolutionTranspose3d.cu
+++ b/aten/src/ATen/native/cuda/NaiveConvolutionTranspose3d.cu
@@ -297,7 +297,7 @@
ones.fill_(1);
}
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
input.scalar_type(), "slow_conv_transpose3d_out_cuda", [&] {
using accscalar_t = at::acc_type<scalar_t, true>;
@@ -531,7 +531,7 @@
{n_output_plane * kernel_width * kernel_height * kernel_depth,
input_depth * input_height * input_width});
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
input.scalar_type(), "slow_conv_transpose3d_backward_out_cuda", [&] {
// Helpers
Tensor grad_input_n;
@@ -761,7 +761,7 @@
columns.resize_({n_output_plane * kernel_width * kernel_height * kernel_depth,
input_depth * input_height * input_width});
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
input.scalar_type(),
"slow_conv_transpose3d_acc_grad_parameters_cuda",
[&] {
diff --git a/aten/src/ATen/native/cuda/NaiveDilatedConvolution.cu b/aten/src/ATen/native/cuda/NaiveDilatedConvolution.cu
index 7fcb16e..0355820 100644
--- a/aten/src/ATen/native/cuda/NaiveDilatedConvolution.cu
+++ b/aten/src/ATen/native/cuda/NaiveDilatedConvolution.cu
@@ -244,7 +244,7 @@
std::vector<int64_t> dims(dim);
std::iota(dims.begin(), dims.end(), 1);
- AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
input.scalar_type(), "slow_conv_dilated<>", [&] {
// For each elt in batch, do:
for (int elt = 0; elt < batchSize; elt++) {
diff --git a/aten/src/ATen/nn.yaml b/aten/src/ATen/nn.yaml
index 0e03e28..e394549 100644
--- a/aten/src/ATen/nn.yaml
+++ b/aten/src/ATen/nn.yaml
@@ -58,7 +58,13 @@
CPU:
forward_scalar_types: ['Float', 'Double', 'Long', 'BFloat16']
backward_scalar_types: ['Float', 'Double', 'BFloat16']
+ CUDA:
+ forward_scalar_types: ['Float', 'Double', 'Half', 'BFloat16']
+ backward_scalar_types: ['Float', 'Double', 'Half', 'BFloat16']
- name: _thnn_conv_depthwise2d(Tensor self, Tensor weight, IntArrayRef[2] kernel_size, Tensor? bias, IntArrayRef[2] stride, IntArrayRef[2] padding, IntArrayRef[2] dilation)
cname: SpatialDepthwiseConvolution
buffers: []
+ CUDA:
+ forward_scalar_types: ['Float', 'Double', 'Half', 'BFloat16']
+ backward_scalar_types: ['Float', 'Double', 'Half', 'BFloat16']
\ No newline at end of file
diff --git a/aten/src/ATen/nn_parse.py b/aten/src/ATen/nn_parse.py
index a70a249..bb1340e 100644
--- a/aten/src/ATen/nn_parse.py
+++ b/aten/src/ATen/nn_parse.py
@@ -233,6 +233,8 @@
'name': name,
'cpu_bfloat16': True if backend_types is not None and 'CPU' in backend_types and
'BFloat16' in backend_types['CPU'] else False,
+ 'cuda_bfloat16': True if backend_types is not None and 'CUDA' in backend_types and
+ 'BFloat16' in backend_types['CUDA'] else False,
'backend_types': backend_types,
'arguments': arguments,
'return': 'argument 0' if inplace else get_return(arguments),
diff --git a/aten/src/THCUNN/SpatialConvolutionMM.cu b/aten/src/THCUNN/SpatialConvolutionMM.cu
index 488f00e..020bfa1 100644
--- a/aten/src/THCUNN/SpatialConvolutionMM.cu
+++ b/aten/src/THCUNN/SpatialConvolutionMM.cu
@@ -8,3 +8,6 @@
#include <THCUNN/generic/SpatialConvolutionMM.cu>
#include <THC/THCGenerateFloatTypes.h>
+
+#include <THCUNN/generic/SpatialConvolutionMM.cu>
+#include <THC/THCGenerateBFloat16Type.h>
diff --git a/aten/src/THCUNN/SpatialDepthwiseConvolution.cu b/aten/src/THCUNN/SpatialDepthwiseConvolution.cu
index 2874f7d..889830f 100644
--- a/aten/src/THCUNN/SpatialDepthwiseConvolution.cu
+++ b/aten/src/THCUNN/SpatialDepthwiseConvolution.cu
@@ -266,3 +266,6 @@
#include <THCUNN/generic/SpatialDepthwiseConvolution.cu>
#include <THC/THCGenerateFloatTypes.h>
+
+#include <THCUNN/generic/SpatialDepthwiseConvolution.cu>
+#include <THC/THCGenerateBFloat16Type.h>
diff --git a/aten/src/THCUNN/THCUNN.h b/aten/src/THCUNN/THCUNN.h
index 3752b66..a4392dd 100644
--- a/aten/src/THCUNN/THCUNN.h
+++ b/aten/src/THCUNN/THCUNN.h
@@ -8,3 +8,6 @@
#include <THCUNN/generic/THCUNN.h>
#include <THC/THCGenerateFloatTypes.h>
+
+#include <THCUNN/generic/THCUNN.h>
+#include <THC/THCGenerateBFloat16Type.h>
diff --git a/aten/src/THCUNN/generic/SpatialConvolutionMM.cu b/aten/src/THCUNN/generic/SpatialConvolutionMM.cu
index b774dfa..9f70006 100644
--- a/aten/src/THCUNN/generic/SpatialConvolutionMM.cu
+++ b/aten/src/THCUNN/generic/SpatialConvolutionMM.cu
@@ -114,6 +114,9 @@
int kW, int kH,
int dW, int dH,
int padW, int padH) {
+ #if defined(THC_REAL_IS_BFLOAT16) && !defined(__HIP_PLATFORM_HCC__)
+ TORCH_CHECK(false, "SpatialConvolutionMM_updateOutput not suppported with BFloat16");
+ #else
THCUNN_assertSameGPU(state, 5, input, output, weight, columns, ones);
if (bias) {
THCUNN_assertSameGPU(state, 2, weight, bias);
@@ -194,6 +197,8 @@
THCudaBlas_Hgemm(
#elif defined(THC_REAL_IS_DOUBLE)
THCudaBlas_Dgemm(
+ #elif defined(THC_REAL_IS_BFLOAT16)
+ THCudaBlas_Bgemm(
#endif
state,
't', 'n',
@@ -232,6 +237,8 @@
THCudaBlas_Hgemm(
#elif defined(THC_REAL_IS_DOUBLE)
THCudaBlas_Dgemm(
+ #elif defined(THC_REAL_IS_BFLOAT16)
+ THCudaBlas_Bgemm(
#endif
state,
'n', 'n',
@@ -256,6 +263,7 @@
THCTensor_(free)(state, input);
THCTensor_(free)(state, weight);
+ #endif // THC_REAL_IS_BFLOAT16 && !__HIP_PLATFORM_HCC__
}
void THNN_(SpatialConvolutionMM_updateGradInput)(
@@ -270,6 +278,9 @@
int dW, int dH,
int padW, int padH) {
+ #if defined(THC_REAL_IS_BFLOAT16) && !defined(__HIP_PLATFORM_HCC__)
+ TORCH_CHECK(false, "SpatialConvolutionMM_updateGradInput not suppported with BFloat16");
+ #else
THCUNN_assertSameGPU(state, 5, input, gradOutput, weight,
gradColumns, gradInput);
weight = THNN_(newViewWeightMM2d)(state, weight);
@@ -329,6 +340,8 @@
THCudaBlas_Hgemm(
#elif defined(THC_REAL_IS_DOUBLE)
THCudaBlas_Dgemm(
+ #elif defined(THC_REAL_IS_BFLOAT16)
+ THCudaBlas_Bgemm(
#endif
state,
'n', 't',
@@ -363,6 +376,7 @@
THCTensor_(free)(state, input);
THCTensor_(free)(state, gradOutput);
+ #endif // THC_REAL_IS_BFLOAT16 && !__HIP_PLATFORM_HCC__
}
void THNN_(SpatialConvolutionMM_accGradParameters)(
@@ -378,6 +392,9 @@
int padW, int padH,
accreal scale_) {
+ #if defined(THC_REAL_IS_BFLOAT16) && !defined(__HIP_PLATFORM_HCC__)
+ TORCH_CHECK(false, "SpatialConvolutionMM_updateGradParameters not suppported with BFloat16");
+ #else
scalar_t scale = ScalarConvert<accreal, scalar_t>::to(scale_);
THCUNN_assertSameGPU(state, 5, input, gradOutput, gradWeight, gradBias, columns, ones);
if (gradWeight) {
@@ -463,6 +480,8 @@
THCudaBlas_Hgemm(
#elif defined(THC_REAL_IS_DOUBLE)
THCudaBlas_Dgemm(
+ #elif defined(THC_REAL_IS_BFLOAT16)
+ THCudaBlas_Bgemm(
#endif
state,
't', 'n',
@@ -499,8 +518,12 @@
THCTensor_(data)(state, gradBias), 1
);
#endif
+ #if defined(THC_REAL_IS_HALF) || defined(THC_REAL_IS_BFLOAT16)
#ifdef THC_REAL_IS_HALF
THCudaBlas_Hgemm(
+ #elif defined(THC_REAL_IS_BFLOAT16)
+ THCudaBlas_Bgemm(
+ #endif
state,
't', 'n',
m_, 1, k_,
@@ -528,6 +551,7 @@
THCTensor_(free)(state, input);
THCTensor_(free)(state, gradOutput);
+ #endif // THC_REAL_IS_BFLOAT16 && !__HIP_PLATFORM_HCC__
}
#endif
diff --git a/aten/src/THCUNN/generic/SpatialDepthwiseConvolution.cu b/aten/src/THCUNN/generic/SpatialDepthwiseConvolution.cu
index 79d4849..4bf80ab 100644
--- a/aten/src/THCUNN/generic/SpatialDepthwiseConvolution.cu
+++ b/aten/src/THCUNN/generic/SpatialDepthwiseConvolution.cu
@@ -13,6 +13,9 @@
int padW, int padH,
int dilationW, int dilationH)
{
+ #if defined(THC_REAL_IS_BFLOAT16) && !defined(__HIP_PLATFORM_HCC__)
+ TORCH_CHECK(false, "SpatialDepthwiseConvolution_updateOutput not suppported with BFloat16");
+ #else
THCUNN_assertSameGPU(state, 3, input, output, weight);
// Only handle 4D Input Tensors for now
@@ -91,6 +94,7 @@
THCTensor_(free)(state, input);
THCTensor_(free)(state, weight);
if (bias) THCTensor_(free)(state, bias);
+ #endif // THC_REAL_IS_BFLOAT16 && !__HIP_PLATFORM_HCC__
}
void THNN_(SpatialDepthwiseConvolution_updateGradInput)(
@@ -104,6 +108,9 @@
int padW, int padH,
int dilationW, int dilationH)
{
+ #if defined(THC_REAL_IS_BFLOAT16) && !defined(__HIP_PLATFORM_HCC__)
+ TORCH_CHECK(false, "SpatialDepthwiseConvolution_updateGradInput not suppported with BFloat16");
+ #else
THCUNN_assertSameGPU(state, 3, gradOutput, gradInput, weight);
// Only handle 4D Input Tensors for now
@@ -196,6 +203,7 @@
THCTensor_(free)(state, weight);
THCTensor_(free)(state, gradOutput);
+ #endif // THC_REAL_IS_BFLOAT16 && !__HIP_PLATFORM_HCC__
}
void THNN_(SpatialDepthwiseConvolution_accGradParameters)(
@@ -208,6 +216,9 @@
int padW, int padH,
int dilationW, int dilationH)
{
+ #if defined(THC_REAL_IS_BFLOAT16) && !defined(__HIP_PLATFORM_HCC__)
+ TORCH_CHECK(false, "SpatialDepthwiseConvolution_accGradParameters not suppported with BFloat16");
+ #else
THCUNN_assertSameGPU(state, 3, input, gradOutput, gradWeight);
// Only handle 4D Input Tensors for now
@@ -260,6 +271,7 @@
THCudaCheck(cudaGetLastError());
THCTensor_(free)(state, gradOutput);
+ #endif // THC_REAL_IS_BFLOAT16 && !__HIP_PLATFORM_HCC__
}
#endif
diff --git a/test/common_utils.py b/test/common_utils.py
index 5bf85df..40ec916 100644
--- a/test/common_utils.py
+++ b/test/common_utils.py
@@ -141,9 +141,6 @@
# Always call p.wait() to ensure exit
p.wait()
-ALL_TENSORTYPES = [torch.float,
- torch.double,
- torch.half]
# Used to run the same test with different tensor types
def repeat_test_for_types(dtypes):
@@ -289,6 +286,20 @@
if TEST_NUMPY:
import numpy
+ALL_TENSORTYPES = [torch.float,
+ torch.double,
+ torch.half]
+
+# bfloat16 bringup is currently only available on ROCm
+# ALL_TENSORTYPES2 will eventually be unified with ALL_TENSORTYPES
+# when bfloat16 bringup is complete on all platforms
+if TEST_WITH_ROCM:
+ ALL_TENSORTYPES2 = [torch.float,
+ torch.double,
+ torch.half,
+ torch.bfloat16]
+else:
+ ALL_TENSORTYPES2 = ALL_TENSORTYPES
def skipIfRocm(fn):
@wraps(fn)
diff --git a/test/test_nn.py b/test/test_nn.py
index cd87862..44aa1ee 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -39,9 +39,9 @@
from torch.nn import Parameter
from torch.nn.parallel._functions import Broadcast
from common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \
- TEST_NUMPY, TEST_SCIPY, download_file, PY3, to_gpu, \
+ TEST_NUMPY, TEST_SCIPY, TEST_WITH_ROCM, download_file, PY3, to_gpu, \
get_function_arglist, load_tests, repeat_test_for_types, ALL_TENSORTYPES, \
- TemporaryFileName
+ ALL_TENSORTYPES2, TemporaryFileName
from common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, TEST_CUDNN_VERSION
from common_nn import NNTestCase, ModuleTest, CriterionTest, TestBase, \
module_tests, criterion_tests, new_criterion_tests, loss_reference_fns, \
@@ -77,7 +77,8 @@
dtype2prec = {torch.float: 1e-5,
torch.double: 1e-5,
- torch.half: 1e-2}
+ torch.half: 1e-2,
+ torch.bfloat16: 1e-1}
# WARNING: If you add a new top-level test case to this file, you MUST
@@ -3972,7 +3973,7 @@
@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
@unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
- @repeat_test_for_types(ALL_TENSORTYPES)
+ @repeat_test_for_types(ALL_TENSORTYPES2)
def test_Conv2d_deterministic_cudnn(self, dtype=torch.float):
inputs = torch.randn(2, 3, 5, 5, device="cuda", dtype=dtype, requires_grad=True)
with cudnn.flags(enabled=True, benchmark=True, deterministic=True):
@@ -4002,7 +4003,7 @@
lambda: o1.sum().backward())
@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
- @repeat_test_for_types(ALL_TENSORTYPES)
+ @repeat_test_for_types(ALL_TENSORTYPES2)
def test_Conv2d_large_workspace(self, dtype=torch.float):
# These sizes require huge cuDNN workspaces. Make sure we choose a
# reasonable algorithm that does not run out of memory
@@ -4097,6 +4098,8 @@
dev_dtypes = [("cpu", torch.float)]
if TEST_CUDA:
dev_dtypes += [("cuda", torch.float), ("cuda", torch.half)]
+ if TEST_WITH_ROCM:
+ dev_dtypes += [("cuda", torch.bfloat16)]
for device, dtype in dev_dtypes:
m = nn.Conv2d(4, 4, kernel_size=3, groups=2, bias=False).to(device, dtype)
i = torch.randn(2, 4, 6, 6, device=device, dtype=dtype, requires_grad=True)
@@ -4133,6 +4136,8 @@
dev_dtypes = [("cpu", torch.float)]
if TEST_CUDA:
dev_dtypes += [("cuda", torch.float), ("cuda", torch.half)]
+ if TEST_WITH_ROCM:
+ dev_dtypes += [("cuda", torch.bfloat16)]
for device, dtype in dev_dtypes:
m = nn.Conv2d(4, 16, kernel_size=3, groups=2, bias=False).to(device, dtype)
i = torch.randn(2, 4, 6, 6, device=device, dtype=dtype, requires_grad=True)
@@ -5846,7 +5851,7 @@
self.assertEqual(grad_output, grad_output_clone)
@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
- @repeat_test_for_types(ALL_TENSORTYPES)
+ @repeat_test_for_types(ALL_TENSORTYPES2)
def test_noncontig_conv_grad_cuda(self, dtype=torch.float):
# FIXME: remove after adding non-contiguous grad tests for all modules
module = nn.Conv2d(3, 5, kernel_size=3, padding=1).to("cuda", dtype)
@@ -9983,7 +9988,7 @@
self.assertEqual(q.size(), out[0].size())
self.assertEqual(dtype, out[0].dtype)
- @dtypesIfCUDA(torch.half, torch.float, torch.double)
+ @dtypesIfCUDA(*ALL_TENSORTYPES2)
@dtypes(torch.float)
def test_Conv2d_naive_groups(self, device, dtype):
# Check that grouped convolutions matches two half convolutions
diff --git a/test/test_torch.py b/test/test_torch.py
index b2e9183..403faba 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -10683,11 +10683,8 @@
def test_unfold_all_devices_and_dtypes(self, device):
for dt in torch.testing.get_all_dtypes():
- if dt == torch.bfloat16:
- self.assertRaises(RuntimeError, lambda: torch.randint(5, (0, 1, 3, 0), dtype=dt, device=device))
- continue
- if dt == torch.half and device == 'cpu':
+ if dt in {torch.half, torch.bfloat16} and device == 'cpu':
# fix once random is implemented for Half on CPU
self.assertRaises(RuntimeError, lambda: torch.randint(5, (0, 1, 3, 0), dtype=dt, device=device))
else:
@@ -10777,17 +10774,14 @@
self.assertEqual(shape, torch.empty_like(torch.zeros(shape, device=device, dtype=dt)).shape)
self.assertEqual(shape, torch.empty_strided(shape, (0,) * len(shape), device=device, dtype=dt).shape)
- if dt == torch.half and device == "cpu":
+ if dt in {torch.half, torch.bfloat16} and device == "cpu":
# update once random is implemented for half on CPU
self.assertRaises(RuntimeError, lambda: torch.randint(6, shape, device=device, dtype=dt).shape)
else:
- if dt == torch.bfloat16:
- self.assertRaises(RuntimeError, lambda: torch.randint(6, shape, device=device, dtype=dt))
- continue # Remove once random is supported for bfloat16 on cuda
self.assertEqual(shape, torch.randint(6, shape, device=device, dtype=dt).shape)
self.assertEqual(shape, torch.randint_like(torch.zeros(shape, device=device, dtype=dt), 6).shape)
- if dt != torch.double and dt != torch.float and dt != torch.half:
+ if dt not in {torch.double, torch.float, torch.half, torch.bfloat16}:
self.assertRaises(RuntimeError, lambda: torch.rand(shape, device=device, dtype=dt).shape)
if dt == torch.double or dt == torch.float: