Vectorized quantized relu/relu6 (#25496)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25496

Benchmark Script

```
import torch, time

sizes = [
    (1, 56, 56, 256),
    (1, 28, 28, 512),
    (1, 14, 14, 1024),
    (1, 7, 7, 2048),
]

NITER = 1000

for dtype in [torch.qint8, torch.quint8, torch.qint32]:
    print('*****', str(dtype), '*****')

    print('\t*****relu*****')
    print('\tsize',
          'time (float ms)',
          'time (quant ms)',
          'quant / float',
          sep='\t')
    for size in sizes:
        # NHWC
        x = torch.rand(*size)

        # NCHW
        x = x.permute([0, 2, 3, 1])

        # Test float
        s = time.time()
        for i in range(NITER):
            torch.relu(x)
        time_per_iter_float = (time.time() - s) / NITER

        # Test quantized
        q_x = torch.quantize_linear(x, 0.5, 1, dtype)

        s = time.time()
        for i in range(NITER):
            torch.relu(q_x)
        time_per_iter_quant = (time.time() - s) / NITER

        print('\t',
              size,
              time_per_iter_float * 1000,
              time_per_iter_quant * 1000,
              time_per_iter_quant / time_per_iter_float,
              sep='\t')

    print('\t*****relu6*****')
    print('\tsize',
          'time (float ms)',
          'time (quant ms)',
          'quant / float',
          sep='\t')
    for size in sizes:
        # NHWC
        x = torch.rand(*size)

        # NCHW
        x = x.permute([0, 2, 3, 1])

        # Test float relu6
        s = time.time()
        for i in range(NITER):
            torch._C._nn.hardtanh(x, 0., 6.)
        time_per_iter_float_6 = (time.time() - s) / NITER

        # Test quantized relu6
        q_x = torch.quantize_linear(x, 0.5, 1, dtype)

        s = time.time()
        for i in range(NITER):
            torch.ops.quantized.relu6(q_x)
        time_per_iter_quant_6 = (time.time() - s) / NITER

        print('\t',
              size,
              time_per_iter_float_6 * 1000,
              time_per_iter_quant_6 * 1000,
              time_per_iter_quant_6 / time_per_iter_float_6,
              sep='\t')

```

Before this change (AVX2)

```
$ OMP_NUM_THREADS=1 python relu_bench.py
***** torch.qint8 *****
	*****relu*****
	size	time (float ms)	time (quant ms)	quant / float
		(1, 56, 56, 256)	0.28845906257629395	0.32473158836364746	1.1257458353479874
		(1, 28, 28, 512)	0.12658190727233887	0.1621997356414795	1.2813816692816096
		(1, 14, 14, 1024)	0.060466766357421875	0.08151435852050781	1.3480852943031985
		(1, 7, 7, 2048)	0.021933555603027344	0.04172706604003906	1.9024305404582809
	*****relu6*****
	size	time (float ms)	time (quant ms)	quant / float
		(1, 56, 56, 256)	1.0264298915863037	0.4686436653137207	0.45657640054641424
		(1, 28, 28, 512)	0.4577608108520508	0.23253798484802246	0.5079901541051298
		(1, 14, 14, 1024)	0.22967290878295898	0.11695981025695801	0.509245129853278
		(1, 7, 7, 2048)	0.12731575965881348	0.060141801834106445	0.4723830105187069
***** torch.quint8 *****
	*****relu*****
	size	time (float ms)	time (quant ms)	quant / float
		(1, 56, 56, 256)	0.28515172004699707	0.32268643379211426	1.1316306762551913
		(1, 28, 28, 512)	0.1268613338470459	0.1618938446044922	1.2761480562681475
		(1, 14, 14, 1024)	0.06022787094116211	0.08164644241333008	1.355625578946535
		(1, 7, 7, 2048)	0.018331527709960938	0.04460000991821289	2.432967433149516
	*****relu6*****
	size	time (float ms)	time (quant ms)	quant / float
		(1, 56, 56, 256)	1.027123212814331	0.5206699371337891	0.50692062124382
		(1, 28, 28, 512)	0.4589383602142334	0.25958728790283203	0.565625605542444
		(1, 14, 14, 1024)	0.23261427879333496	0.13058066368103027	0.561361341867771
		(1, 7, 7, 2048)	0.13072657585144043	0.06684517860412598	0.5113358027528374
***** torch.qint32 *****
	*****relu*****
	size	time (float ms)	time (quant ms)	quant / float
		(1, 56, 56, 256)	0.285900354385376	0.44794583320617676	1.5667900593168678
		(1, 28, 28, 512)	0.12691712379455566	0.21081137657165527	1.6610160258035915
		(1, 14, 14, 1024)	0.05957603454589844	0.10731720924377441	1.8013486473507283
		(1, 7, 7, 2048)	0.01675701141357422	0.05678510665893555	3.388737123669683
	*****relu6*****
	size	time (float ms)	time (quant ms)	quant / float
		(1, 56, 56, 256)	1.0314903259277344	0.6447939872741699	0.6251090980366052
		(1, 28, 28, 512)	0.4572310447692871	0.3106963634490967	0.6795172090859886
		(1, 14, 14, 1024)	0.2294166088104248	0.1586904525756836	0.6917130080447454
		(1, 7, 7, 2048)	0.12760710716247559	0.07992196083068848	0.6263127705647926

```

After this change (AVX2)

```
$ OMP_NUM_THREADS=1 python relu_bench.py

***** torch.qint8 *****
	*****relu*****
	size	time (float ms)	time (quant ms)	quant / float
		(1, 56, 56, 256)	0.2889232635498047	0.06460881233215332	0.22361928056034167
		(1, 28, 28, 512)	0.13853216171264648	0.013955354690551758	0.10073729102343015
		(1, 14, 14, 1024)	0.0721442699432373	0.007253408432006836	0.10054032617855548
		(1, 7, 7, 2048)	0.015225648880004883	0.004289150238037109	0.28170557930505313
	*****relu6*****
	size	time (float ms)	time (quant ms)	quant / float
		(1, 56, 56, 256)	1.042311191558838	0.06422209739685059	0.061615089540392104
		(1, 28, 28, 512)	0.46384429931640625	0.01335287094116211	0.028787399049295198
		(1, 14, 14, 1024)	0.2301616668701172	0.007760286331176758	0.033716675920477994
		(1, 7, 7, 2048)	0.12573981285095215	0.004631757736206055	0.03683604763827976
***** torch.quint8 *****
	*****relu*****
	size	time (float ms)	time (quant ms)	quant / float
		(1, 56, 56, 256)	0.2877991199493408	0.0571134090423584	0.1984488661828141
		(1, 28, 28, 512)	0.12664175033569336	0.013076543807983398	0.10325618347283565
		(1, 14, 14, 1024)	0.06389951705932617	0.005294084548950195	0.08285014961904974
		(1, 7, 7, 2048)	0.016280174255371094	0.003660917282104492	0.22486966199988284
	*****relu6*****
	size	time (float ms)	time (quant ms)	quant / float
		(1, 56, 56, 256)	1.0244698524475098	0.05978655815124512	0.05835853344870231
		(1, 28, 28, 512)	0.454937219619751	0.013289213180541992	0.02921109244842504
		(1, 14, 14, 1024)	0.22972846031188965	0.0077877044677734375	0.03389960676705229
		(1, 7, 7, 2048)	0.125657320022583	0.0045795440673828125	0.03644470586003093
***** torch.qint32 *****
	*****relu*****
	size	time (float ms)	time (quant ms)	quant / float
		(1, 56, 56, 256)	0.28399205207824707	0.2665698528289795	0.9386525111468004
		(1, 28, 28, 512)	0.12665152549743652	0.12166023254394531	0.9605903447756557
		(1, 14, 14, 1024)	0.0598299503326416	0.059305429458618164	0.9912331387355795
		(1, 7, 7, 2048)	0.014290809631347656	0.012906551361083984	0.9031364698031366
	*****relu6*****
	size	time (float ms)	time (quant ms)	quant / float
		(1, 56, 56, 256)	1.020923376083374	0.27229976654052734	0.2667191024513184
		(1, 28, 28, 512)	0.4564201831817627	0.12390279769897462	0.2714665176181136
		(1, 14, 14, 1024)	0.23244047164916992	0.05935955047607422	0.25537527976482316
		(1, 7, 7, 2048)	0.1271505355834961	0.014976024627685547	0.11778184463762029

```

Test Plan: Imported from OSS

Differential Revision: D17141891

Pulled By: jamesr66a

fbshipit-source-id: 14b8c3330017c518a6b385780a449ca51efef0ce
diff --git a/aten/src/ATen/cpu/vec256/vec256_qint.h b/aten/src/ATen/cpu/vec256/vec256_qint.h
index b02c94c..b64c611 100644
--- a/aten/src/ATen/cpu/vec256/vec256_qint.h
+++ b/aten/src/ATen/cpu/vec256/vec256_qint.h
@@ -128,6 +128,48 @@
         return retval;
     }
 
