Fix embedding quantization issue when memory format is not `contiguous ` (#82605)

Summary:
The current implementation of embedding quantization has the assumption that the memory address must be `contiguous` for the input `Tensor`

To guarantee that, we cast the input `weight` to be `contiguous` format, by

```
  const auto weight_contig =
      weight.expect_contiguous(weight.suggest_memory_format());
```
or
```
  Tensor weight_contig = weight.contiguous(weight.suggest_memory_format());
```

However, in the branch `USE_FBGEMM = true`, it doesn't use `weight_contig`, which gives a wrong result when the input data is not `contiguous`

Example: N2297477

Test Plan:
```
buck1 test mode/dev //caffe2/test:quantization -- --exact 'caffe2/test:quantization - test_embedding_bag_byte_unpack (quantization.core.test_quantized_op.TestQuantizedEmbeddingOps)'

buck1 test mode/dev //caffe2/test:quantization -- --exact 'caffe2/test:quantization - test_embedding_bag_2bit_unpack (quantization.core.test_quantized_op.TestQuantizedEmbeddingOps)'

buck1 test mode/dev //caffe2/test:quantization -- --exact 'caffe2/test:quantization - test_embedding_bag_4bit_unpack (quantization.core.test_quantized_op.TestQuantizedEmbeddingOps)'
```

Differential Revision: D38302116

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82605
Approved by: https://github.com/houseroad
diff --git a/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp
index b32fcf0..748e89f 100644
--- a/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp
+++ b/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp
@@ -139,7 +139,8 @@
 //
 // Python example examining a packed 8bit zero_point and scale:
 //
-// >> x = torch.from_numpy(np.array([[[10, 20], [30, 40]],[[50, 60], [70, 80]]], dtype=np.float32))
+// >> x = torch.from_numpy(np.array([[[10, 20], [30, 40]],[[50, 60], [70, 80]]],
+// dtype=np.float32))
 // >> x_packed = torch.ops.quantized.embedding_bag_byte_prepack(x)
 //
 // # Pull out and examine packed scales, zero_points and values
@@ -228,8 +229,9 @@
   auto* output_data = output.data_ptr<uint8_t>();
 
 #ifdef USE_FBGEMM
-  if (weight.scalar_type() == at::ScalarType::Half) {
-    const auto weight_data = static_cast<fbgemm::float16*>(weight.data_ptr());
+  if (weight_contig->scalar_type() == at::ScalarType::Half) {
+    const auto weight_data =
+        static_cast<fbgemm::float16*>(weight_contig->data_ptr());
     at::parallel_for(
         0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
           fbgemm::FloatOrHalfToFused8BitRowwiseQuantizedSBFloat<
@@ -240,7 +242,7 @@
               output_data + start_idx * output_columns);
         });
   } else {
-    const auto weight_data = weight.data_ptr<float>();
+    const auto weight_data = weight_contig->data_ptr<float>();
     at::parallel_for(
         0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
           fbgemm::FloatOrHalfToFused8BitRowwiseQuantizedSBFloat<float>(
@@ -346,8 +348,9 @@
 
 #ifdef USE_FBGEMM
   if (!optimized_qparams) {
-    if (weight.scalar_type() == at::ScalarType::Half) {
-      const auto weight_data = static_cast<fbgemm::float16*>(weight.data_ptr());
+    if (weight_contig.scalar_type() == at::ScalarType::Half) {
+      const auto weight_data =
+          static_cast<fbgemm::float16*>(weight_contig.data_ptr());
       at::parallel_for(
           0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
             fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<
@@ -359,7 +362,7 @@
                 output_data + start_idx * output_shape[1]);
           });
     } else {
-      const auto weight_data = weight.data_ptr<float>();
+      const auto weight_data = weight_contig.data_ptr<float>();
       at::parallel_for(
           0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
             fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<float>(
diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py
index 41b735a..23cbffe 100644
--- a/test/quantization/core/test_quantized_op.py
+++ b/test/quantization/core/test_quantized_op.py
@@ -3785,10 +3785,10 @@
 
 @unittest.skipIf(IS_MACOS, "Known test failure on Mac.")
 class TestQuantizedEmbeddingOps(TestCase):
-    def _test_embedding_bag_unpack_fn(self, pack_fn, unpack_fn, num_embeddings, embedding_dim, bit_rate, optimized_qparams,
-                                      num_batches, data_type=np.float32):
-        weights = torch.from_numpy((np.random.random_sample((
-            num_batches, num_embeddings, embedding_dim)).squeeze() + 1).astype(data_type))
+
+    def _test_embedding_bag_unpack_impl(self, pack_fn, unpack_fn, bit_rate, optimized_qparams, weights):
+        data_type = weights.dtype
+
         qtype = torch.quint8
         if bit_rate == 8:
             w_packed = pack_fn(weights)
@@ -3796,13 +3796,13 @@
             w_packed = pack_fn(weights, optimized_qparams=optimized_qparams)
         w_unpacked = unpack_fn(w_packed)
 
-        if (bit_rate == 8 or bit_rate == 4) and data_type != np.float16:
+        if (bit_rate == 8 or bit_rate == 4) and data_type != torch.float16:
             # torch.quantize_per_channel does not support float16 yet.
 
             obs_weights = weights
             # Combine 3D embeddings (e.g. stacked combination of embeddings)
             # in a dimension orthogonal to channels.
-            if(num_batches > 1):
+            if (len(obs_weights.shape) > 2):
                 stacked_shape = list(weights.size())
                 stacked_shape[1] *= stacked_shape[0]
                 obs_weights = weights.reshape(stacked_shape[1:])
@@ -3826,13 +3826,13 @@
 
         # compare against C2 to ensure numerical equivalency.
         from caffe2.python import core, workspace
-        conversion_op = "FloatToFused8BitRowwiseQuantized" if data_type == np.float32 else "HalfFloatToFused8BitRowwiseQuantized"
+        conversion_op = "FloatToFused8BitRowwiseQuantized" if data_type == torch.float32 else "HalfFloatToFused8BitRowwiseQuantized"
         reverse_conversion_op = None
         if bit_rate == 4:
-            conversion_op = "FloatToFused4BitRowwiseQuantized" if data_type == np.float32 else "HalfToFused4BitRowwiseQuantized"
+            conversion_op = "FloatToFused4BitRowwiseQuantized" if data_type == torch.float32 else "HalfToFused4BitRowwiseQuantized"
             reverse_conversion_op = "Fused4BitRowwiseQuantizedToFloat"
         elif bit_rate == 2:
-            conversion_op = "FloatToFused2BitRowwiseQuantized" if data_type == np.float32 else "HalfToFused2BitRowwiseQuantized"
+            conversion_op = "FloatToFused2BitRowwiseQuantized" if data_type == torch.float32 else "HalfToFused2BitRowwiseQuantized"
             reverse_conversion_op = "Fused2BitRowwiseQuantizedToFloat"
 
         def get_c2_weights(weights, engine_str):
@@ -3862,13 +3862,35 @@
             engine = "GREEDY"
         else:
             engine = ""
-        w_packed_c2, w_unpacked_c2 = get_c2_weights(weights, engine)
+
+        # C2 quantization needs the memory format of Tensor to be `continuous`, otherwise it will
+        # throw exceptions. torch.clone() will make the memory format to be `continuous`
+        c2_copy = torch.clone(weights)
+        w_packed_c2, w_unpacked_c2 = get_c2_weights(c2_copy, engine)
 
         # Compare packed weights against C2.
         np.testing.assert_allclose(w_packed.numpy(), w_packed_c2.numpy(), atol=1e-6, rtol=1e-6)
         # Compare unpacked weights against C2
         np.testing.assert_allclose(w_unpacked.numpy(), w_unpacked_c2.numpy(), atol=1e-6, rtol=1e-6)
 
+
+    def _test_embedding_bag_unpack_fn(self, pack_fn, unpack_fn, num_embeddings, embedding_dim, bit_rate,
+                                      optimized_qparams, num_batches, data_type=np.float32):
+
+        # when num_batches = 1, it will create a 2D tensor
+        unsplit_weight = torch.from_numpy((np.random.random_sample((
+            num_batches, num_embeddings, embedding_dim)).squeeze() + 1).astype(np.float32))
+
+        # test unsplit weight (memory format is `contiguous`)
+        self._test_embedding_bag_unpack_impl(pack_fn, unpack_fn, bit_rate, optimized_qparams, unsplit_weight)
+
+        # test split weights (memory format is not `contiguous`)
+        split_dim = len(unsplit_weight.shape) - 2
+        split_weights = torch.split(unsplit_weight, 1, dim=split_dim)
+        for weight in split_weights:
+            self._test_embedding_bag_unpack_impl(pack_fn, unpack_fn, bit_rate, optimized_qparams, weight)
+
+
     """ Tests the correctness of the embedding_bag_8bit pack/unpack op against C2 """
     @unittest.skipIf(not BUILD_WITH_CAFFE2, "Test needs Caffe2")
     @given(num_embeddings=st.integers(10, 100),
@@ -3892,6 +3914,7 @@
         pack_fn = torch.ops.quantized.embedding_bag_4bit_prepack
         unpack_fn = torch.ops.quantized.embedding_bag_4bit_unpack
 
+        # 4bit and 2bit quantization right now only works for 2D Tensor so we set the num_batches to 1
         self._test_embedding_bag_unpack_fn(
             pack_fn, unpack_fn, num_embeddings, embedding_dim, 4, optimized_qparams, 1, data_type=data_type)
 
@@ -3905,6 +3928,7 @@
         pack_fn = torch.ops.quantized.embedding_bag_2bit_prepack
         unpack_fn = torch.ops.quantized.embedding_bag_2bit_unpack
 
+        # 4bit and 2bit quantization right now only works for 2D Tensor so we set the num_batches to 1
         self._test_embedding_bag_unpack_fn(
             pack_fn, unpack_fn, num_embeddings, embedding_dim, 2, optimized_qparams, 1, data_type=data_type)