[quant] Set sparse to False for embedding_bag ops in graph mode (#45997)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/45997
The current sparse field using in the float module is for sparse gradients, which is not applicable
to inference. The sparse field in the quantizd ops denotes pruned weights.
Test Plan:
python test/test_quantization.py TestQuantizeDynamicJitOps.test_embedding_bag
Imported from OSS
Reviewed By: qizzzh
Differential Revision: D24176543
fbshipit-source-id: a05b4ff949e0375462ae411947f68076e1b460d2
diff --git a/test/quantization/test_quantize_jit.py b/test/quantization/test_quantize_jit.py
index a0fad9b..44bbf11 100644
--- a/test/quantization/test_quantize_jit.py
+++ b/test/quantization/test_quantize_jit.py
@@ -3039,14 +3039,14 @@
self.embedding1 = torch.nn.EmbeddingBag(num_embeddings=10,
embedding_dim=12,
include_last_offset=True,
- sparse=False,
+ sparse=True,
_weight=weights,
mode='sum')
self.embedding2 = torch.nn.EmbeddingBag(num_embeddings=10,
embedding_dim=12,
include_last_offset=True,
- sparse=False,
+ sparse=True,
_weight=weights,
mode='sum')
@@ -3077,6 +3077,7 @@
FileCheck().check("quantized::embedding_bag_4bit_rowwise_offsets") \
.check_next("quantized::embedding_bag_byte_rowwise_offsets") \
.run(m.graph)
+ m(*dummy_inputs)
diff --git a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp
index fd49cf6..0a71b5e 100644
--- a/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp
+++ b/torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp
@@ -330,6 +330,8 @@
// Create and insert quantized embedding op.
Value* none = g->insertConstant(IValue());
Value* zero = g->insertConstant(IValue(0));
+ bool sparse_param = false;
+ auto sparse_const = g->insertConstant(sparse_param);
if (is_aten_op) {
TORCH_CHECK(
@@ -340,6 +342,10 @@
for (auto i = 1; i < inputs_size - 1; ++i) {
qembedding_bag_inputs.push_back(embedding_bag_inputs[i]);
}
+ // The sparse field in the float operator denotes sparse gradients.
+ // For inference this stands for pruned weights. We currently don't support
+ // pruning in graph mode API so we set the field to 0 for inference.
+ qembedding_bag_inputs[5] = sparse_const;
} else {
TORCH_CHECK(
inputs_size == 11,
@@ -348,8 +354,8 @@
qembedding_bag_inputs.push_back(embedding_bag_inputs[3]); // offsets
qembedding_bag_inputs.push_back(
embedding_bag_inputs[6]); // scale_grad_by_freq
- qembedding_bag_inputs.push_back(zero); // zero
- qembedding_bag_inputs.push_back(embedding_bag_inputs[8]); // sparse
+ qembedding_bag_inputs.push_back(zero); // mode
+ qembedding_bag_inputs.push_back(sparse_const); // pruned_weights
qembedding_bag_inputs.push_back(
embedding_bag_inputs[9]); // per_sample_weights
}