+    Vec256<c10::qint8> relu(Vec256<c10::qint8> zero_point) {
+#ifdef __AVX2__
+      return _mm256_max_epi8(vals, zero_point.vals);
+#else
+      // Pray the compiler can autovectorize this
+      int8_t int_vals[size()];
+      _mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals);
+      int8_t zero_point_vals[size()];
+      _mm256_storeu_si256(
+          reinterpret_cast<__m256i*>(&zero_point_vals), zero_point.vals);
+      int8_t result_vals[size()];
+      for (size_t i = 0; i < size(); ++i) {
+        result_vals[i] = std::max<int8_t>(int_vals[i], zero_point_vals[i]);
+      }
+      return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals));
+#endif
+    }
+
+    Vec256<c10::qint8> relu6(
+        Vec256<c10::qint8> zero_point,
+        Vec256<c10::qint8> q_six) {
+#ifdef __AVX2__
+      return _mm256_min_epi8(
+          _mm256_max_epi8(vals, zero_point.vals), q_six.vals);
+#else
+      // Pray the compiler can autovectorize this
+      int8_t int_vals[size()];
+      _mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals);
+      int8_t zero_point_vals[size()];
+      _mm256_storeu_si256(
+          reinterpret_cast<__m256i*>(&zero_point_vals), zero_point.vals);
+      int8_t q_six_vals[size()];
+      _mm256_storeu_si256(reinterpret_cast<__m256i*>(&q_six_vals), q_six.vals);
+      int8_t result_vals[size()];
+      for (size_t i = 0; i < size(); ++i) {
+        result_vals[i] = std::min<int8_t>(
+            std::max<int8_t>(int_vals[i], zero_point_vals[i]), q_six_vals[i]);
+      }
+      return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals));
+#endif
+    }
+
     void dump() const {
         for (size_t i = 0; i < size(); ++i) {
             std::cout << (int)((value_type*)&vals)[i] << " ";
@@ -137,6 +179,8 @@
  private:
     Vec256() {}
 
+    Vec256(__m256i vals_) : vals(vals_) {}
+
     // Load from memory constructor
     Vec256(const void* ptr) {
         vals = _mm256_loadu_si256((const __m256i*)ptr);
@@ -228,6 +272,48 @@
         return retval;
     }
 
+    Vec256<c10::quint8> relu(Vec256<c10::quint8> zero_point) {
+#ifdef __AVX2__
+      return _mm256_max_epu8(vals, zero_point.vals);
+#else
+      // Pray the compiler can autovectorize this
+      uint8_t int_vals[size()];
+      _mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals);
+      uint8_t zero_point_vals[size()];
+      _mm256_storeu_si256(
+          reinterpret_cast<__m256i*>(&zero_point_vals), zero_point.vals);
+      uint8_t result_vals[size()];
+      for (size_t i = 0; i < size(); ++i) {
+        result_vals[i] = std::max<uint8_t>(int_vals[i], zero_point_vals[i]);
+      }
+      return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals));
+#endif
+    }
+
+    Vec256<c10::quint8> relu6(
+        Vec256<c10::quint8> zero_point,
+        Vec256<c10::quint8> q_six) {
+#ifdef __AVX2__
+      return _mm256_min_epi8(
+          _mm256_max_epi8(vals, zero_point.vals), q_six.vals);
+#else
+      // Pray the compiler can autovectorize this
+      uint8_t int_vals[size()];
+      _mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals);
+      uint8_t zero_point_vals[size()];
+      _mm256_storeu_si256(
+          reinterpret_cast<__m256i*>(&zero_point_vals), zero_point.vals);
+      uint8_t q_six_vals[size()];
+      _mm256_storeu_si256(reinterpret_cast<__m256i*>(&q_six_vals), q_six.vals);
+      uint8_t result_vals[size()];
+      for (size_t i = 0; i < size(); ++i) {
+        result_vals[i] = std::min<uint8_t>(
+            std::max<uint8_t>(int_vals[i], zero_point_vals[i]), q_six_vals[i]);
+      }
+      return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals));
+#endif
+    }
+
     void dump() const {
         for (size_t i = 0; i < size(); ++i) {
             std::cout << (int)((value_type*)&vals)[i] << " ";
@@ -237,6 +323,8 @@
  private:
     Vec256() {}
 
+    Vec256(__m256i vals_) : vals(vals_) {}
+
     // Load from memory constructor
     Vec256(const void* ptr) {
         vals = _mm256_loadu_si256((const __m256i*)ptr);
@@ -295,6 +383,48 @@
         return retval;
     }
 
+    Vec256<c10::qint32> relu(Vec256<c10::qint32> zero_point) {
+#ifdef __AVX2__
+      return _mm256_max_epi32(vals, zero_point.vals);
+#else
+      // Pray the compiler can autovectorize this
+      int32_t int_vals[size()];
+      _mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals);
+      int32_t zero_point_vals[size()];
+      _mm256_storeu_si256(
+          reinterpret_cast<__m256i*>(&zero_point_vals), zero_point.vals);
+      int32_t result_vals[size()];
+      for (size_t i = 0; i < size(); ++i) {
+        result_vals[i] = std::max<int32_t>(int_vals[i], zero_point_vals[i]);
+      }
+      return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals));
+#endif
+    }
+
+    Vec256<c10::qint32> relu6(
+        Vec256<c10::qint32> zero_point,
+        Vec256<c10::qint32> q_six) {
+#ifdef __AVX2__
+      return _mm256_min_epi8(
+          _mm256_max_epi8(vals, zero_point.vals), q_six.vals);
+#else
+      // Pray the compiler can autovectorize this
+      int32_t int_vals[size()];
+      _mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals);
+      int32_t zero_point_vals[size()];
+      _mm256_storeu_si256(
+          reinterpret_cast<__m256i*>(&zero_point_vals), zero_point.vals);
+      int32_t q_six_vals[size()];
+      _mm256_storeu_si256(reinterpret_cast<__m256i*>(&q_six_vals), q_six.vals);
+      int32_t result_vals[size()];
+      for (size_t i = 0; i < size(); ++i) {
+        result_vals[i] = std::min<int32_t>(
+            std::max<int32_t>(int_vals[i], zero_point_vals[i]), q_six_vals[i]);
+      }
+      return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals));
+#endif
+    }
+
     void dump() const {
         for (size_t i = 0; i < 8; ++i) {
           std::cout << ((int32_t*)&vals)[i] << " ";
@@ -304,6 +434,8 @@
  private:
     Vec256() {}
 
+    Vec256(__m256i vals_) : vals(vals_) {}
+
     // Load from memory constructor
     Vec256(const void* ptr) {
       vals = _mm256_loadu_si256((const __m256i*)ptr);
@@ -368,6 +500,9 @@
       }
       std::cout << std::endl;
   }
+
+ protected:
+  Vec256QuantizedConverter() {}
 };
 
 template <>
