Fix embedding quantization issue when memory format is not `contiguous ` (#82605)
Summary:
The current implementation of embedding quantization has the assumption that the memory address must be `contiguous` for the input `Tensor`
To guarantee that, we cast the input `weight` to be `contiguous` format, by
```
const auto weight_contig =
weight.expect_contiguous(weight.suggest_memory_format());
```
or
```
Tensor weight_contig = weight.contiguous(weight.suggest_memory_format());
```
However, in the branch `USE_FBGEMM = true`, it doesn't use `weight_contig`, which gives a wrong result when the input data is not `contiguous`
Example: N2297477
Test Plan:
```
buck1 test mode/dev //caffe2/test:quantization -- --exact 'caffe2/test:quantization - test_embedding_bag_byte_unpack (quantization.core.test_quantized_op.TestQuantizedEmbeddingOps)'
buck1 test mode/dev //caffe2/test:quantization -- --exact 'caffe2/test:quantization - test_embedding_bag_2bit_unpack (quantization.core.test_quantized_op.TestQuantizedEmbeddingOps)'
buck1 test mode/dev //caffe2/test:quantization -- --exact 'caffe2/test:quantization - test_embedding_bag_4bit_unpack (quantization.core.test_quantized_op.TestQuantizedEmbeddingOps)'
```
Differential Revision: D38302116
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82605
Approved by: https://github.com/houseroad
diff --git a/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp
index b32fcf0..748e89f 100644
--- a/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp
+++ b/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp
@@ -139,7 +139,8 @@
//
// Python example examining a packed 8bit zero_point and scale:
//
-// >> x = torch.from_numpy(np.array([[[10, 20], [30, 40]],[[50, 60], [70, 80]]], dtype=np.float32))
+// >> x = torch.from_numpy(np.array([[[10, 20], [30, 40]],[[50, 60], [70, 80]]],
+// dtype=np.float32))
// >> x_packed = torch.ops.quantized.embedding_bag_byte_prepack(x)
//
// # Pull out and examine packed scales, zero_points and values
@@ -228,8 +229,9 @@
auto* output_data = output.data_ptr<uint8_t>();
#ifdef USE_FBGEMM
- if (weight.scalar_type() == at::ScalarType::Half) {
- const auto weight_data = static_cast<fbgemm::float16*>(weight.data_ptr());
+ if (weight_contig->scalar_type() == at::ScalarType::Half) {
+ const auto weight_data =
+ static_cast<fbgemm::float16*>(weight_contig->data_ptr());
at::parallel_for(
0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
fbgemm::FloatOrHalfToFused8BitRowwiseQuantizedSBFloat<
@@ -240,7 +242,7 @@
output_data + start_idx * output_columns);
});
} else {
- const auto weight_data = weight.data_ptr<float>();
+ const auto weight_data = weight_contig->data_ptr<float>();
at::parallel_for(
0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
fbgemm::FloatOrHalfToFused8BitRowwiseQuantizedSBFloat<float>(
@@ -346,8 +348,9 @@
#ifdef USE_FBGEMM
if (!optimized_qparams) {
- if (weight.scalar_type() == at::ScalarType::Half) {
- const auto weight_data = static_cast<fbgemm::float16*>(weight.data_ptr());
+ if (weight_contig.scalar_type() == at::ScalarType::Half) {
+ const auto weight_data =
+ static_cast<fbgemm::float16*>(weight_contig.data_ptr());
at::parallel_for(
0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<
@@ -359,7 +362,7 @@
output_data + start_idx * output_shape[1]);
});
} else {
- const auto weight_data = weight.data_ptr<float>();
+ const auto weight_data = weight_contig.data_ptr<float>();
at::parallel_for(
0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<float>(
diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py
index 41b735a..23cbffe 100644
--- a/test/quantization/core/test_quantized_op.py
+++ b/test/quantization/core/test_quantized_op.py
@@ -3785,10 +3785,10 @@
@unittest.skipIf(IS_MACOS, "Known test failure on Mac.")
class TestQuantizedEmbeddingOps(TestCase):
- def _test_embedding_bag_unpack_fn(self, pack_fn, unpack_fn, num_embeddings, embedding_dim, bit_rate, optimized_qparams,
- num_batches, data_type=np.float32):
- weights = torch.from_numpy((np.random.random_sample((
- num_batches, num_embeddings, embedding_dim)).squeeze() + 1).astype(data_type))
+
+ def _test_embedding_bag_unpack_impl(self, pack_fn, unpack_fn, bit_rate, optimized_qparams, weights):
+ data_type = weights.dtype
+
qtype = torch.quint8
if bit_rate == 8:
w_packed = pack_fn(weights)
@@ -3796,13 +3796,13 @@
w_packed = pack_fn(weights, optimized_qparams=optimized_qparams)
w_unpacked = unpack_fn(w_packed)
- if (bit_rate == 8 or bit_rate == 4) and data_type != np.float16:
+ if (bit_rate == 8 or bit_rate == 4) and data_type != torch.float16:
# torch.quantize_per_channel does not support float16 yet.
obs_weights = weights
# Combine 3D embeddings (e.g. stacked combination of embeddings)
# in a dimension orthogonal to channels.
- if(num_batches > 1):
+ if (len(obs_weights.shape) > 2):
stacked_shape = list(weights.size())
stacked_shape[1] *= stacked_shape[0]
obs_weights = weights.reshape(stacked_shape[1:])
@@ -3826,13 +3826,13 @@
# compare against C2 to ensure numerical equivalency.
from caffe2.python import core, workspace
- conversion_op = "FloatToFused8BitRowwiseQuantized" if data_type == np.float32 else "HalfFloatToFused8BitRowwiseQuantized"
+ conversion_op = "FloatToFused8BitRowwiseQuantized" if data_type == torch.float32 else "HalfFloatToFused8BitRowwiseQuantized"
reverse_conversion_op = None
if bit_rate == 4:
- conversion_op = "FloatToFused4BitRowwiseQuantized" if data_type == np.float32 else "HalfToFused4BitRowwiseQuantized"
+ conversion_op = "FloatToFused4BitRowwiseQuantized" if data_type == torch.float32 else "HalfToFused4BitRowwiseQuantized"
reverse_conversion_op = "Fused4BitRowwiseQuantizedToFloat"
elif bit_rate == 2:
- conversion_op = "FloatToFused2BitRowwiseQuantized" if data_type == np.float32 else "HalfToFused2BitRowwiseQuantized"
+ conversion_op = "FloatToFused2BitRowwiseQuantized" if data_type == torch.float32 else "HalfToFused2BitRowwiseQuantized"
reverse_conversion_op = "Fused2BitRowwiseQuantizedToFloat"
def get_c2_weights(weights, engine_str):
@@ -3862,13 +3862,35 @@
engine = "GREEDY"
else:
engine = ""
- w_packed_c2, w_unpacked_c2 = get_c2_weights(weights, engine)
+
+ # C2 quantization needs the memory format of Tensor to be `continuous`, otherwise it will
+ # throw exceptions. torch.clone() will make the memory format to be `continuous`
+ c2_copy = torch.clone(weights)
+ w_packed_c2, w_unpacked_c2 = get_c2_weights(c2_copy, engine)
# Compare packed weights against C2.
np.testing.assert_allclose(w_packed.numpy(), w_packed_c2.numpy(), atol=1e-6, rtol=1e-6)
# Compare unpacked weights against C2
np.testing.assert_allclose(w_unpacked.numpy(), w_unpacked_c2.numpy(), atol=1e-6, rtol=1e-6)
+
+ def _test_embedding_bag_unpack_fn(self, pack_fn, unpack_fn, num_embeddings, embedding_dim, bit_rate,
+ optimized_qparams, num_batches, data_type=np.float32):
+
+ # when num_batches = 1, it will create a 2D tensor
+ unsplit_weight = torch.from_numpy((np.random.random_sample((
+ num_batches, num_embeddings, embedding_dim)).squeeze() + 1).astype(np.float32))
+
+ # test unsplit weight (memory format is `contiguous`)
+ self._test_embedding_bag_unpack_impl(pack_fn, unpack_fn, bit_rate, optimized_qparams, unsplit_weight)
+
+ # test split weights (memory format is not `contiguous`)
+ split_dim = len(unsplit_weight.shape) - 2
+ split_weights = torch.split(unsplit_weight, 1, dim=split_dim)
+ for weight in split_weights:
+ self._test_embedding_bag_unpack_impl(pack_fn, unpack_fn, bit_rate, optimized_qparams, weight)
+
+
""" Tests the correctness of the embedding_bag_8bit pack/unpack op against C2 """
@unittest.skipIf(not BUILD_WITH_CAFFE2, "Test needs Caffe2")
@given(num_embeddings=st.integers(10, 100),
@@ -3892,6 +3914,7 @@
pack_fn = torch.ops.quantized.embedding_bag_4bit_prepack
unpack_fn = torch.ops.quantized.embedding_bag_4bit_unpack
+ # 4bit and 2bit quantization right now only works for 2D Tensor so we set the num_batches to 1
self._test_embedding_bag_unpack_fn(
pack_fn, unpack_fn, num_embeddings, embedding_dim, 4, optimized_qparams, 1, data_type=data_type)
@@ -3905,6 +3928,7 @@
pack_fn = torch.ops.quantized.embedding_bag_2bit_prepack
unpack_fn = torch.ops.quantized.embedding_bag_2bit_unpack
+ # 4bit and 2bit quantization right now only works for 2D Tensor so we set the num_batches to 1
self._test_embedding_bag_unpack_fn(
pack_fn, unpack_fn, num_embeddings, embedding_dim, 2, optimized_qparams, 1, data_type=data_type)