[quant] skip tests without fbgemm support (#47800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47800
Fixes #47748
Test Plan:
python test/test_quantization.py
Imported from OSS
Reviewed By: vkuzo
Differential Revision: D24904885
fbshipit-source-id: 76d27659e73c7f60b3fcc25606657ee9305117be
diff --git a/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp b/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp
index 13f98cd..6071b94 100644
--- a/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp
+++ b/aten/src/ATen/native/quantized/cpu/qembeddingbag.cpp
@@ -305,9 +305,12 @@
success,
"FBGEMM GenerateEmbeddingSpMDMRowWiseSparse kernel failed for 8-bit input");
}
+ return output;
#endif
// TODO add default (non-FBGEMM) implementation.
- return output;
+ TORCH_CHECK(
+ false,
+ "embedding_bag_byte expects FBGEMM support. This PyTorch installation was not built with FBGEMM operators");
}
at::Tensor embedding_bag_byte_helper(
diff --git a/test/quantization/test_quantize.py b/test/quantization/test_quantize.py
index 597dbb2..96b4546 100644
--- a/test/quantization/test_quantize.py
+++ b/test/quantization/test_quantize.py
@@ -536,6 +536,7 @@
self.checkQuantizedLinear(model.fc)
+ @skipIfNoFBGEMM
def test_quantized_embedding_bag(self):
r""" Test the post-training quantization flow, serialization and scripting
of embedding_bag modules
diff --git a/test/quantization/test_quantize_jit.py b/test/quantization/test_quantize_jit.py
index 44bbf11..dc7b3f9 100644
--- a/test/quantization/test_quantize_jit.py
+++ b/test/quantization/test_quantize_jit.py
@@ -3032,6 +3032,7 @@
FunctionalLinear(weight, bias), x,
"quantized::linear_dynamic", tracing=tracing, dynamic=True)
+ @skipIfNoFBGEMM
def test_embedding_bag(self):
class M(torch.nn.Module):
def __init__(self, weights):
diff --git a/test/quantization/test_quantized_module.py b/test/quantization/test_quantized_module.py
index 60cb1f3..f07f919 100644
--- a/test/quantization/test_quantized_module.py
+++ b/test/quantization/test_quantized_module.py
@@ -726,6 +726,7 @@
embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0),
set_qconfig=st.booleans(),
)
+ @skipIfNoFBGEMM
def test_embedding_api(self, num_embeddings, embedding_dim, set_qconfig):
num_lengths = np.random.randint(1, 6)
lengths = np.random.randint(0, 21, size=num_lengths).astype(np.int32)
diff --git a/test/quantization/test_quantized_op.py b/test/quantization/test_quantized_op.py
index 12871f6..3397205 100644
--- a/test/quantization/test_quantized_op.py
+++ b/test/quantization/test_quantized_op.py
@@ -3105,6 +3105,7 @@
""" Tests the correctness of the quantized embedding lookup operator """
@given(num_embeddings=st.integers(10, 100),
embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0))
+ @skipIfNoFBGEMM
def test_embedding_byte(self, num_embeddings, embedding_dim):
quant_op = torch.ops.quantized.embedding_byte
prepack_op = torch.ops.quantized.embedding_bag_prepack