@@ -406,6 +541,28 @@
 
     return Vec256<c10::qint8>::loadu(qvals);
   }
+
+  Vec256<c10::qint8> relu(Vec256<c10::qint8> zero_point) {
+    Vec256<c10::qint8> retval;
+    for (size_t i = 0; i < size(); ++i) {
+      retval.vals[i] = std::max<value_type>(vals[i], zero_point.vals[i]);
+    }
+    return retval;
+  }
+
+  Vec256<c10::qint8> relu6(
+      Vec256<c10::qint8> zero_point,
+      Vec256<c10::qint8> q_six) {
+    Vec256<c10::qint8> retval;
+    for (size_t i = 0; i < size(); ++i) {
+      retval.vals[i] = std::min<value_type>(
+          std::max<value_type>(vals[i], zero_point.vals[i]), q_six.vals[i]);
+    }
+    return retval;
+  }
+
+ private:
+  Vec256() {}
 };
 
 template <>
@@ -444,6 +601,28 @@
 
     return Vec256<c10::quint8>::loadu(qvals);
   }
+
+  Vec256<c10::quint8> relu(Vec256<c10::quint8> zero_point) {
+    Vec256<c10::quint8> retval;
+    for (size_t i = 0; i < size(); ++i) {
+      retval.vals[i] = std::max<value_type>(vals[i], zero_point.vals[i]);
+    }
+    return retval;
+  }
+
+  Vec256<c10::quint8> relu6(
+      Vec256<c10::quint8> zero_point,
+      Vec256<c10::quint8> q_six) {
+    Vec256<c10::quint8> retval;
+    for (size_t i = 0; i < size(); ++i) {
+      retval.vals[i] = std::min<value_type>(
+          std::max<value_type>(vals[i], zero_point.vals[i]), q_six.vals[i]);
+    }
+    return retval;
+  }
+
+ private:
+  Vec256() {}
 };
 
 template <>
