[pyper][emb][quantization] Support emb trained in FP16 (#60736)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60736

Add support of embedding with input data type as float16, utilize new kernel functions added in fbgemm https://github.com/pytorch/FBGEMM/pull/616

Test Plan: `buck test caffe2/test/:quantization -- test_embedding_bag`

Reviewed By: supriyar

Differential Revision: D29392320

fbshipit-source-id: 0a120b3a58b6cf1d84961831097e9581ffd2b591
diff --git a/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp
index ed41b85..edd1e2c 100644
--- a/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp
+++ b/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp
@@ -1,3 +1,4 @@
+#include <c10/core/ScalarType.h>
 #include <ATen/ATen.h>
 #include <ATen/Parallel.h>
 #include <ATen/native/quantized/cpu/embedding_packed_params.h>
@@ -197,6 +198,10 @@
   // packed_weights = torch.ops.quantized.embedding_bag_byte_prepack(weights)
   // assert(packed_weights.size() == torch.Size([2, 10, 11]))
 
+  TORCH_CHECK(
+    weight.scalar_type() == at::ScalarType::Float || weight.scalar_type() == at::ScalarType::Half,
+    "'embedding_bag_byte_prepack' only support float32 or float16.");
+
   const auto weight_sizes = weight.sizes();
   const auto cols_dim = weight_sizes.size() - 1;
   const int32_t embedding_rows = c10::size_to_dim_(cols_dim, weight_sizes);
@@ -205,7 +210,6 @@
   const int32_t output_columns = embedding_cols + 2 * sizeof(float);
   Tensor weight_contig = weight.contiguous(weight.suggest_memory_format());
 
-  const float* weight_data = weight_contig.data_ptr<float>();
   // Adjust output dimensions to account for FP32 scale and zero_points.
   std::vector<int64_t> output_shape = weight_sizes.vec();
   output_shape[cols_dim] = output_columns;
@@ -218,17 +222,34 @@
   auto* output_data = output.data_ptr<uint8_t>();
 
 #ifdef USE_FBGEMM
-
-  at::parallel_for(
-      0, embedding_rows, 1, [&](int32_t start_idx, int32_t end_idx) {
-        for (int64_t row = start_idx; row < end_idx; ++row) {
-          fbgemm::FloatOrHalfToFused8BitRowwiseQuantizedSBFloat<float>(
-            weight_data + row * embedding_cols, 1,
-              embedding_cols, output_data + row * output_columns);
-        }
-      });
+  if (weight.scalar_type() == at::ScalarType::Half) {
+    const auto weight_data = static_cast<fbgemm::float16*>(weight.data_ptr());
+    at::parallel_for(
+        0, embedding_rows, 1, [&](int32_t start_idx, int32_t end_idx) {
+          for (int64_t row = start_idx; row < end_idx; ++row) {
+            fbgemm::FloatOrHalfToFused8BitRowwiseQuantizedSBFloat<fbgemm::float16>(
+              weight_data + row * embedding_cols, 1,
+                embedding_cols, output_data + row * output_columns);
+          }
+        });
+  }
+  else {
+    const auto weight_data = weight.data_ptr<float>();
+    at::parallel_for(
+        0, embedding_rows, 1, [&](int32_t start_idx, int32_t end_idx) {
+          for (int64_t row = start_idx; row < end_idx; ++row) {
+            fbgemm::FloatOrHalfToFused8BitRowwiseQuantizedSBFloat<float>(
+              weight_data + row * embedding_cols, 1,
+                embedding_cols, output_data + row * output_columns);
+          }
+        });
+  }
 
 #else
+  const auto float_weight = weight_contig.scalar_type() == at::ScalarType::Half
+    ? weight_contig.to(at::ScalarType::Float)
+    : weight_contig;
+  const auto weight_data = float_weight.data_ptr<float>();
   constexpr float kEpsilon = 1e-8f;
   for (std::size_t row = 0; row < embedding_rows; ++row) {
     const float* input_row = weight_data + row * embedding_cols;
@@ -262,12 +283,15 @@
     const bool optimized_qparams,
     const int64_t nbins,
     const double ratio) {
+  TORCH_CHECK(
+    weight.scalar_type() == at::ScalarType::Float || weight.scalar_type() == at::ScalarType::Half,
+    "'qembeddingbag_nbit_prepack' only support float32 or float16.");
+
   int64_t embedding_rows = weight.size(0);
   int64_t embedding_cols = weight.size(1);
 
   Tensor weight_contig = weight.contiguous(weight.suggest_memory_format());
 
-  const auto weight_data = weight.data_ptr<float>();
   TORCH_CHECK(
       bit_width == 4 || bit_width == 2,
       "bit_width must be either 2 or 4 to use 'qembeddingbag_nbit_prepack'."
@@ -299,18 +323,35 @@
 
 #ifdef USE_FBGEMM
   if (!optimized_qparams) {
-    at::parallel_for(
-      0, embedding_rows, 1, [&](int32_t start_idx, int32_t end_idx) {
-        for (int64_t row = start_idx; row < end_idx; ++row) {
-          fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<float>(
-            bit_width, weight_data + row * embedding_cols, 1,
-            embedding_cols, output_data + row * output_shape[1]);
-        }
-      });
+    if (weight.scalar_type() == at::ScalarType::Half) {
+      const auto weight_data = static_cast<fbgemm::float16*>(weight.data_ptr());
+      at::parallel_for(
+        0, embedding_rows, 1, [&](int32_t start_idx, int32_t end_idx) {
+          for (int64_t row = start_idx; row < end_idx; ++row) {
+            fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<fbgemm::float16>(
+              bit_width, weight_data + row * embedding_cols, 1,
+              embedding_cols, output_data + row * output_shape[1]);
+          }
+        });
+    }
+    else {
+      const auto weight_data = weight.data_ptr<float>();
+      at::parallel_for(
+        0, embedding_rows, 1, [&](int32_t start_idx, int32_t end_idx) {
+          for (int64_t row = start_idx; row < end_idx; ++row) {
+            fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<float>(
+              bit_width, weight_data + row * embedding_cols, 1,
+              embedding_cols, output_data + row * output_shape[1]);
+          }
+        });
+    }
   } else {
 #endif // USE_FBGEMM
     const auto output_columns = output.size(output.dim() - 1);
-
+    const auto float_weight = weight_contig.scalar_type() == at::ScalarType::Half
+        ? weight_contig.to(at::ScalarType::Float)
+        : weight_contig;
+    const auto weight_data = float_weight.data_ptr<float>();
     for (int row = 0; row < embedding_rows; ++row) {
       const float* input_row = weight_data + row * embedding_cols;
       std::uint8_t* output_row = output_data + row * output_columns;
@@ -320,7 +361,7 @@
       if (optimized_qparams) {
         at::Tensor xmax_tensor, xmin_tensor;
         std::tie(xmax_tensor, xmin_tensor) = at::choose_qparams_optimized(
-            weight_contig[row], embedding_cols, nbins, ratio, bit_width);
+            float_weight[row], embedding_cols, nbins, ratio, bit_width);
         TORCH_CHECK(
             xmax_tensor.numel() == 1 && xmin_tensor.numel() == 1,
             "Expected choose_qparams_optimized to return min/max tensors of size 1");
diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py
index 1aff284..517b23e 100644
--- a/test/quantization/core/test_quantized_op.py
+++ b/test/quantization/core/test_quantized_op.py
@@ -3142,9 +3142,9 @@
 @unittest.skipIf(sys.platform == "darwin", "Known test failure on Mac.")
 class TestQuantizedEmbeddingOps(TestCase):
     def _test_embedding_bag_unpack_fn(self, pack_fn, unpack_fn, num_embeddings, embedding_dim, bit_rate, optimized_qparams,
-                                      num_batches):
+                                      num_batches, data_type=np.float32):
         weights = torch.from_numpy((np.random.random_sample((
-            num_batches, num_embeddings, embedding_dim)).squeeze() + 1).astype(np.float32))
+            num_batches, num_embeddings, embedding_dim)).squeeze() + 1).astype(data_type))
         qtype = torch.quint8
         if bit_rate == 8:
             w_packed = pack_fn(weights)
