Cuda Support for Learnable Fake Quantize Per Tensor (GPU) (#41127)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/41127

In this diff, implementation is provided to support the GPU kernel running the learnable fake quantize per tensor kernels.

Test Plan: On a devvm, run `buck test //caffe2/test:quantization -- learnable` to test both the forward and backward for the learnable per tensor fake quantize kernels. The test will test the `cuda` version if a gpu is available.

Reviewed By: z-a-f

Differential Revision: D22435037

fbshipit-source-id: 515afde13dd224d21fd47fb7cb027ee8d704cbdd
diff --git a/aten/src/ATen/native/quantized/cuda/fake_quantize_core.cu b/aten/src/ATen/native/quantized/cuda/fake_quantize_core.cu
index 24b0796..7146ace 100644
--- a/aten/src/ATen/native/quantized/cuda/fake_quantize_core.cu
+++ b/aten/src/ATen/native/quantized/cuda/fake_quantize_core.cu
@@ -69,8 +69,60 @@
     });
 }
 
+void _fake_quantize_grad_learnable_scale_tensor_kernel_cuda(
+    Tensor& input_grad,
+    const Tensor& input,
+    const Tensor& output_grad,
+    float scale,
+    int64_t zero_point,
+    int64_t quant_min,
+    int64_t quant_max) {
+  // scalar type of this function is guaranteed to be float
+  float inv_scale = 1.0f / scale;
+  float grad_small = quant_min - zero_point;
+  float grad_big = quant_max - zero_point;
+
+  auto iter = TensorIterator::binary_op(input_grad, input, output_grad);
+  gpu_kernel(iter,
+    [=] GPU_LAMBDA (float x, float dx) -> float {
+      int64_t xq = static_cast<int64_t>(zero_point + std::nearbyint(x * inv_scale));
+      xq = std::max(std::min(xq, quant_max), quant_min);
+      if (xq == quant_min) {
+        return dx * grad_small;
+      } else if (xq == quant_max) {
+        return dx * grad_big;
+      }
+      float x_fq = static_cast<float>((xq - zero_point) * scale);
+      return dx * (x_fq - x) * inv_scale;
+    });
+}
+
+void _fake_quantize_grad_learnable_zero_point_tensor_kernel_cuda(
+    Tensor& input_grad,
+    const Tensor& input,
+    const Tensor& output_grad,
+    float scale,
+    int64_t zero_point,
+    int64_t quant_min,
+    int64_t quant_max) {
+  // scalar type of this function is guaranteed to be float
+  float inv_scale = 1.0f / scale;
+  auto iter = TensorIterator::binary_op(input_grad, input, output_grad);
+  gpu_kernel(iter,
+    [=] GPU_LAMBDA (float x, float dx) -> float {
+      int64_t xq = static_cast<int64_t>(zero_point + std::nearbyint(x * inv_scale));
+      xq = std::max(std::min(xq, quant_max), quant_min);
+      if (xq == quant_min || xq == quant_max) {
+        return dx * (-1) * scale;
+      }
+      return 0;
+    });
+}
+
 REGISTER_DISPATCH(fake_quant_tensor_stub, &fake_quantize_tensor_kernel_cuda);
 REGISTER_DISPATCH(fake_quant_grad_tensor_stub, &fake_quantize_grad_tensor_kernel_cuda);
+REGISTER_DISPATCH(fake_quant_grad_learnable_scale_tensor_stub, &_fake_quantize_grad_learnable_scale_tensor_kernel_cuda);
+REGISTER_DISPATCH(fake_quant_grad_learnable_zero_point_tensor_stub, &_fake_quantize_grad_learnable_zero_point_tensor_kernel_cuda);
 
 // Fake quantize per channel
 
diff --git a/aten/src/ATen/native/quantized/fake_quant_per_tensor_affine.cpp b/aten/src/ATen/native/quantized/fake_quant_per_tensor_affine.cpp
index db4b087..5a71657 100644
--- a/aten/src/ATen/native/quantized/fake_quant_per_tensor_affine.cpp
+++ b/aten/src/ATen/native/quantized/fake_quant_per_tensor_affine.cpp
@@ -176,8 +176,8 @@
     zero_point.device().type(), dZeroPoint_vec, X, dX, scale_val, zero_point_val, quant_min, quant_max);
 
   // The total sums over the scale and zero point gradient vectors are what will be returned in the end.
-  auto dScale = dScale_vec.sum().unsqueeze(0);
-  auto dZeroPoint = dZeroPoint_vec.sum().unsqueeze(0);
+  auto dScale = dScale_vec.sum().unsqueeze(0).to(scale.device());
+  auto dZeroPoint = dZeroPoint_vec.sum().unsqueeze(0).to(zero_point.device());
 
   return std::make_tuple(dX, dScale, dZeroPoint);
 }
diff --git a/test/quantization/test_workflow_module.py b/test/quantization/test_workflow_module.py
index 72b8180..bbe3903 100644
--- a/test/quantization/test_workflow_module.py
+++ b/test/quantization/test_workflow_module.py
@@ -589,6 +589,17 @@
         self._test_learnable_forward_per_tensor(
             X, 'cpu', scale_base, zero_point_base)
 
+    @given(X=hu.tensor(shapes=hu.array_shapes(1, 5,),
+                       elements=hu.floats(-1e3, 1e3, allow_nan=False, allow_infinity=False),
+                       qparams=hu.qparams(dtypes=torch.quint8)))
+    @unittest.skipIf(not TEST_CUDA, "No gpu is not available.")
+    def test_learnable_forward_per_tensor_cuda(self, X):
+        X, (_, _, _) = X
+        scale_base = torch.normal(mean=0, std=1, size=(1,)).clamp(1e-4, 100)
+        zero_point_base = torch.normal(mean=0, std=128, size=(1,))
+        self._test_learnable_forward_per_tensor(
+            X, 'cuda', scale_base, zero_point_base)
+
     def _test_learnable_backward_per_tensor(self, X, device, scale_base, zero_point_base):
         r"""Tests the backward method with additional backprop support for scale and zero point.
         """
@@ -597,7 +608,7 @@
         for n_bits in (4, 8):
             quant_min, quant_max = 0, 2 ** n_bits - 1
 
-            X = X_base.clone().float()
+            X = X_base.clone().float().to(device)
             X.requires_grad_()
             scale_base = scale_base.to(device)
             zero_point_base = zero_point_base.to(device)
@@ -644,6 +655,18 @@
         self._test_learnable_backward_per_tensor(
             X, 'cpu', scale_base, zero_point_base)
 
+    @given(X=hu.tensor(shapes=hu.array_shapes(1, 5,),
+                       elements=hu.floats(-1e3, 1e3, allow_nan=False, allow_infinity=False),
+                       qparams=hu.qparams(dtypes=torch.quint8)))
+    @unittest.skipIf(not TEST_CUDA, "No gpu is not available.")
+    def test_learnable_backward_per_tensor_cuda(self, X):
+        torch.random.manual_seed(NP_RANDOM_SEED)
+        X, (_, _, _) = X
+        scale_base = torch.normal(mean=0, std=1, size=(1,)).clamp(1e-4, 100)
+        zero_point_base = torch.normal(mean=0, std=128, size=(1,))
+        self._test_learnable_backward_per_tensor(
+            X, 'cuda', scale_base, zero_point_base)
+
     @given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
            X=hu.tensor(shapes=hu.array_shapes(1, 5,),
                        qparams=hu.qparams(dtypes=torch.quint8)))