add BFloat16 support for MaxPool2d on CPU (#56903)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56903
Test Plan: Imported from OSS
Reviewed By: mikaylagawarecki
Differential Revision: D28836791
Pulled By: VitalyFedyunin
fbshipit-source-id: e03d55cc30dfa3628f096938fbad34b1031948af
diff --git a/aten/src/ATen/native/cpu/MaxPoolKernel.cpp b/aten/src/ATen/native/cpu/MaxPoolKernel.cpp
index 77be098..e81601b 100644
--- a/aten/src/ATen/native/cpu/MaxPoolKernel.cpp
+++ b/aten/src/ATen/native/cpu/MaxPoolKernel.cpp
@@ -11,7 +11,7 @@
namespace {
-template <typename scalar_t>
+template <typename scalar_t, typename accscalar_t>
void cpu_max_pool(
const Tensor& output_,
const Tensor& indices_,
@@ -57,11 +57,11 @@
// compute local max
int64_t maxindex = ih0 * input_width + iw0;
- scalar_t maxval = -std::numeric_limits<scalar_t>::infinity();
+ accscalar_t maxval = -std::numeric_limits<accscalar_t>::infinity();
for (int64_t ih = ih0; ih < ih1; ih += dilationH) {
for (int64_t iw = iw0; iw < iw1; iw += dilationW) {
int64_t index = ih * input_width + iw;
- scalar_t val = input_ptr[index];
+ accscalar_t val = accscalar_t(input_ptr[index]);
if ((val > maxval) || std::isnan(val)) {
maxval = val;
maxindex = index;
@@ -70,7 +70,7 @@
}
// set output to local max and store location of max
- output_data[i] = maxval;
+ output_data[i] = scalar_t(maxval);
indices_data[i] = maxindex;
// move on to next output index
@@ -119,7 +119,7 @@
// for the convience of vectorization, use integer of the same size of scalar_t,
// e.g. int32_t for float, int64_t for double
// need to make sure doesn't overflow
- TORCH_CHECK(input_height <= std::ceil((double)std::numeric_limits<integer_t>::max() / (double)input_width));
+ TORCH_CHECK(input_height * input_width <= std::numeric_limits<integer_t>::max());
// parallel on dim N, H, W
at::parallel_for(0, nbatch * output_height * output_width, 0, [&](int64_t begin, int64_t end) {
@@ -207,6 +207,150 @@
}
}
+template <>
+void cpu_max_pool_channels_last<BFloat16>(
+ const Tensor& output_,
+ const Tensor& indices_,
+ const Tensor& input_,
+ int kW, int kH,
+ int dW, int dH,
+ int padW, int padH,
+ int dilationW, int dilationH) {
+ TORCH_CHECK(input_.ndimension() == 4,
+ "max pooling with channels last format supports tensors with 4 dims");
+ auto memory_format = at::MemoryFormat::ChannelsLast;
+ auto input = input_.contiguous(memory_format);
+ auto output = output_.contiguous(memory_format);
+ auto indices = indices_.contiguous(memory_format);
+
+ auto input_data = input.data_ptr<BFloat16>();
+ auto output_data = output.data_ptr<BFloat16>();
+ auto indices_data = indices.data_ptr<int64_t>();
+
+ int64_t nbatch = input.size(0);
+ int64_t channels = input.size(1);
+ int64_t input_height = input.size(2);
+ int64_t input_width = input.size(3);
+ int64_t output_height = output.size(2);
+ int64_t output_width = output.size(3);
+
+ using bVec = vec::Vectorized<BFloat16>;
+ using fVec = vec::Vectorized<float>;
+ using iVec = vec::Vectorized<int32_t>;
+ // for the convience of vectorization, use int32_t instead of int64_t
+ TORCH_CHECK(input_height * input_width <= std::numeric_limits<int32_t>::max());
+
+ // parallel on dim N, H, W
+ at::parallel_for(0, nbatch * output_height * output_width, 0, [&](int64_t begin, int64_t end) {
+ int64_t n = 0;
+ int64_t oh = 0;
+ int64_t ow = 0;
+ data_index_init(begin, n, nbatch, oh, output_height, ow, output_width);
+
+ int64_t size = channels;
+ int64_t len = size - (size % bVec::size());
+ // temp buffer holding index with integer_t
+ std::unique_ptr<int32_t []> index_buffer(new int32_t[len]);
+ // temp buffer holding max value with float
+ std::unique_ptr<float []> max_arr(new float[size]);
+ float* max = max_arr.get();
+
+ for (const auto i : c10::irange(begin, end)) {
+ int64_t ih0 = oh * dH - padH;
+ int64_t iw0 = ow * dW - padW;
+ int64_t ih1 = std::min(ih0 + (kH - 1) * dilationH + 1, input_height);
+ int64_t iw1 = std::min(iw0 + (kW - 1) * dilationW + 1, input_width);
+ while(ih0 < 0) { ih0 += dilationH; }
+ while(iw0 < 0) { iw0 += dilationW; }
+
+ BFloat16* out = output_data + i * channels;
+ int64_t* ind = indices_data + i * channels;
+
+ // Pass I: init out lane
+ iVec index0_ivec = iVec(ih0 * input_width + iw0);
+ fVec max_fvec = fVec(-std::numeric_limits<float>::infinity());
+ int64_t d1 = 0;
+ for (; d1 < len; d1 += fVec::size()) {
+ index0_ivec.store(index_buffer.get() + d1);
+ max_fvec.store(max + d1);
+ }
+ for (; d1 < size; d1++) {
+ ind[d1] = ih0 * input_width + iw0;
+ max[d1] = -std::numeric_limits<float>::infinity();
+ }
+ // Pass II: compute local max
+ for (int64_t ih = ih0; ih < ih1; ih += dilationH) {
+ for (int64_t iw = iw0; iw < iw1; iw += dilationW) {
+ BFloat16* in = input_data + n * input_height * input_width * channels +
+ ih * input_width * channels + iw * channels;
+
+ int64_t d2 = 0;
+ for (; d2 < len; d2 += bVec::size()) {
+ iVec index_ivec = iVec(ih * input_width + iw);
+ bVec val_bvec = bVec::loadu(in + d2);
+ fVec val_fvec0, val_fvec1;
+ std::tie(val_fvec0, val_fvec1) = convert_bfloat16_float(val_bvec);
+
+ iVec maxindex_ivec0 = iVec::loadu(index_buffer.get() + d2);
+ iVec maxindex_ivec1 = iVec::loadu(index_buffer.get() + d2 + iVec::size());
+ fVec maxval_fvec0 = fVec::loadu(max + d2);
+ fVec maxval_fvec1 = fVec::loadu(max + d2 + fVec::size());
+
+ // true = all ones, false = all zeros
+ fVec mask0 = (val_fvec0 > maxval_fvec0) | val_fvec0.isnan();
+ fVec mask1 = (val_fvec1 > maxval_fvec1) | val_fvec1.isnan();
+ iVec imask0 = vec::cast<int32_t>(mask0);
+ iVec imask1 = vec::cast<int32_t>(mask1);
+
+ fVec max_fvec0 = fVec::blendv(maxval_fvec0, val_fvec0, mask0);
+ fVec max_fvec1 = fVec::blendv(maxval_fvec1, val_fvec1, mask1);
+ iVec ind_ivec0 = iVec::blendv(maxindex_ivec0, index_ivec, imask0);
+ iVec ind_ivec1 = iVec::blendv(maxindex_ivec1, index_ivec, imask1);
+
+ max_fvec0.store(max + d2);
+ max_fvec1.store(max + d2 + fVec::size());
+ ind_ivec0.store(index_buffer.get() + d2);
+ ind_ivec1.store(index_buffer.get() + d2 + iVec::size());
+ }
+ for (; d2 < size; d2++) {
+ int64_t index = ih * input_width + iw;
+ float val = float(in[d2]);
+ int64_t maxindex = ind[d2];
+ float maxval = max[d2];
+
+ bool mask = (val > maxval) || std::isnan(val);
+ max[d2] = mask ? val : maxval;
+ ind[d2] = mask ? index : maxindex;
+ }
+ }
+ }
+ // Pass III: convert max values from float to bfloat16
+ int64_t d3 = 0;
+ for (; d3 < len; d3 += bVec::size()) {
+ fVec max_fvec0 = fVec::loadu(max + d3);
+ fVec max_fvec1 = fVec::loadu(max + d3 + fVec::size());
+ bVec max_bvec = convert_float_bfloat16(max_fvec0, max_fvec1);
+ max_bvec.store(out + d3);
+ }
+ for (; d3 < size; d3++) {
+ out[d3] = BFloat16(max[d3]);
+ }
+ // convert indice data type
+ vec::convert<int32_t, int64_t>(index_buffer.get(), ind, len);
+
+ // move on to next output index
+ data_index_step(n, nbatch, oh, output_height, ow, output_width);
+ }
+ });
+
+ if (!output_.is_contiguous(memory_format)) {
+ output_.copy_(output);
+ }
+ if (!indices_.is_contiguous(memory_format)) {
+ indices_.copy_(indices);
+ }
+}
+
template <typename scalar_t>
void cpu_max_pool_backward(
const Tensor& grad_input_,
@@ -315,13 +459,17 @@
int dilationW, int dilationH) {
switch (input.suggest_memory_format()) {
case at::MemoryFormat::Contiguous: {
- AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "max_pool2d", [&] {
- cpu_max_pool<scalar_t>(output, indices, input, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
+ AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "max_pool2d", [&] {
+ if (input.scalar_type() == ScalarType::BFloat16) {
+ cpu_max_pool<BFloat16, /*accscalar_t*/float>(output, indices, input, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
+ } else {
+ cpu_max_pool<scalar_t, scalar_t>(output, indices, input, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
+ }
});
break;
}
case at::MemoryFormat::ChannelsLast: {
- AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "max_pool2d_channels_last", [&] {
+ AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "max_pool2d_channels_last", [&] {
cpu_max_pool_channels_last<scalar_t>(output, indices, input, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
});
break;
@@ -337,13 +485,13 @@
const Tensor& indices) {
switch (grad_output.suggest_memory_format()) {
case at::MemoryFormat::Contiguous: {
- AT_DISPATCH_FLOATING_TYPES(grad_output.scalar_type(), "max_pool2d_backward", [&] {
+ AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, grad_output.scalar_type(), "max_pool2d_backward", [&] {
cpu_max_pool_backward<scalar_t>(grad_input, grad_output, indices);
});
break;
}
case at::MemoryFormat::ChannelsLast: {
- AT_DISPATCH_FLOATING_TYPES(grad_output.scalar_type(), "max_pool2d_backward_channels_last", [&] {
+ AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, grad_output.scalar_type(), "max_pool2d_backward_channels_last", [&] {
cpu_max_pool_backward_channels_last<scalar_t>(grad_input, grad_output, indices);
});
break;
diff --git a/test/test_nn.py b/test/test_nn.py
index 115a6fb..deead0e 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -15030,6 +15030,32 @@
helper(10, 512, 31, 31, 3, stride=2)
helper(1, 129, 8, 8, 3, stride=2)
+ @onlyCPU
+ def test_max_pool2d_bfloat16(self, device):
+ def helper(n, c, h, w, kernel_size, stride, memory_format):
+ input = torch.randn(n, c, h, w, dtype=torch.float32, device=device).bfloat16()
+ input = input.to(memory_format=memory_format).requires_grad_()
+ pool = torch.nn.MaxPool2d(kernel_size, stride, return_indices=True).to(device)
+
+ input2 = input.detach().clone().float().requires_grad_(True)
+
+ out, ind = pool(input)
+ out.sum().backward()
+ out2, ind2 = pool(input2)
+ out2.sum().backward()
+
+ self.assertTrue(out.is_contiguous(memory_format=memory_format))
+ self.assertEqual(out.dtype, torch.bfloat16)
+ self.assertEqual(input.grad.dtype, torch.bfloat16)
+ self.assertEqual(out, out2.bfloat16())
+ self.assertEqual(ind, ind2)
+ self.assertEqual(input.grad, input2.grad.bfloat16())
+
+ helper(4, 30, 8, 8, 7, 1, torch.contiguous_format)
+ helper(4, 65, 8, 8, 7, 1, torch.channels_last)
+ helper(1, 19, 20, 10, 8, 2, torch.contiguous_format)
+ helper(1, 19, 20, 10, 8, 2, torch.channels_last)
+
@onlyCUDA
def test_max_pool2d_indices(self, device):
def helper(n, c, h, w, ks):
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index b78ac9a..17656b8 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -11185,6 +11185,7 @@
# TODO: add shape checks
assert_jit_shape_analysis=False,
dtypes=floating_types(),
+ dtypesIfCPU=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
skips=(
# Pre-existing condition; Needs to be fixed
@@ -11203,6 +11204,7 @@
check_batched_forward_grad=False,
assert_jit_shape_analysis=True,
dtypes=floating_types(),
+ dtypesIfCPU=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
sample_inputs_func=sample_inputs_max_pool),
OpInfo('nn.functional.max_pool3d',