@@ -3152,7 +3152,9 @@
             w_packed = pack_fn(weights, optimized_qparams=optimized_qparams)
         w_unpacked = unpack_fn(w_packed)
 
-        if bit_rate == 8 or bit_rate == 4:
+        if (bit_rate == 8 or bit_rate == 4) and data_type != np.float16:
+            # torch.quantize_per_channel does not support float16 yet.
+
             obs_weights = weights
             # Combine 3D embeddings (e.g. stacked combination of embeddings)
             # in a dimension orthogonal to channels.
@@ -3180,13 +3182,13 @@
 
         # compare against C2 to ensure numerical equivalency.
         from caffe2.python import core, workspace
-        conversion_op = "FloatToFused8BitRowwiseQuantized"
+        conversion_op = "FloatToFused8BitRowwiseQuantized" if data_type == np.float32 else "HalfFloatToFused8BitRowwiseQuantized"
         reverse_conversion_op = None
         if bit_rate == 4:
-            conversion_op = "FloatToFused4BitRowwiseQuantized"
+            conversion_op = "FloatToFused4BitRowwiseQuantized" if data_type == np.float32 else "HalfToFused4BitRowwiseQuantized"
             reverse_conversion_op = "Fused4BitRowwiseQuantizedToFloat"
         elif bit_rate == 2:
-            conversion_op = "FloatToFused2BitRowwiseQuantized"
+            conversion_op = "FloatToFused2BitRowwiseQuantized" if data_type == np.float32 else "HalfToFused2BitRowwiseQuantized"
             reverse_conversion_op = "Fused2BitRowwiseQuantizedToFloat"
 
         def get_c2_weights(weights, engine_str):
@@ -3226,32 +3228,38 @@
     """ Tests the correctness of the embedding_bag_8bit pack/unpack op against C2 """
     @given(num_embeddings=st.integers(10, 100),
            embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0),
-           num_batches=st.integers(1, 5))
-    def test_embedding_bag_byte_unpack(self, num_embeddings, embedding_dim, num_batches):
+           num_batches=st.integers(1, 5),
+           data_type=st.sampled_from([np.float32, np.float16]),)
+    def test_embedding_bag_byte_unpack(self, num_embeddings, embedding_dim, num_batches, data_type):
         pack_fn = torch.ops.quantized.embedding_bag_byte_prepack
         unpack_fn = torch.ops.quantized.embedding_bag_byte_unpack
 
-        self._test_embedding_bag_unpack_fn(pack_fn, unpack_fn, num_embeddings, embedding_dim, 8, False, num_batches)
+        self._test_embedding_bag_unpack_fn(
+            pack_fn, unpack_fn, num_embeddings, embedding_dim, 8, False, num_batches, data_type=data_type)
 
     """ Tests the correctness of the embedding_bag_4bit pack/unpack op against C2 """
     @given(num_embeddings=st.integers(10, 100),
            embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0),
-           optimized_qparams=st.booleans(),)
-    def test_embedding_bag_4bit_unpack(self, num_embeddings, embedding_dim, optimized_qparams):
+           optimized_qparams=st.booleans(),
+           data_type=st.sampled_from([np.float32, np.float16]),)
+    def test_embedding_bag_4bit_unpack(self, num_embeddings, embedding_dim, optimized_qparams, data_type):
         pack_fn = torch.ops.quantized.embedding_bag_4bit_prepack
         unpack_fn = torch.ops.quantized.embedding_bag_4bit_unpack
 
-        self._test_embedding_bag_unpack_fn(pack_fn, unpack_fn, num_embeddings, embedding_dim, 4, optimized_qparams, 1)
+        self._test_embedding_bag_unpack_fn(
+            pack_fn, unpack_fn, num_embeddings, embedding_dim, 4, optimized_qparams, 1, data_type=data_type)
 
     """ Tests the correctness of the embedding_bag_2bit pack/unpack op against C2 """
     @given(num_embeddings=st.integers(10, 100),
            embedding_dim=st.integers(5, 50).filter(lambda x: x % 8 == 0),
-           optimized_qparams=st.booleans(),)
-    def test_embedding_bag_2bit_unpack(self, num_embeddings, embedding_dim, optimized_qparams):
+           optimized_qparams=st.booleans(),
+           data_type=st.sampled_from([np.float32, np.float16]),)
+    def test_embedding_bag_2bit_unpack(self, num_embeddings, embedding_dim, optimized_qparams, data_type):
         pack_fn = torch.ops.quantized.embedding_bag_2bit_prepack
         unpack_fn = torch.ops.quantized.embedding_bag_2bit_unpack
 
-        self._test_embedding_bag_unpack_fn(pack_fn, unpack_fn, num_embeddings, embedding_dim, 2, optimized_qparams, 1)
+        self._test_embedding_bag_unpack_fn(
+            pack_fn, unpack_fn, num_embeddings, embedding_dim, 2, optimized_qparams, 1, data_type=data_type)
 
 
     def embedding_bag_rowwise_offsets_run(