@@ -482,6 +661,28 @@
 
     return Vec256<c10::qint32>::loadu(qvals);
   }
+
+  Vec256<c10::qint32> relu(Vec256<c10::qint32> zero_point) {
+    Vec256<c10::qint32> retval;
+    for (size_t i = 0; i < size(); ++i) {
+      retval.vals[i] = std::max<value_type>(vals[i], zero_point.vals[i]);
+    }
+    return retval;
+  }
+
+  Vec256<c10::qint32> relu6(
+      Vec256<c10::qint32> zero_point,
+      Vec256<c10::qint32> q_six) {
+    Vec256<c10::qint32> retval;
+    for (size_t i = 0; i < size(); ++i) {
+      retval.vals[i] = std::min<value_type>(
+          std::max<value_type>(vals[i], zero_point.vals[i]), q_six.vals[i]);
+    }
+    return retval;
+  }
+
+ private:
+  Vec256() {}
 };
 
 #endif // defined(__AVX__) && !defined(_MSC_VER)
diff --git a/aten/src/ATen/native/quantized/cpu/qrelu.cpp b/aten/src/ATen/native/quantized/cpu/qrelu.cpp
index 40a7a22..aa0a0c6 100644
--- a/aten/src/ATen/native/quantized/cpu/qrelu.cpp
+++ b/aten/src/ATen/native/quantized/cpu/qrelu.cpp
@@ -19,21 +19,30 @@
         qx.q_scale(),
         qx.q_zero_point(),
         qx.suggest_memory_format());
