add BFloat16 support for MaxPool2d on CPU (#56903)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56903

Test Plan: Imported from OSS

Reviewed By: mikaylagawarecki

Differential Revision: D28836791

Pulled By: VitalyFedyunin

fbshipit-source-id: e03d55cc30dfa3628f096938fbad34b1031948af
diff --git a/aten/src/ATen/native/cpu/MaxPoolKernel.cpp b/aten/src/ATen/native/cpu/MaxPoolKernel.cpp
index 77be098..e81601b 100644
--- a/aten/src/ATen/native/cpu/MaxPoolKernel.cpp
+++ b/aten/src/ATen/native/cpu/MaxPoolKernel.cpp
@@ -11,7 +11,7 @@
 
 namespace {
 
-template <typename scalar_t>
+template <typename scalar_t, typename accscalar_t>
 void cpu_max_pool(
     const Tensor& output_,
     const Tensor& indices_,
@@ -57,11 +57,11 @@
 
       // compute local max
       int64_t maxindex = ih0 * input_width + iw0;
-      scalar_t maxval = -std::numeric_limits<scalar_t>::infinity();
+      accscalar_t maxval = -std::numeric_limits<accscalar_t>::infinity();
       for (int64_t ih = ih0; ih < ih1; ih += dilationH) {
         for (int64_t iw = iw0; iw < iw1; iw += dilationW) {
           int64_t index = ih * input_width + iw;
-          scalar_t val = input_ptr[index];
+          accscalar_t val = accscalar_t(input_ptr[index]);
           if ((val > maxval) || std::isnan(val)) {
             maxval = val;
             maxindex = index;
@@ -70,7 +70,7 @@
       }
 
       // set output to local max and store location of max
-      output_data[i] = maxval;
+      output_data[i] = scalar_t(maxval);
       indices_data[i] = maxindex;
 
       // move on to next output index
@@ -119,7 +119,7 @@
   // for the convience of vectorization, use integer of the same size of scalar_t,
   //   e.g. int32_t for float, int64_t for double
   // need to make sure doesn't overflow
-  TORCH_CHECK(input_height <= std::ceil((double)std::numeric_limits<integer_t>::max() / (double)input_width));
+  TORCH_CHECK(input_height * input_width <= std::numeric_limits<integer_t>::max());
 
   // parallel on dim N, H, W
   at::parallel_for(0, nbatch * output_height * output_width, 0, [&](int64_t begin, int64_t end) {
@@ -207,6 +207,150 @@
   }
 }
 
+template <>
+void cpu_max_pool_channels_last<BFloat16>(
+    const Tensor& output_,
+    const Tensor& indices_,
+    const Tensor& input_,
+    int kW, int kH,
+    int dW, int dH,
+    int padW, int padH,
+    int dilationW, int dilationH) {
+  TORCH_CHECK(input_.ndimension() == 4,
+              "max pooling with channels last format supports tensors with 4 dims");
+  auto memory_format = at::MemoryFormat::ChannelsLast;
+  auto input = input_.contiguous(memory_format);
+  auto output = output_.contiguous(memory_format);
+  auto indices = indices_.contiguous(memory_format);
+
+  auto input_data = input.data_ptr<BFloat16>();
+  auto output_data = output.data_ptr<BFloat16>();
+  auto indices_data = indices.data_ptr<int64_t>();
+
+  int64_t nbatch = input.size(0);
+  int64_t channels = input.size(1);
+  int64_t input_height = input.size(2);
+  int64_t input_width = input.size(3);
+  int64_t output_height = output.size(2);
+  int64_t output_width = output.size(3);
+
+  using bVec = vec::Vectorized<BFloat16>;
+  using fVec = vec::Vectorized<float>;
+  using iVec = vec::Vectorized<int32_t>;
+  // for the convience of vectorization, use int32_t instead of int64_t
+  TORCH_CHECK(input_height * input_width <= std::numeric_limits<int32_t>::max());
+
+  // parallel on dim N, H, W
+  at::parallel_for(0, nbatch * output_height * output_width, 0, [&](int64_t begin, int64_t end) {
+    int64_t n = 0;
+    int64_t oh = 0;
+    int64_t ow = 0;
+    data_index_init(begin, n, nbatch, oh, output_height, ow, output_width);
+
+    int64_t size = channels;
+    int64_t len = size - (size % bVec::size());
+    // temp buffer holding index with integer_t
+    std::unique_ptr<int32_t []> index_buffer(new int32_t[len]);
+    // temp buffer holding max value with float
+    std::unique_ptr<float []> max_arr(new float[size]);
+    float* max = max_arr.get();
+
+    for (const auto i : c10::irange(begin, end)) {
+      int64_t ih0 = oh * dH - padH;
+      int64_t iw0 = ow * dW - padW;
+      int64_t ih1 = std::min(ih0 + (kH - 1) * dilationH + 1, input_height);
+      int64_t iw1 = std::min(iw0 + (kW - 1) * dilationW + 1, input_width);
+      while(ih0 < 0) { ih0 += dilationH; }
+      while(iw0 < 0) { iw0 += dilationW; }
+
+      BFloat16* out = output_data + i * channels;
+      int64_t* ind = indices_data + i * channels;
+
+      // Pass I: init out lane
+      iVec index0_ivec = iVec(ih0 * input_width + iw0);
+      fVec max_fvec = fVec(-std::numeric_limits<float>::infinity());
+      int64_t d1 = 0;
+      for (; d1 < len; d1 += fVec::size()) {
+        index0_ivec.store(index_buffer.get() + d1);
+        max_fvec.store(max + d1);
+      }
+      for (; d1 < size; d1++) {
+        ind[d1] = ih0 * input_width + iw0;
+        max[d1] = -std::numeric_limits<float>::infinity();
+      }
+      // Pass II: compute local max
+      for (int64_t ih = ih0; ih < ih1; ih += dilationH) {
+        for (int64_t iw = iw0; iw < iw1; iw += dilationW) {
+          BFloat16* in = input_data + n * input_height * input_width * channels +
+              ih * input_width * channels + iw * channels;
+
+          int64_t d2 = 0;
+          for (; d2 < len; d2 += bVec::size()) {
+            iVec index_ivec = iVec(ih * input_width + iw);
+            bVec val_bvec = bVec::loadu(in + d2);
+            fVec val_fvec0, val_fvec1;
+            std::tie(val_fvec0, val_fvec1) = convert_bfloat16_float(val_bvec);
+
+            iVec maxindex_ivec0 = iVec::loadu(index_buffer.get() + d2);
+            iVec maxindex_ivec1 = iVec::loadu(index_buffer.get() + d2 + iVec::size());
+            fVec maxval_fvec0 = fVec::loadu(max + d2);
+            fVec maxval_fvec1 = fVec::loadu(max + d2 + fVec::size());
+
+            // true = all ones, false = all zeros
+            fVec mask0 = (val_fvec0 > maxval_fvec0) | val_fvec0.isnan();
+            fVec mask1 = (val_fvec1 > maxval_fvec1) | val_fvec1.isnan();
+            iVec imask0 = vec::cast<int32_t>(mask0);
+            iVec imask1 = vec::cast<int32_t>(mask1);
+
+            fVec max_fvec0 = fVec::blendv(maxval_fvec0, val_fvec0, mask0);
+            fVec max_fvec1 = fVec::blendv(maxval_fvec1, val_fvec1, mask1);
+            iVec ind_ivec0 = iVec::blendv(maxindex_ivec0, index_ivec, imask0);
+            iVec ind_ivec1 = iVec::blendv(maxindex_ivec1, index_ivec, imask1);
+
+            max_fvec0.store(max + d2);
+            max_fvec1.store(max + d2 + fVec::size());
+            ind_ivec0.store(index_buffer.get() + d2);
+            ind_ivec1.store(index_buffer.get() + d2 + iVec::size());
+          }
+          for (; d2 < size; d2++) {
+            int64_t index = ih * input_width + iw;
+            float val = float(in[d2]);
+            int64_t maxindex = ind[d2];
+            float maxval = max[d2];
+
+            bool mask = (val > maxval) || std::isnan(val);
+            max[d2] = mask ? val : maxval;
+            ind[d2] = mask ? index : maxindex;
+          }
+        }
+      }
+      // Pass III: convert max values from float to bfloat16
+      int64_t d3 = 0;
+      for (; d3 < len; d3 += bVec::size()) {
+        fVec max_fvec0 = fVec::loadu(max + d3);
+        fVec max_fvec1 = fVec::loadu(max + d3 + fVec::size());
+        bVec max_bvec = convert_float_bfloat16(max_fvec0, max_fvec1);
+        max_bvec.store(out + d3);
+      }
+      for (; d3 < size; d3++) {
+        out[d3] = BFloat16(max[d3]);
+      }
+      // convert indice data type
+      vec::convert<int32_t, int64_t>(index_buffer.get(), ind, len);
+
+      // move on to next output index
+      data_index_step(n, nbatch, oh, output_height, ow, output_width);
+    }
+  });
+
+  if (!output_.is_contiguous(memory_format)) {
+    output_.copy_(output);
+  }
+  if (!indices_.is_contiguous(memory_format)) {
+    indices_.copy_(indices);
+  }
+}
+
 template <typename scalar_t>
 void cpu_max_pool_backward(
     const Tensor& grad_input_,
@@ -315,13 +459,17 @@
     int dilationW, int dilationH) {
   switch (input.suggest_memory_format()) {
     case at::MemoryFormat::Contiguous: {
-      AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "max_pool2d", [&] {
-        cpu_max_pool<scalar_t>(output, indices, input, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
+      AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "max_pool2d", [&] {
+        if (input.scalar_type() == ScalarType::BFloat16) {
+          cpu_max_pool<BFloat16, /*accscalar_t*/float>(output, indices, input, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
+        } else {
+          cpu_max_pool<scalar_t, scalar_t>(output, indices, input, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
+        }
       });
       break;
     }
     case at::MemoryFormat::ChannelsLast: {
-      AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "max_pool2d_channels_last", [&] {
+      AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, input.scalar_type(), "max_pool2d_channels_last", [&] {
         cpu_max_pool_channels_last<scalar_t>(output, indices, input, kW, kH, dW, dH, padW, padH, dilationW, dilationH);
       });
       break;
@@ -337,13 +485,13 @@
     const Tensor& indices) {
   switch (grad_output.suggest_memory_format()) {
     case at::MemoryFormat::Contiguous: {
-      AT_DISPATCH_FLOATING_TYPES(grad_output.scalar_type(), "max_pool2d_backward", [&] {
+      AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, grad_output.scalar_type(), "max_pool2d_backward", [&] {
         cpu_max_pool_backward<scalar_t>(grad_input, grad_output, indices);
       });
       break;
     }
     case at::MemoryFormat::ChannelsLast: {
-      AT_DISPATCH_FLOATING_TYPES(grad_output.scalar_type(), "max_pool2d_backward_channels_last", [&] {
+      AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::BFloat16, grad_output.scalar_type(), "max_pool2d_backward_channels_last", [&] {
         cpu_max_pool_backward_channels_last<scalar_t>(grad_input, grad_output, indices);
       });
       break;
diff --git a/test/test_nn.py b/test/test_nn.py
index 115a6fb..deead0e 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -15030,6 +15030,32 @@
         helper(10, 512, 31, 31, 3, stride=2)
         helper(1, 129, 8, 8, 3, stride=2)
 
+    @onlyCPU
+    def test_max_pool2d_bfloat16(self, device):
+        def helper(n, c, h, w, kernel_size, stride, memory_format):
+            input = torch.randn(n, c, h, w, dtype=torch.float32, device=device).bfloat16()
+            input = input.to(memory_format=memory_format).requires_grad_()
+            pool = torch.nn.MaxPool2d(kernel_size, stride, return_indices=True).to(device)
+
+            input2 = input.detach().clone().float().requires_grad_(True)
+
+            out, ind = pool(input)
+            out.sum().backward()
+            out2, ind2 = pool(input2)
+            out2.sum().backward()
+
+            self.assertTrue(out.is_contiguous(memory_format=memory_format))
+            self.assertEqual(out.dtype, torch.bfloat16)
+            self.assertEqual(input.grad.dtype, torch.bfloat16)
+            self.assertEqual(out, out2.bfloat16())
+            self.assertEqual(ind, ind2)
+            self.assertEqual(input.grad, input2.grad.bfloat16())
+
+        helper(4, 30, 8, 8, 7, 1, torch.contiguous_format)
+        helper(4, 65, 8, 8, 7, 1, torch.channels_last)
+        helper(1, 19, 20, 10, 8, 2, torch.contiguous_format)
+        helper(1, 19, 20, 10, 8, 2, torch.channels_last)
+
     @onlyCUDA
     def test_max_pool2d_indices(self, device):
         def helper(n, c, h, w, ks):
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index b78ac9a..17656b8 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -11185,6 +11185,7 @@
            # TODO: add shape checks
            assert_jit_shape_analysis=False,
            dtypes=floating_types(),
+           dtypesIfCPU=floating_types_and(torch.bfloat16),
            dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
            skips=(
                # Pre-existing condition; Needs to be fixed
@@ -11203,6 +11204,7 @@
            check_batched_forward_grad=False,
            assert_jit_shape_analysis=True,
            dtypes=floating_types(),
+           dtypesIfCPU=floating_types_and(torch.bfloat16),
            dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
            sample_inputs_func=sample_inputs_max_pool),
     OpInfo('nn.functional.max_pool3d',