[PyTorch Edge] Use Integer Subtraction (Instead of Float) in Non-FBGEMM Dequantization (#67115)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67115
This matches what FBGEMM does (https://fburl.com/code/vjrdn6tj → https://fburl.com/code/btkdn24l)
Benchmark Mobile Vision Transformer Model Results (as described in D31066997 and config from rebasing onto v4 of D31869106):
This diff (v18):
- NET latency: 109.866
- https://our.intern.facebook.com/intern/aibench/details/536304563225483
This diff before using vsubl (v14 but rebased onto v22 of D31205883, the previous diff in this stack)
- NET latency: 115.887
- https://our.intern.facebook.com/intern/aibench/details/906978557243297
Before this diff (v22 of D31205883):
- NET latency: 116.449
- https://our.intern.facebook.com/intern/aibench/details/870678436773989
ghstack-source-id: 142166375
Test Plan: Phabricator tests + Running quantized_test on a pixel3a passes and Running mobile vision transformer model (as described in D31066997) both work
Reviewed By: kimishpatel
Differential Revision: D31483135
fbshipit-source-id: fbef00cad6087b49900d21c3dd3b6fd432f64e94
diff --git a/aten/src/ATen/native/quantized/affine_quantizer_base.cpp b/aten/src/ATen/native/quantized/affine_quantizer_base.cpp
index bba09ae..dc58f60 100644
--- a/aten/src/ATen/native/quantized/affine_quantizer_base.cpp
+++ b/aten/src/ATen/native/quantized/affine_quantizer_base.cpp
@@ -153,9 +153,7 @@
template <typename T>
TORCH_API float dequantize_val(double scale, int64_t zero_point, T value) {
- // We need to convert the qint8 value to float to ensure the subtraction
- // subexpression returns a float
- return (static_cast<float>(value.val_) - zero_point) * scale;
+ return static_cast<float>(scale) * (value.val_ - static_cast<int32_t>(zero_point));
}
#endif // USE_FBGEMM
diff --git a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
index 8ee1db6..30e8e15 100644
--- a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
+++ b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
@@ -2816,26 +2816,30 @@
const int8_t* in_underlying = reinterpret_cast<const int8_t*>(in);
const float32x4_t scale_fp32x4 = vdupq_n_f32(scale);
- const float32x4_t minus_scale_times_zero_point_fp32x4 =
- vdupq_n_f32(-scale * zero_point);
+ // Zero point is restricted to be in bounds of a signed 8 bit integer
+ const int8x8_t zero_point_s8x8 = vget_low_s8(vdupq_n_s8(static_cast<int8_t>(zero_point)));
int i;
for (i = 0; i + 16 <= N; i += 16) {
const int8x16_t vin_s8 = vld1q_s8(in_underlying);
- const int16x8_t vin_low_s16 = vmovl_s8(vget_low_s8(vin_s8)); // 0 ... 7
- const int16x8_t vin_high_s16 = VMOVL_HIGH_S8(vin_s8); // 8 ... 15
+ // Extract upper or lower values to int16x8 and subtract zero point
+ // Each input element and the zero point are restricted to be in bounds of
+ // a signed 8 bit integer, so the difference will fit in a signed 16 bit
+ // integer
+ const int16x8_t minus_zp_low_s16 = vsubl_s8(vget_low_s8(vin_s8), zero_point_s8x8); // 0 ... 7
+ const int16x8_t minus_zp_high_s16 = vsubl_s8(vget_high_s8(vin_s8), zero_point_s8x8); // 8 ... 15
- const int32x4_t vin_s32_low_low = vmovl_s16(vget_low_s16(vin_low_s16)); // 0 ... 3
- const int32x4_t vin_s32_low_high = VMOVL_HIGH_S16(vin_low_s16); // 4 ... 7
- const int32x4_t vin_s32_high_low = vmovl_s16(vget_low_s16(vin_high_s16)); // 8 ... 11
- const int32x4_t vin_s32_high_high = VMOVL_HIGH_S16(vin_high_s16); // 12 ... 15
+ const int32x4_t minus_zp_low_low = vmovl_s16(vget_low_s16(minus_zp_low_s16)); // 0 ... 3
+ const int32x4_t minus_zp_low_high = VMOVL_HIGH_S16(minus_zp_low_s16); // 4 ... 7
+ const int32x4_t minus_zp_high_low = vmovl_s16(vget_low_s16(minus_zp_high_s16)); // 8 ... 11
+ const int32x4_t minus_zp_high_high = VMOVL_HIGH_S16(minus_zp_high_s16); // 12 ... 15
- // Store (... * scale) + (-scale * zero point)) int32 -> fp32
- vst1q_f32(out, vmlaq_f32(minus_scale_times_zero_point_fp32x4, vcvtq_f32_s32(vin_s32_low_low), scale_fp32x4));
- vst1q_f32(out + 4, vmlaq_f32(minus_scale_times_zero_point_fp32x4, vcvtq_f32_s32(vin_s32_low_high), scale_fp32x4));
- vst1q_f32(out + 8, vmlaq_f32(minus_scale_times_zero_point_fp32x4, vcvtq_f32_s32(vin_s32_high_low), scale_fp32x4));
- vst1q_f32(out + 12, vmlaq_f32(minus_scale_times_zero_point_fp32x4, vcvtq_f32_s32(vin_s32_high_high), scale_fp32x4));
+ // Store * scale int32->fp32
+ vst1q_f32(out, vmulq_f32(vcvtq_f32_s32(minus_zp_low_low), scale_fp32x4));
+ vst1q_f32(out + 4, vmulq_f32(vcvtq_f32_s32(minus_zp_low_high), scale_fp32x4));
+ vst1q_f32(out + 8, vmulq_f32(vcvtq_f32_s32(minus_zp_high_low), scale_fp32x4));
+ vst1q_f32(out + 12, vmulq_f32(vcvtq_f32_s32(minus_zp_high_high), scale_fp32x4));
out += 16;
in += 16;
@@ -2857,26 +2861,30 @@
const uint8_t* in_underlying = reinterpret_cast<const uint8_t*>(in);
const float32x4_t scale_fp32x4 = vdupq_n_f32(scale);
- const float32x4_t minus_scale_times_zero_point_fp32x4 =
- vdupq_n_f32(-scale * zero_point);
+ // Zero point is restricted to be in bounds of an unsigned 8 bit integer
+ const uint8x8_t zero_point_u8x8 = vget_low_u8(vdupq_n_u8(static_cast<uint8_t>(zero_point)));
int i;
for (i = 0; i + 16 <= N; i += 16) {
const uint8x16_t vin_u8 = vld1q_u8(in_underlying);
- const uint16x8_t vin_low_u16 = vmovl_u8(vget_low_u8(vin_u8)); // 0 ... 7
- const uint16x8_t vin_high_u16 = VMOVL_HIGH_U8(vin_u8); // 8 ... 15
+ // Extract upper or lower values to uint16x8 and subtract zero point
+ // Each input element and the zero point are restricted to be in bounds of
+ // an unsigned 8 bit integer, so the difference will fit in a signed 16 bit
+ // integer
+ const int16x8_t minus_zp_low_s16 = vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(vin_u8), zero_point_u8x8)); // 0 ... 7
+ const int16x8_t minus_zp_high_s16 = vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(vin_u8), zero_point_u8x8)); // 8 ... 15
- const uint32x4_t vin_u32_low_low = vmovl_u16(vget_low_u16(vin_low_u16)); // 0 ... 3
- const uint32x4_t vin_u32_low_high = VMOVL_HIGH_U16(vin_low_u16); // 4 ... 7
- const uint32x4_t vin_u32_high_low = vmovl_u16(vget_low_u16(vin_high_u16)); // 8 ... 11
- const uint32x4_t vin_u32_high_high = VMOVL_HIGH_U16(vin_high_u16); // 12 ... 15
+ const int32x4_t minus_zp_low_low = vmovl_s16(vget_low_s16(minus_zp_low_s16)); // 0 ... 3
+ const int32x4_t minus_zp_low_high = VMOVL_HIGH_S16(minus_zp_low_s16); // 4 ... 7
+ const int32x4_t minus_zp_high_low = vmovl_s16(vget_low_s16(minus_zp_high_s16)); // 8 ... 11
+ const int32x4_t minus_zp_high_high = VMOVL_HIGH_S16(minus_zp_high_s16); // 12 ... 15
- // Store (... * scale) + (-scale * zero point)) uint32 -> fp32
- vst1q_f32(out, vmlaq_f32(minus_scale_times_zero_point_fp32x4, vcvtq_f32_u32(vin_u32_low_low), scale_fp32x4));
- vst1q_f32(out + 4, vmlaq_f32(minus_scale_times_zero_point_fp32x4, vcvtq_f32_u32(vin_u32_low_high), scale_fp32x4));
- vst1q_f32(out + 8, vmlaq_f32(minus_scale_times_zero_point_fp32x4, vcvtq_f32_u32(vin_u32_high_low), scale_fp32x4));
- vst1q_f32(out + 12, vmlaq_f32(minus_scale_times_zero_point_fp32x4, vcvtq_f32_u32(vin_u32_high_high), scale_fp32x4));
+ // Store * scale int32->fp32
+ vst1q_f32(out, vmulq_f32(vcvtq_f32_s32(minus_zp_low_low), scale_fp32x4));
+ vst1q_f32(out + 4, vmulq_f32(vcvtq_f32_s32(minus_zp_low_high), scale_fp32x4));
+ vst1q_f32(out + 8, vmulq_f32(vcvtq_f32_s32(minus_zp_high_low), scale_fp32x4));
+ vst1q_f32(out + 12, vmulq_f32(vcvtq_f32_s32(minus_zp_high_high), scale_fp32x4));
out += 16;
in += 16;