Vectorized quantized relu/relu6 (#25496)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25496
Benchmark Script
```
import torch, time
sizes = [
(1, 56, 56, 256),
(1, 28, 28, 512),
(1, 14, 14, 1024),
(1, 7, 7, 2048),
]
NITER = 1000
for dtype in [torch.qint8, torch.quint8, torch.qint32]:
print('*****', str(dtype), '*****')
print('\t*****relu*****')
print('\tsize',
'time (float ms)',
'time (quant ms)',
'quant / float',
sep='\t')
for size in sizes:
# NHWC
x = torch.rand(*size)
# NCHW
x = x.permute([0, 2, 3, 1])
# Test float
s = time.time()
for i in range(NITER):
torch.relu(x)
time_per_iter_float = (time.time() - s) / NITER
# Test quantized
q_x = torch.quantize_linear(x, 0.5, 1, dtype)
s = time.time()
for i in range(NITER):
torch.relu(q_x)
time_per_iter_quant = (time.time() - s) / NITER
print('\t',
size,
time_per_iter_float * 1000,
time_per_iter_quant * 1000,
time_per_iter_quant / time_per_iter_float,
sep='\t')
print('\t*****relu6*****')
print('\tsize',
'time (float ms)',
'time (quant ms)',
'quant / float',
sep='\t')
for size in sizes:
# NHWC
x = torch.rand(*size)
# NCHW
x = x.permute([0, 2, 3, 1])
# Test float relu6
s = time.time()
for i in range(NITER):
torch._C._nn.hardtanh(x, 0., 6.)
time_per_iter_float_6 = (time.time() - s) / NITER
# Test quantized relu6
q_x = torch.quantize_linear(x, 0.5, 1, dtype)
s = time.time()
for i in range(NITER):
torch.ops.quantized.relu6(q_x)
time_per_iter_quant_6 = (time.time() - s) / NITER
print('\t',
size,
time_per_iter_float_6 * 1000,
time_per_iter_quant_6 * 1000,
time_per_iter_quant_6 / time_per_iter_float_6,
sep='\t')
```
Before this change (AVX2)
```
$ OMP_NUM_THREADS=1 python relu_bench.py
***** torch.qint8 *****
*****relu*****
size time (float ms) time (quant ms) quant / float
(1, 56, 56, 256) 0.28845906257629395 0.32473158836364746 1.1257458353479874
(1, 28, 28, 512) 0.12658190727233887 0.1621997356414795 1.2813816692816096
(1, 14, 14, 1024) 0.060466766357421875 0.08151435852050781 1.3480852943031985
(1, 7, 7, 2048) 0.021933555603027344 0.04172706604003906 1.9024305404582809
*****relu6*****
size time (float ms) time (quant ms) quant / float
(1, 56, 56, 256) 1.0264298915863037 0.4686436653137207 0.45657640054641424
(1, 28, 28, 512) 0.4577608108520508 0.23253798484802246 0.5079901541051298
(1, 14, 14, 1024) 0.22967290878295898 0.11695981025695801 0.509245129853278
(1, 7, 7, 2048) 0.12731575965881348 0.060141801834106445 0.4723830105187069
***** torch.quint8 *****
*****relu*****
size time (float ms) time (quant ms) quant / float
(1, 56, 56, 256) 0.28515172004699707 0.32268643379211426 1.1316306762551913
(1, 28, 28, 512) 0.1268613338470459 0.1618938446044922 1.2761480562681475
(1, 14, 14, 1024) 0.06022787094116211 0.08164644241333008 1.355625578946535
(1, 7, 7, 2048) 0.018331527709960938 0.04460000991821289 2.432967433149516
*****relu6*****
size time (float ms) time (quant ms) quant / float
(1, 56, 56, 256) 1.027123212814331 0.5206699371337891 0.50692062124382
(1, 28, 28, 512) 0.4589383602142334 0.25958728790283203 0.565625605542444
(1, 14, 14, 1024) 0.23261427879333496 0.13058066368103027 0.561361341867771
(1, 7, 7, 2048) 0.13072657585144043 0.06684517860412598 0.5113358027528374
***** torch.qint32 *****
*****relu*****
size time (float ms) time (quant ms) quant / float
(1, 56, 56, 256) 0.285900354385376 0.44794583320617676 1.5667900593168678
(1, 28, 28, 512) 0.12691712379455566 0.21081137657165527 1.6610160258035915
(1, 14, 14, 1024) 0.05957603454589844 0.10731720924377441 1.8013486473507283
(1, 7, 7, 2048) 0.01675701141357422 0.05678510665893555 3.388737123669683
*****relu6*****
size time (float ms) time (quant ms) quant / float
(1, 56, 56, 256) 1.0314903259277344 0.6447939872741699 0.6251090980366052
(1, 28, 28, 512) 0.4572310447692871 0.3106963634490967 0.6795172090859886
(1, 14, 14, 1024) 0.2294166088104248 0.1586904525756836 0.6917130080447454
(1, 7, 7, 2048) 0.12760710716247559 0.07992196083068848 0.6263127705647926
```
After this change (AVX2)
```
$ OMP_NUM_THREADS=1 python relu_bench.py
***** torch.qint8 *****
*****relu*****
size time (float ms) time (quant ms) quant / float
(1, 56, 56, 256) 0.2889232635498047 0.06460881233215332 0.22361928056034167
(1, 28, 28, 512) 0.13853216171264648 0.013955354690551758 0.10073729102343015
(1, 14, 14, 1024) 0.0721442699432373 0.007253408432006836 0.10054032617855548
(1, 7, 7, 2048) 0.015225648880004883 0.004289150238037109 0.28170557930505313
*****relu6*****
size time (float ms) time (quant ms) quant / float
(1, 56, 56, 256) 1.042311191558838 0.06422209739685059 0.061615089540392104
(1, 28, 28, 512) 0.46384429931640625 0.01335287094116211 0.028787399049295198
(1, 14, 14, 1024) 0.2301616668701172 0.007760286331176758 0.033716675920477994
(1, 7, 7, 2048) 0.12573981285095215 0.004631757736206055 0.03683604763827976
***** torch.quint8 *****
*****relu*****
size time (float ms) time (quant ms) quant / float
(1, 56, 56, 256) 0.2877991199493408 0.0571134090423584 0.1984488661828141
(1, 28, 28, 512) 0.12664175033569336 0.013076543807983398 0.10325618347283565
(1, 14, 14, 1024) 0.06389951705932617 0.005294084548950195 0.08285014961904974
(1, 7, 7, 2048) 0.016280174255371094 0.003660917282104492 0.22486966199988284
*****relu6*****
size time (float ms) time (quant ms) quant / float
(1, 56, 56, 256) 1.0244698524475098 0.05978655815124512 0.05835853344870231
(1, 28, 28, 512) 0.454937219619751 0.013289213180541992 0.02921109244842504
(1, 14, 14, 1024) 0.22972846031188965 0.0077877044677734375 0.03389960676705229
(1, 7, 7, 2048) 0.125657320022583 0.0045795440673828125 0.03644470586003093
***** torch.qint32 *****
*****relu*****
size time (float ms) time (quant ms) quant / float
(1, 56, 56, 256) 0.28399205207824707 0.2665698528289795 0.9386525111468004
(1, 28, 28, 512) 0.12665152549743652 0.12166023254394531 0.9605903447756557
(1, 14, 14, 1024) 0.0598299503326416 0.059305429458618164 0.9912331387355795
(1, 7, 7, 2048) 0.014290809631347656 0.012906551361083984 0.9031364698031366
*****relu6*****
size time (float ms) time (quant ms) quant / float
(1, 56, 56, 256) 1.020923376083374 0.27229976654052734 0.2667191024513184
(1, 28, 28, 512) 0.4564201831817627 0.12390279769897462 0.2714665176181136
(1, 14, 14, 1024) 0.23244047164916992 0.05935955047607422 0.25537527976482316
(1, 7, 7, 2048) 0.1271505355834961 0.014976024627685547 0.11778184463762029
```
Test Plan: Imported from OSS
Differential Revision: D17141891
Pulled By: jamesr66a
fbshipit-source-id: 14b8c3330017c518a6b385780a449ca51efef0ce
diff --git a/aten/src/ATen/cpu/vec256/vec256_qint.h b/aten/src/ATen/cpu/vec256/vec256_qint.h
index b02c94c..b64c611 100644
--- a/aten/src/ATen/cpu/vec256/vec256_qint.h
+++ b/aten/src/ATen/cpu/vec256/vec256_qint.h
@@ -128,6 +128,48 @@
return retval;
}
+ Vec256<c10::qint8> relu(Vec256<c10::qint8> zero_point) {
+#ifdef __AVX2__
+ return _mm256_max_epi8(vals, zero_point.vals);
+#else
+ // Pray the compiler can autovectorize this
+ int8_t int_vals[size()];
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals);
+ int8_t zero_point_vals[size()];
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(&zero_point_vals), zero_point.vals);
+ int8_t result_vals[size()];
+ for (size_t i = 0; i < size(); ++i) {
+ result_vals[i] = std::max<int8_t>(int_vals[i], zero_point_vals[i]);
+ }
+ return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals));
+#endif
+ }
+
+ Vec256<c10::qint8> relu6(
+ Vec256<c10::qint8> zero_point,
+ Vec256<c10::qint8> q_six) {
+#ifdef __AVX2__
+ return _mm256_min_epi8(
+ _mm256_max_epi8(vals, zero_point.vals), q_six.vals);
+#else
+ // Pray the compiler can autovectorize this
+ int8_t int_vals[size()];
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals);
+ int8_t zero_point_vals[size()];
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(&zero_point_vals), zero_point.vals);
+ int8_t q_six_vals[size()];
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(&q_six_vals), q_six.vals);
+ int8_t result_vals[size()];
+ for (size_t i = 0; i < size(); ++i) {
+ result_vals[i] = std::min<int8_t>(
+ std::max<int8_t>(int_vals[i], zero_point_vals[i]), q_six_vals[i]);
+ }
+ return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals));
+#endif
+ }
+
void dump() const {
for (size_t i = 0; i < size(); ++i) {
std::cout << (int)((value_type*)&vals)[i] << " ";
@@ -137,6 +179,8 @@
private:
Vec256() {}
+ Vec256(__m256i vals_) : vals(vals_) {}
+
// Load from memory constructor
Vec256(const void* ptr) {
vals = _mm256_loadu_si256((const __m256i*)ptr);
@@ -228,6 +272,48 @@
return retval;
}
+ Vec256<c10::quint8> relu(Vec256<c10::quint8> zero_point) {
+#ifdef __AVX2__
+ return _mm256_max_epu8(vals, zero_point.vals);
+#else
+ // Pray the compiler can autovectorize this
+ uint8_t int_vals[size()];
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals);
+ uint8_t zero_point_vals[size()];
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(&zero_point_vals), zero_point.vals);
+ uint8_t result_vals[size()];
+ for (size_t i = 0; i < size(); ++i) {
+ result_vals[i] = std::max<uint8_t>(int_vals[i], zero_point_vals[i]);
+ }
+ return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals));
+#endif
+ }
+
+ Vec256<c10::quint8> relu6(
+ Vec256<c10::quint8> zero_point,
+ Vec256<c10::quint8> q_six) {
+#ifdef __AVX2__
+ return _mm256_min_epi8(
+ _mm256_max_epi8(vals, zero_point.vals), q_six.vals);
+#else
+ // Pray the compiler can autovectorize this
+ uint8_t int_vals[size()];
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals);
+ uint8_t zero_point_vals[size()];
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(&zero_point_vals), zero_point.vals);
+ uint8_t q_six_vals[size()];
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(&q_six_vals), q_six.vals);
+ uint8_t result_vals[size()];
+ for (size_t i = 0; i < size(); ++i) {
+ result_vals[i] = std::min<uint8_t>(
+ std::max<uint8_t>(int_vals[i], zero_point_vals[i]), q_six_vals[i]);
+ }
+ return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals));
+#endif
+ }
+
void dump() const {
for (size_t i = 0; i < size(); ++i) {
std::cout << (int)((value_type*)&vals)[i] << " ";
@@ -237,6 +323,8 @@
private:
Vec256() {}
+ Vec256(__m256i vals_) : vals(vals_) {}
+
// Load from memory constructor
Vec256(const void* ptr) {
vals = _mm256_loadu_si256((const __m256i*)ptr);
@@ -295,6 +383,48 @@
return retval;
}
+ Vec256<c10::qint32> relu(Vec256<c10::qint32> zero_point) {
+#ifdef __AVX2__
+ return _mm256_max_epi32(vals, zero_point.vals);
+#else
+ // Pray the compiler can autovectorize this
+ int32_t int_vals[size()];
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals);
+ int32_t zero_point_vals[size()];
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(&zero_point_vals), zero_point.vals);
+ int32_t result_vals[size()];
+ for (size_t i = 0; i < size(); ++i) {
+ result_vals[i] = std::max<int32_t>(int_vals[i], zero_point_vals[i]);
+ }
+ return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals));
+#endif
+ }
+
+ Vec256<c10::qint32> relu6(
+ Vec256<c10::qint32> zero_point,
+ Vec256<c10::qint32> q_six) {
+#ifdef __AVX2__
+ return _mm256_min_epi8(
+ _mm256_max_epi8(vals, zero_point.vals), q_six.vals);
+#else
+ // Pray the compiler can autovectorize this
+ int32_t int_vals[size()];
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(&int_vals), vals);
+ int32_t zero_point_vals[size()];
+ _mm256_storeu_si256(
+ reinterpret_cast<__m256i*>(&zero_point_vals), zero_point.vals);
+ int32_t q_six_vals[size()];
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(&q_six_vals), q_six.vals);
+ int32_t result_vals[size()];
+ for (size_t i = 0; i < size(); ++i) {
+ result_vals[i] = std::min<int32_t>(
+ std::max<int32_t>(int_vals[i], zero_point_vals[i]), q_six_vals[i]);
+ }
+ return _mm256_loadu_si256(reinterpret_cast<__m256i*>(&result_vals));
+#endif
+ }
+
void dump() const {
for (size_t i = 0; i < 8; ++i) {
std::cout << ((int32_t*)&vals)[i] << " ";
@@ -304,6 +434,8 @@
private:
Vec256() {}
+ Vec256(__m256i vals_) : vals(vals_) {}
+
// Load from memory constructor
Vec256(const void* ptr) {
vals = _mm256_loadu_si256((const __m256i*)ptr);
@@ -368,6 +500,9 @@
}
std::cout << std::endl;
}
+
+ protected:
+ Vec256QuantizedConverter() {}
};
template <>
@@ -406,6 +541,28 @@
return Vec256<c10::qint8>::loadu(qvals);
}
+
+ Vec256<c10::qint8> relu(Vec256<c10::qint8> zero_point) {
+ Vec256<c10::qint8> retval;
+ for (size_t i = 0; i < size(); ++i) {
+ retval.vals[i] = std::max<value_type>(vals[i], zero_point.vals[i]);
+ }
+ return retval;
+ }
+
+ Vec256<c10::qint8> relu6(
+ Vec256<c10::qint8> zero_point,
+ Vec256<c10::qint8> q_six) {
+ Vec256<c10::qint8> retval;
+ for (size_t i = 0; i < size(); ++i) {
+ retval.vals[i] = std::min<value_type>(
+ std::max<value_type>(vals[i], zero_point.vals[i]), q_six.vals[i]);
+ }
+ return retval;
+ }
+
+ private:
+ Vec256() {}
};
template <>
@@ -444,6 +601,28 @@
return Vec256<c10::quint8>::loadu(qvals);
}
+
+ Vec256<c10::quint8> relu(Vec256<c10::quint8> zero_point) {
+ Vec256<c10::quint8> retval;
+ for (size_t i = 0; i < size(); ++i) {
+ retval.vals[i] = std::max<value_type>(vals[i], zero_point.vals[i]);
+ }
+ return retval;
+ }
+
+ Vec256<c10::quint8> relu6(
+ Vec256<c10::quint8> zero_point,
+ Vec256<c10::quint8> q_six) {
+ Vec256<c10::quint8> retval;
+ for (size_t i = 0; i < size(); ++i) {
+ retval.vals[i] = std::min<value_type>(
+ std::max<value_type>(vals[i], zero_point.vals[i]), q_six.vals[i]);
+ }
+ return retval;
+ }
+
+ private:
+ Vec256() {}
};
template <>
@@ -482,6 +661,28 @@
return Vec256<c10::qint32>::loadu(qvals);
}
+
+ Vec256<c10::qint32> relu(Vec256<c10::qint32> zero_point) {
+ Vec256<c10::qint32> retval;
+ for (size_t i = 0; i < size(); ++i) {
+ retval.vals[i] = std::max<value_type>(vals[i], zero_point.vals[i]);
+ }
+ return retval;
+ }
+
+ Vec256<c10::qint32> relu6(
+ Vec256<c10::qint32> zero_point,
+ Vec256<c10::qint32> q_six) {
+ Vec256<c10::qint32> retval;
+ for (size_t i = 0; i < size(); ++i) {
+ retval.vals[i] = std::min<value_type>(
+ std::max<value_type>(vals[i], zero_point.vals[i]), q_six.vals[i]);
+ }
+ return retval;
+ }
+
+ private:
+ Vec256() {}
};
#endif // defined(__AVX__) && !defined(_MSC_VER)
diff --git a/aten/src/ATen/native/quantized/cpu/qrelu.cpp b/aten/src/ATen/native/quantized/cpu/qrelu.cpp
index 40a7a22..aa0a0c6 100644
--- a/aten/src/ATen/native/quantized/cpu/qrelu.cpp
+++ b/aten/src/ATen/native/quantized/cpu/qrelu.cpp
@@ -19,21 +19,30 @@
qx.q_scale(),
qx.q_zero_point(),
qx.suggest_memory_format());
+ using Vec = Vec256<scalar_t>;
auto iter = TensorIterator::unary_op(qy, qx);
- cpu_kernel(iter, [&](scalar_t value) -> scalar_t {
- return scalar_t(std::max<underlying_t>(value.val_, zero_point));
- });
+ auto zero_point_vec = Vec(scalar_t(zero_point));
+ cpu_kernel_vec(
+ iter,
+ [&](scalar_t value) -> scalar_t {
+ return scalar_t(std::max<underlying_t>(value.val_, zero_point));
+ },
+ [&](Vec value) -> Vec { return value.relu(zero_point_vec); });
});
return qy;
}
-
Tensor& quantized_relu_(Tensor& qx) {
const auto zero_point = qx.q_zero_point();
AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qrelu", [&]() {
+ using Vec = Vec256<scalar_t>;
auto iter = TensorIterator::unary_op(qx, qx);
- cpu_kernel(iter, [&](scalar_t value) -> scalar_t {
- return scalar_t(std::max<underlying_t>(value.val_, zero_point));
- });
+ auto zero_point_vec = Vec(scalar_t(zero_point));
+ cpu_kernel_vec(
+ iter,
+ [&](scalar_t value) -> scalar_t {
+ return scalar_t(std::max<underlying_t>(value.val_, zero_point));
+ },
+ [&](Vec value) -> Vec { return value.relu(zero_point_vec); });
});
return qx;
}
@@ -49,13 +58,20 @@
qx.q_scale(),
qx.q_zero_point(),
qx.suggest_memory_format());
+ using Vec = Vec256<scalar_t>;
auto iter = TensorIterator::unary_op(qy, qx);
scalar_t six = at::quantize_val<scalar_t>(qx.q_scale(), qx.q_zero_point(),
6.0);
- cpu_kernel(iter, [&](scalar_t value) -> scalar_t {
- underlying_t relu_val = std::max<underlying_t>(value.val_, zero_point);
- return scalar_t(std::min<underlying_t>(relu_val, six.val_));
- });
+ auto zero_point_vec = Vec(scalar_t(zero_point));
+ auto six_vec = Vec(six);
+ cpu_kernel_vec(
+ iter,
+ [&](scalar_t value) -> scalar_t {
+ underlying_t relu_val =
+ std::max<underlying_t>(value.val_, zero_point);
+ return scalar_t(std::min<underlying_t>(relu_val, six.val_));
+ },
+ [&](Vec val) { return val.relu6(zero_point_vec, six_vec); });
});
return qy;
}