+    using Vec = Vec256<scalar_t>;
     auto iter = TensorIterator::unary_op(qy, qx);
-    cpu_kernel(iter, [&](scalar_t value) -> scalar_t {
-      return scalar_t(std::max<underlying_t>(value.val_, zero_point));
-    });
+    auto zero_point_vec = Vec(scalar_t(zero_point));
+    cpu_kernel_vec(
+        iter,
+        [&](scalar_t value) -> scalar_t {
+          return scalar_t(std::max<underlying_t>(value.val_, zero_point));
+        },
+        [&](Vec value) -> Vec { return value.relu(zero_point_vec); });
   });
   return qy;
 }
-
 Tensor& quantized_relu_(Tensor& qx) {
   const auto zero_point = qx.q_zero_point();
   AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qrelu", [&]() {
+    using Vec = Vec256<scalar_t>;
     auto iter = TensorIterator::unary_op(qx, qx);
-    cpu_kernel(iter, [&](scalar_t value) -> scalar_t {
-      return scalar_t(std::max<underlying_t>(value.val_, zero_point));
-    });
+    auto zero_point_vec = Vec(scalar_t(zero_point));
+    cpu_kernel_vec(
+        iter,
+        [&](scalar_t value) -> scalar_t {
+          return scalar_t(std::max<underlying_t>(value.val_, zero_point));
+        },
+        [&](Vec value) -> Vec { return value.relu(zero_point_vec); });
   });
   return qx;
 }
@@ -49,13 +58,20 @@
         qx.q_scale(),
         qx.q_zero_point(),
         qx.suggest_memory_format());
+    using Vec = Vec256<scalar_t>;
     auto iter = TensorIterator::unary_op(qy, qx);
     scalar_t six = at::quantize_val<scalar_t>(qx.q_scale(), qx.q_zero_point(),
                                               6.0);
-    cpu_kernel(iter, [&](scalar_t value) -> scalar_t {
-      underlying_t relu_val = std::max<underlying_t>(value.val_, zero_point);
-      return scalar_t(std::min<underlying_t>(relu_val, six.val_));
-    });
+    auto zero_point_vec = Vec(scalar_t(zero_point));
+    auto six_vec = Vec(six);
+    cpu_kernel_vec(
+        iter,
+        [&](scalar_t value) -> scalar_t {
+          underlying_t relu_val =
+              std::max<underlying_t>(value.val_, zero_point);
+          return scalar_t(std::min<underlying_t>(relu_val, six.val_));
+        },
+        [&](Vec val) { return val.relu6(zero_point_vec, six_vec); });
   });
   return qy;
 }