[pyper][emb][quantization] Support emb trained in FP16 (#60736)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60736
Add support of embedding with input data type as float16, utilize new kernel functions added in fbgemm https://github.com/pytorch/FBGEMM/pull/616
Test Plan: `buck test caffe2/test/:quantization -- test_embedding_bag`
Reviewed By: supriyar
Differential Revision: D29392320
fbshipit-source-id: 0a120b3a58b6cf1d84961831097e9581ffd2b591
diff --git a/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp
index ed41b85..edd1e2c 100644
--- a/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp
+++ b/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp
@@ -1,3 +1,4 @@
+#include <c10/core/ScalarType.h>
#include <ATen/ATen.h>
#include <ATen/Parallel.h>
#include <ATen/native/quantized/cpu/embedding_packed_params.h>
@@ -197,6 +198,10 @@
// packed_weights = torch.ops.quantized.embedding_bag_byte_prepack(weights)
// assert(packed_weights.size() == torch.Size([2, 10, 11]))
+ TORCH_CHECK(
+ weight.scalar_type() == at::ScalarType::Float || weight.scalar_type() == at::ScalarType::Half,
+ "'embedding_bag_byte_prepack' only support float32 or float16.");
+
const auto weight_sizes = weight.sizes();
const auto cols_dim = weight_sizes.size() - 1;
const int32_t embedding_rows = c10::size_to_dim_(cols_dim, weight_sizes);
@@ -205,7 +210,6 @@
const int32_t output_columns = embedding_cols + 2 * sizeof(float);
Tensor weight_contig = weight.contiguous(weight.suggest_memory_format());
- const float* weight_data = weight_contig.data_ptr<float>();
// Adjust output dimensions to account for FP32 scale and zero_points.
std::vector<int64_t> output_shape = weight_sizes.vec();
output_shape[cols_dim] = output_columns;
@@ -218,17 +222,34 @@
auto* output_data = output.data_ptr<uint8_t>();
#ifdef USE_FBGEMM
-
- at::parallel_for(
- 0, embedding_rows, 1, [&](int32_t start_idx, int32_t end_idx) {
- for (int64_t row = start_idx; row < end_idx; ++row) {
- fbgemm::FloatOrHalfToFused8BitRowwiseQuantizedSBFloat<float>(
- weight_data + row * embedding_cols, 1,
- embedding_cols, output_data + row * output_columns);
- }
- });
+ if (weight.scalar_type() == at::ScalarType::Half) {
+ const auto weight_data = static_cast<fbgemm::float16*>(weight.data_ptr());
+ at::parallel_for(
+ 0, embedding_rows, 1, [&](int32_t start_idx, int32_t end_idx) {
+ for (int64_t row = start_idx; row < end_idx; ++row) {
+ fbgemm::FloatOrHalfToFused8BitRowwiseQuantizedSBFloat<fbgemm::float16>(
+ weight_data + row * embedding_cols, 1,
+ embedding_cols, output_data + row * output_columns);
+ }
+ });
+ }
+ else {
+ const auto weight_data = weight.data_ptr<float>();
+ at::parallel_for(
+ 0, embedding_rows, 1, [&](int32_t start_idx, int32_t end_idx) {
+ for (int64_t row = start_idx; row < end_idx; ++row) {
+ fbgemm::FloatOrHalfToFused8BitRowwiseQuantizedSBFloat<float>(
+ weight_data + row * embedding_cols, 1,
+ embedding_cols, output_data + row * output_columns);
+ }
+ });
+ }
#else
+ const auto float_weight = weight_contig.scalar_type() == at::ScalarType::Half
+ ? weight_contig.to(at::ScalarType::Float)
+ : weight_contig;
+ const auto weight_data = float_weight.data_ptr<float>();
constexpr float kEpsilon = 1e-8f;
for (std::size_t row = 0; row < embedding_rows; ++row) {
const float* input_row = weight_data + row * embedding_cols;
@@ -262,12 +283,15 @@
const bool optimized_qparams,
const int64_t nbins,
const double ratio) {
+ TORCH_CHECK(
+ weight.scalar_type() == at::ScalarType::Float || weight.scalar_type() == at::ScalarType::Half,
+ "'qembeddingbag_nbit_prepack' only support float32 or float16.");
+
int64_t embedding_rows = weight.size(0);
int64_t embedding_cols = weight.size(1);
Tensor weight_contig = weight.contiguous(weight.suggest_memory_format());
- const auto weight_data = weight.data_ptr<float>();
TORCH_CHECK(
bit_width == 4 || bit_width == 2,
"bit_width must be either 2 or 4 to use 'qembeddingbag_nbit_prepack'."
@@ -299,18 +323,35 @@
#ifdef USE_FBGEMM
if (!optimized_qparams) {
- at::parallel_for(
- 0, embedding_rows, 1, [&](int32_t start_idx, int32_t end_idx) {
- for (int64_t row = start_idx; row < end_idx; ++row) {
- fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<float>(
- bit_width, weight_data + row * embedding_cols, 1,
- embedding_cols, output_data + row * output_shape[1]);
- }
- });
+ if (weight.scalar_type() == at::ScalarType::Half) {
+ const auto weight_data = static_cast<fbgemm::float16*>(weight.data_ptr());
+ at::parallel_for(
+ 0, embedding_rows, 1, [&](int32_t start_idx, int32_t end_idx) {
+ for (int64_t row = start_idx; row < end_idx; ++row) {
+ fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<fbgemm::float16>(
+ bit_width, weight_data + row * embedding_cols, 1,
+ embedding_cols, output_data + row * output_shape[1]);
+ }
+ });
+ }
+ else {
+ const auto weight_data = weight.data_ptr<float>();
+ at::parallel_for(
+ 0, embedding_rows, 1, [&](int32_t start_idx, int32_t end_idx) {
+ for (int64_t row = start_idx; row < end_idx; ++row) {
+ fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<float>(
+ bit_width, weight_data + row * embedding_cols, 1,
+ embedding_cols, output_data + row * output_shape[1]);
+ }
+ });
+ }
} else {
#endif // USE_FBGEMM
const auto output_columns = output.size(output.dim() - 1);
-
+ const auto float_weight = weight_contig.scalar_type() == at::ScalarType::Half
+ ? weight_contig.to(at::ScalarType::Float)
+ : weight_contig;
+ const auto weight_data = float_weight.data_ptr<float>();
for (int row = 0; row < embedding_rows; ++row) {
const float* input_row = weight_data + row * embedding_cols;
std::uint8_t* output_row = output_data + row * output_columns;
@@ -320,7 +361,7 @@
if (optimized_qparams) {
at::Tensor xmax_tensor, xmin_tensor;
std::tie(xmax_tensor, xmin_tensor) = at::choose_qparams_optimized(
- weight_contig[row], embedding_cols, nbins, ratio, bit_width);
+ float_weight[row], embedding_cols, nbins, ratio, bit_width);
TORCH_CHECK(
xmax_tensor.numel() == 1 && xmin_tensor.numel() == 1,
"Expected choose_qparams_optimized to return min/max tensors of size 1");
diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py
index 1aff284..517b23e 100644
--- a/test/quantization/core/test_quantized_op.py
+++ b/test/quantization/core/test_quantized_op.py
@@ -3142,9 +3142,9 @@
@unittest.skipIf(sys.platform == "darwin", "Known test failure on Mac.")
class TestQuantizedEmbeddingOps(TestCase):
def _test_embedding_bag_unpack_fn(self, pack_fn, unpack_fn, num_embeddings, embedding_dim, bit_rate, optimized_qparams,
- num_batches):
+ num_batches, data_type=np.float32):
weights = torch.from_numpy((np.random.random_sample((
- num_batches, num_embeddings, embedding_dim)).squeeze() + 1).astype(np.float32))
+ num_batches, num_embeddings, embedding_dim)).squeeze() + 1).astype(data_type))
qtype = torch.quint8
if bit_rate == 8:
w_packed = pack_fn(weights)
@@ -3152,7 +3152,9 @@
w_packed = pack_fn(weights, optimized_qparams=optimized_qparams)
w_unpacked = unpack_fn(w_packed)
- if bit_rate == 8 or bit_rate == 4:
+ if (bit_rate == 8 or bit_rate == 4) and data_type != np.float16:
+ # torch.quantize_per_channel does not support float16 yet.
+
obs_weights = weights
# Combine 3D embeddings (e.g. stacked combination of embeddings)
# in a dimension orthogonal to channels.
@@ -3180,13 +3182,13 @@
# compare against C2 to ensure numerical equivalency.
from caffe2.python import core, workspace
- conversion_op = "FloatToFused8BitRowwiseQuantized"
+ conversion_op = "FloatToFused8BitRowwiseQuantized" if data_type == np.float32 else "HalfFloatToFused8BitRowwiseQuantized"
reverse_conversion_op = None
if bit_rate == 4:
- conversion_op = "FloatToFused4BitRowwiseQuantized"
+ conversion_op = "FloatToFused4BitRowwiseQuantized" if data_type == np.float32 else "HalfToFused4BitRowwiseQuantized"
reverse_conversion_op = "Fused4BitRowwiseQuantizedToFloat"
elif bit_rate == 2:
- conversion_op = "FloatToFused2BitRowwiseQuantized"
+ conversion_op = "FloatToFused2BitRowwiseQuantized" if data_type == np.float32 else "HalfToFused2BitRowwiseQuantized"
reverse_conversion_op = "Fused2BitRowwiseQuantizedToFloat"
def get_c2_weights(weights, engine_str):
@@ -3226,32 +3228,38 @@
""" Tests the correctness of the embedding_bag_8bit pack/unpack op against C2 """
@given(num_embeddings=st.integers(10, 100),
embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0),
- num_batches=st.integers(1, 5))
- def test_embedding_bag_byte_unpack(self, num_embeddings, embedding_dim, num_batches):
+ num_batches=st.integers(1, 5),
+ data_type=st.sampled_from([np.float32, np.float16]),)
+ def test_embedding_bag_byte_unpack(self, num_embeddings, embedding_dim, num_batches, data_type):
pack_fn = torch.ops.quantized.embedding_bag_byte_prepack
unpack_fn = torch.ops.quantized.embedding_bag_byte_unpack
- self._test_embedding_bag_unpack_fn(pack_fn, unpack_fn, num_embeddings, embedding_dim, 8, False, num_batches)
+ self._test_embedding_bag_unpack_fn(
+ pack_fn, unpack_fn, num_embeddings, embedding_dim, 8, False, num_batches, data_type=data_type)
""" Tests the correctness of the embedding_bag_4bit pack/unpack op against C2 """
@given(num_embeddings=st.integers(10, 100),
embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0),
- optimized_qparams=st.booleans(),)
- def test_embedding_bag_4bit_unpack(self, num_embeddings, embedding_dim, optimized_qparams):
+ optimized_qparams=st.booleans(),
+ data_type=st.sampled_from([np.float32, np.float16]),)
+ def test_embedding_bag_4bit_unpack(self, num_embeddings, embedding_dim, optimized_qparams, data_type):
pack_fn = torch.ops.quantized.embedding_bag_4bit_prepack
unpack_fn = torch.ops.quantized.embedding_bag_4bit_unpack
- self._test_embedding_bag_unpack_fn(pack_fn, unpack_fn, num_embeddings, embedding_dim, 4, optimized_qparams, 1)
+ self._test_embedding_bag_unpack_fn(
+ pack_fn, unpack_fn, num_embeddings, embedding_dim, 4, optimized_qparams, 1, data_type=data_type)
""" Tests the correctness of the embedding_bag_2bit pack/unpack op against C2 """
@given(num_embeddings=st.integers(10, 100),
embedding_dim=st.integers(5, 50).filter(lambda x: x % 8 == 0),
- optimized_qparams=st.booleans(),)
- def test_embedding_bag_2bit_unpack(self, num_embeddings, embedding_dim, optimized_qparams):
+ optimized_qparams=st.booleans(),
+ data_type=st.sampled_from([np.float32, np.float16]),)
+ def test_embedding_bag_2bit_unpack(self, num_embeddings, embedding_dim, optimized_qparams, data_type):
pack_fn = torch.ops.quantized.embedding_bag_2bit_prepack
unpack_fn = torch.ops.quantized.embedding_bag_2bit_unpack
- self._test_embedding_bag_unpack_fn(pack_fn, unpack_fn, num_embeddings, embedding_dim, 2, optimized_qparams, 1)
+ self._test_embedding_bag_unpack_fn(
+ pack_fn, unpack_fn, num_embeddings, embedding_dim, 2, optimized_qparams, 1, data_type=data_type)
def embedding_bag_rowwise_offsets_run(