improve batch_norm contiguous case's performance (#34530)

Summary:
For batch_norm inference contiguous case, we can get a better performance by manually vectorize it.
Test script:
```                                                                                                   X
 import torch
 import torch.nn as nn
 import time

 torch.manual_seed(0)

 for n in [1, 10, 100]:
     for c in [1, 10, 100]:
         for hw in [1, 10, 200]:
             m = nn.BatchNorm2d(c, affine=False)
             m.eval()
             input = torch.randn(20, c, hw, hw)
             # warm up
             for i in range(200):
                 output = m(input)
             fwd_t = 0
             for j in range(1000):
                 t1 = time.time()
                 output = m(input)
                 t2 = time.time()
                 fwd_t = fwd_t + (t2 -t1)

             fwd_avg = fwd_t / 1000 * 1000
             print("size = (%d, %d, %d, %d); compute time is %.4f(ms)" % (n, c, hw, hw, fwd_avg))
```

Before:
```
size = (1, 1, 1, 1); compute time is 0.0110(ms)
size = (1, 1, 10, 10); compute time is 0.0123(ms)
size = (1, 1, 200, 200); compute time is 0.8166(ms)
size = (1, 10, 1, 1); compute time is 0.0107(ms)
size = (1, 10, 10, 10); compute time is 0.0257(ms)
size = (1, 10, 200, 200); compute time is 8.7533(ms)
size = (1, 100, 1, 1); compute time is 0.0122(ms)
size = (1, 100, 10, 10); compute time is 0.1619(ms)
size = (1, 100, 200, 200); compute time is 123.5674(ms)
size = (10, 1, 1, 1); compute time is 0.0109(ms)
size = (10, 1, 10, 10); compute time is 0.0123(ms)
size = (10, 1, 200, 200); compute time is 0.5629(ms)
size = (10, 10, 1, 1); compute time is 0.0107(ms)
size = (10, 10, 10, 10); compute time is 0.0253(ms)
size = (10, 10, 200, 200); compute time is 8.7817(ms)
size = (10, 100, 1, 1); compute time is 0.0120(ms)
size = (10, 100, 10, 10); compute time is 0.1655(ms)
size = (10, 100, 200, 200); compute time is 123.2488(ms)
size = (100, 1, 1, 1); compute time is 0.0109(ms)
size = (100, 1, 10, 10); compute time is 0.0123(ms)
size = (100, 1, 200, 200); compute time is 0.5740(ms)
size = (100, 10, 1, 1); compute time is 0.0108(ms)
size = (100, 10, 10, 10); compute time is 0.0257(ms)
size = (100, 10, 200, 200); compute time is 8.7201(ms)
size = (100, 100, 1, 1); compute time is 0.0122(ms)
size = (100, 100, 10, 10); compute time is 0.1628(ms)
size = (100, 100, 200, 200); compute time is 123.1739(ms)
```
After:
```
size = (1, 1, 1, 1); compute time is 0.0105(ms)
size = (1, 1, 10, 10); compute time is 0.0114(ms)
size = (1, 1, 200, 200); compute time is 0.5771(ms)
size = (1, 10, 1, 1); compute time is 0.0105(ms)
size = (1, 10, 10, 10); compute time is 0.0160(ms)
size = (1, 10, 200, 200); compute time is 6.9851(ms)
size = (1, 100, 1, 1); compute time is 0.0122(ms)
size = (1, 100, 10, 10); compute time is 0.0848(ms)
size = (1, 100, 200, 200); compute time is 98.6758(ms)
size = (10, 1, 1, 1); compute time is 0.0105(ms)
size = (10, 1, 10, 10); compute time is 0.0115(ms)
size = (10, 1, 200, 200); compute time is 0.2690(ms)
size = (10, 10, 1, 1); compute time is 0.0105(ms)
size = (10, 10, 10, 10); compute time is 0.0159(ms)
size = (10, 10, 200, 200); compute time is 6.6946(ms)
size = (10, 100, 1, 1); compute time is 0.0123(ms)
size = (10, 100, 10, 10); compute time is 0.0854(ms)
size = (10, 100, 200, 200); compute time is 98.7327(ms)
size = (100, 1, 1, 1); compute time is 0.0107(ms)
size = (100, 1, 10, 10); compute time is 0.0116(ms)
size = (100, 1, 200, 200); compute time is 0.2681(ms)
size = (100, 10, 1, 1); compute time is 0.0104(ms)
size = (100, 10, 10, 10); compute time is 0.0159(ms)
size = (100, 10, 200, 200); compute time is 6.7507(ms)
size = (100, 100, 1, 1); compute time is 0.0124(ms)
size = (100, 100, 10, 10); compute time is 0.0852(ms)
size = (100, 100, 200, 200); compute time is 98.6866(ms)
```
For real modle Resnext101, we can also get **~20%** performance improvement for large batch size,
Test script:
```
 import torch
 import torchvision
 import torch
 import time

 torch.manual_seed(0)
 #torch.set_num_threads(1)

 model = torchvision.models.resnext101_32x8d().eval()

 for batch_size in [1, 64]:
     input = torch.randn(batch_size, 3, 224, 224)
     #warm up
     with torch.no_grad():
         for i in range(5):
             output = model(input)

         fwd_t = 0
         for i in range(10):
             t1 = time.time()
             output = model(input)
             t2 = time.time()
             fwd_t = fwd_t + (t2 - t1)

         time_fwd_avg = fwd_t / 10 * 1000
         print("Throughput of resnext101 with batch_size = %d is %10.2f (imgs/s)" % (batch_size, batch_size * 1000/              time_fwd_avg ))
```
Before:
```
Throughput of resnext101 with batch_size = 1 is       7.89 (imgs/s)
Throughput of resnext101 with batch_size = 64 is      13.02 (imgs/s)

num_threads =1
Throughput of resnext101 with batch_size = 1 is       2.97 (imgs/s)
Throughput of resnext101 with batch_size = 64 is       2.75 (imgs/s)
```
After:
```
Throughput of resnext101 with batch_size = 1 is       8.95 (imgs/s)
Throughput of resnext101 with batch_size = 64 is      15.52 (imgs/s)

num_threads = 1
Throughput of resnext101 with batch_size = 1 is       3.10 (imgs/s)
Throughput of resnext101 with batch_size = 64 is       2.88 (imgs/s)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34530

Differential Revision: D20479560

Pulled By: ngimel

fbshipit-source-id: 2e788ebcd814556116c90553ec61159eeffb3c16
diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp
index bbed67b..bce9e9d 100644
--- a/aten/src/ATen/native/Normalization.cpp
+++ b/aten/src/ATen/native/Normalization.cpp
@@ -8,6 +8,7 @@
 #include <ATen/detail/CUDAHooksInterface.h>
 #include <ATen/native/TensorIterator.h>
 #include <ATen/native/cpu/Loops.h>
+#include <ATen/native/batch_norm.h>
 
 #include <vector>
 
@@ -15,6 +16,8 @@
 
 namespace at { namespace native {
 
+DEFINE_DISPATCH(batch_norm_cpu_inference_contiguous_stub);
+
 namespace {
   void check_dims_match_num_input_features(const char* arg_name, int64_t expected, int64_t actual){
     TORCH_CHECK(actual == expected,
@@ -87,59 +90,6 @@
   }
 }
 
-/// A fast path for CPU inference when all tensors are contiguous.
-/// This code achieves machine bandwidth peak without AVX support.
-/// If this changes for future architectures, we can move it to the cpu/
-/// directory.
-template<typename scalar_t>
-void batch_norm_cpu_inference_contiguous(Tensor& output, const Tensor& input,
-    const Tensor& weight /* optional */, const Tensor& bias /* optional */,
-    const Tensor& mean, const Tensor& variance, double eps) {
-
-  int64_t n_batch = input.size(0);
-  int64_t n_channel = input.size(1);
-  int64_t image_size = input.numel() / n_batch / n_channel;
-
-  scalar_t* output_data = output.data_ptr<scalar_t>();
-  const scalar_t* input_data = input.data_ptr<scalar_t>();
-
-  Tensor alpha = at::empty_like(mean, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
-  Tensor beta = at::empty_like(mean, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
-  scalar_t* alpha_data = alpha.data_ptr<scalar_t>();
-  scalar_t* beta_data = beta.data_ptr<scalar_t>();
-
-  batch_norm_cpu_inference_collect_linear_and_constant_terms<scalar_t>(
-      alpha_data, beta_data, n_channel, weight, bias, mean, variance, eps);
-
-  // Apply the linear terms to the input,
-  // output(n, c, h, w) = input(n, c, h, w) * alpha(c) + beta(c)
-  // No need to use parallel_for as this function is supposed to be
-  // memory-limited.
-  // Keep the loop struture simple to make sure compiler vectorization kicks in.
-  if (image_size != 1) {
-    for (int64_t n = 0; n < n_batch; ++n) {
-      for (int64_t c = 0; c < n_channel; ++c) {
-        for (int64_t i = 0; i < image_size; ++i) {
-          // Keep all the offset calculation within the inner loop for
-          // simplicity. Compilers are very good at hoisting the common part
-          // outside.
-          int64_t offset = n * n_channel * image_size + c * image_size + i;
-          output_data[offset] = input_data[offset] * alpha_data[c] +
-              beta_data[c];
-        }
-      }
-    }
-  } else {
-    // image_size == 1
-    for (int64_t n = 0; n < n_batch; ++n) {
-      for (int64_t c = 0; c < n_channel; ++c) {
-        int64_t offset = n * n_channel + c;
-        output_data[offset] = input_data[offset] * alpha_data[c] + beta_data[c];
-      }
-    }
-  }
-}
-
 /// A fast path for CPU inference when all tensors are channels last contiguous.
 /// This code achieves machine bandwidth peak without AVX support.
 /// If this changes for future architectures, we can move it to the cpu/
@@ -207,8 +157,8 @@
       && running_var.is_contiguous()) {
 
     Tensor output = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
-    batch_norm_cpu_inference_contiguous<scalar_t>(
-      output, input, weight, bias, running_mean, running_var, eps);
+    batch_norm_cpu_inference_contiguous_stub(kCPU, output, input, weight,
+        bias, running_mean, running_var, eps);
     return std::make_tuple(output, save_mean, save_invstd);
   }
 
diff --git a/aten/src/ATen/native/batch_norm.h b/aten/src/ATen/native/batch_norm.h
new file mode 100644
index 0000000..9ff05b7
--- /dev/null
+++ b/aten/src/ATen/native/batch_norm.h
@@ -0,0 +1,18 @@
+#pragma once
+
+#include <ATen/ATen.h>
+#include <ATen/native/DispatchStub.h>
+
+namespace at {
+
+namespace native {
+
+using batch_norm_fn = void (*)(Tensor&, const Tensor&, const Tensor&,
+    const Tensor&, const Tensor&, const Tensor&, double);
+
+DECLARE_DISPATCH(batch_norm_fn, batch_norm_cpu_inference_contiguous_stub);
+
+} // namespace native
+
+} // namespace at
+
diff --git a/aten/src/ATen/native/cpu/batch_norm_kernel.cpp b/aten/src/ATen/native/cpu/batch_norm_kernel.cpp
new file mode 100644
index 0000000..ad720db
--- /dev/null
+++ b/aten/src/ATen/native/cpu/batch_norm_kernel.cpp
@@ -0,0 +1,114 @@
+#include <ATen/native/batch_norm.h>
+
+#include <ATen/ATen.h>
+#include <ATen/CPUApplyUtils.h>
+#include <ATen/Dispatch.h>
+#include <ATen/native/TensorIterator.h>
+#include <ATen/native/cpu/Loops.h>
+
+namespace at { namespace native {
+namespace {
+
+using namespace vec256;
+
+template<typename scalar_t>
+void batch_norm_cpu_inference_collect_linear_and_constant_terms(
+    TensorAccessor<scalar_t, 1> alpha, TensorAccessor<scalar_t, 1> beta, int64_t n_channel,
+    const Tensor& weight /* optional */, const Tensor& bias /* optional */,
+    const Tensor& mean, const Tensor& variance, double eps) {
+
+  const scalar_t* weight_data = weight.defined() ? weight.data_ptr<scalar_t>() : nullptr;
+  const scalar_t* bias_data = bias.defined() ? bias.data_ptr<scalar_t>() : nullptr;
+  auto mean_data = mean.accessor<scalar_t, 1>();
+  auto var_data = variance.accessor<scalar_t, 1>();
+
+  /// Collect the linear and constant terms regarding the input.
+  /// output(n, c, h, w)
+  ///     = (input(n, c, h, w) - mean(c)) / sqrt(var(c) + eps) * weight(c)
+  ///         + bias(c)
+  ///     = input(n, c, h, w) * inv_var(c) * weight(c)
+  ///         - mean(c) * inv_var(c) * weight(c) + bias(c),
+  /// where inv_var(c) = 1 / sqrt(var(c) + eps).
+  /// So the linear term, alpha(c) = inv_var(c) * weight(c),
+  ///   the constant term beta(c) = bias(c) - mean(c) * inv_var(c) * weight(c)
+  /// Note that this is only a good idea if (input_size >> c), in degenerate
+  /// cases where image_size == 1 && batch_size == 1, it is slow.
+  for (int64_t c = 0; c < n_channel; c++) {
+    scalar_t inv_var = 1 / std::sqrt(var_data[c] + static_cast<scalar_t>(eps));
+    scalar_t weight_v = weight_data ? weight_data[c] : 1;
+    scalar_t bias_v = bias_data ? bias_data[c] : 0;
+    alpha[c] = inv_var * weight_v;
+    beta[c] = bias_v - mean_data[c] * alpha[c];
+  }
+}
+
+/// A fast path for CPU inference when all tensors are contiguous.
+template<typename scalar_t>
+void batch_norm_cpu_inference_contiguous_impl(Tensor& output,
+    const Tensor& input, const Tensor& weight, const Tensor& bias,
+    const Tensor& mean, const Tensor& variance, double eps) {
+
+  using Vec = Vec256<scalar_t>;
+  int64_t n_batch = input.size(0);
+  int64_t n_channel = input.size(1);
+  int64_t image_size = input.numel() / n_batch / n_channel;
+
+  Tensor alpha = at::empty_like(mean, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+  Tensor beta = at::empty_like(mean, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
+  auto alpha_data = alpha.accessor<scalar_t, 1>();
+  auto beta_data = beta.accessor<scalar_t, 1>();
+
+  batch_norm_cpu_inference_collect_linear_and_constant_terms<scalar_t>(
+     alpha_data, beta_data, n_channel, weight, bias, mean, variance, eps);
+
+  scalar_t* output_data = output.data_ptr<scalar_t>();
+  const scalar_t* input_data = input.data_ptr<scalar_t>();
+
+  // Apply the linear terms to the input,
+  // output(n, c, h, w) = input(n, c, h, w) * alpha(c) + beta(c)
+  // No need to use parallel_for as this function is supposed to be
+  // memory-limited.
+  if (image_size != 1) {
+    const int64_t n_offset = n_channel * image_size;
+    const int64_t loop_size = image_size - (image_size % Vec::size());
+    for (int64_t n = 0; n < n_batch; n++) { 
+      for (int64_t c = 0; c < n_channel; c++) {
+        const Vec alpha_vec(alpha_data[c]);
+        const Vec beta_vec(beta_data[c]);
+        int64_t offset = n * n_offset + c * image_size;
+        int64_t d = 0;
+        for (; d < loop_size; d += Vec::size()) {
+          Vec data_vec = Vec::loadu(input_data + offset + d);
+          Vec output_vec = data_vec * alpha_vec + beta_vec;
+          output_vec.store(output_data + offset + d);
+        }
+        if (image_size - d > 0) {
+          Vec data_vec = Vec::loadu(input_data + offset + d, image_size - d);
+          Vec output_vec = data_vec * alpha_vec + beta_vec;
+          output_vec.store(output_data + offset + d, image_size - d);
+        }
+      }
+    }
+  } else {
+    // image_size == 1
+    for (int64_t n = 0; n < n_batch; ++n) {
+      for (int64_t c = 0; c < n_channel; ++c) {
+        int64_t offset = n * n_channel + c;
+        output_data[offset] = input_data[offset] * alpha_data[c] + beta_data[c];
+      }
+    }
+  }
+}
+
+void batch_norm_cpu_inference_contiguous_kernel(Tensor& output, const Tensor& input,
+    const Tensor& weight, const Tensor& bias, const Tensor& mean, const Tensor& variance, double eps) {
+  AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "batch_norm_cpu_inference_contiguous", [&] {
+    batch_norm_cpu_inference_contiguous_impl<scalar_t>(output, input, weight, bias, mean, variance, eps);
+  });
+}
+
+}// anonymous namespace
+
+REGISTER_DISPATCH(batch_norm_cpu_inference_contiguous_stub, &batch_norm_cpu_inference_contiguous_kernel);
+
+}} // namespace at::native