equal_quantized_cpu requires both inputs are quantized tensor (#95875)
**Summary**
Fix the issue https://github.com/pytorch/pytorch/issues/95291, `equal_quantized_cpu` requires both inputs are quantized tensor.
**Test Plan**
```
python -m pytest test_quantization.py -k test_quantized_equal
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95875
Approved by: https://github.com/vkuzo, https://github.com/jgong5
diff --git a/aten/src/ATen/native/quantized/QTensor.cpp b/aten/src/ATen/native/quantized/QTensor.cpp
index eeb8a64..6a43bde 100644
--- a/aten/src/ATen/native/quantized/QTensor.cpp
+++ b/aten/src/ATen/native/quantized/QTensor.cpp
@@ -232,7 +232,7 @@
TORCH_CHECK(
self.device().type() == kCPU && other.device().type() == kCPU,
"quantized_equal is implemented only for the QuantizedCPU backend");
- if (!other.is_quantized()) {
+ if (!self.is_quantized() || !other.is_quantized()) {
return false;
}
diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py
index ed37552..c7009d5 100644
--- a/test/quantization/core/test_quantized_op.py
+++ b/test/quantization/core/test_quantized_op.py
@@ -2337,6 +2337,13 @@
self.assertEqual(qX.equal(qX), equal_ref(qX, qX))
self.assertEqual(qX.equal(qX2), equal_ref(qX, qX2))
+ """Tests quantized equal op with input of non-quantized tensor."""
+ def test_quantized_equal(self,):
+ x = torch.rand(1)
+ y = torch.quantize_per_tensor(x, scale=0.5, zero_point=0, dtype=torch.qint8)
+ self.assertTrue(not torch.equal(x, y))
+ self.assertTrue(not torch.equal(y, x))
+
@skipIfNoFBGEMM
def test_group_norm(self):
# hypothesis is flaky for this test, create test cases manually