Remove fbgemm_is_cpu_supported in favor of torch.backends.quantized.supported_qengines (#26840)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26840
Cleaning up top-level namespace. Also cosmetic changes to torch.backends.quantized
Test Plan: Imported from OSS
Differential Revision: D17604403
Pulled By: dzhulgakov
fbshipit-source-id: c55af277ea7319d962a82a6120f65ccd47a60abc
diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp
index 2b28ad4..ea7cd4a 100644
--- a/aten/src/ATen/Context.cpp
+++ b/aten/src/ATen/Context.cpp
@@ -4,26 +4,30 @@
#include <c10/core/TensorOptions.h>
-#include <thread>
#include <mutex>
#include <sstream>
-#include <string>
#include <stdexcept>
+#include <string>
+#include <thread>
#include <ATen/Tensor.h>
#include <ATen/cpu/FlushDenormal.h>
-#include <TH/TH.h> // for USE_LAPACK
+#include <TH/TH.h> // for USE_LAPACK
+
+#ifdef USE_FBGEMM
+#include "fbgemm/Fbgemm.h"
+#endif // USE_FBGEMM
namespace at {
Context::Context()
-: thc_state(nullptr, [](THCState* p){ /* no-op */ } )
-, thh_state(nullptr, [](THHState* p){ /* no-op */ } ) {}
+ : thc_state(nullptr, [](THCState* p) { /* no-op */ }),
+ thh_state(nullptr, [](THHState* p) { /* no-op */ }) {}
// TODO: This could be bad juju if someone calls globalContext() in the
// destructor of an object with static lifetime.
-Context & globalContext() {
+Context& globalContext() {
static Context globalContext_;
return globalContext_;
}
@@ -96,7 +100,8 @@
}
at::QEngine Context::qEngine() const {
- return quantized_engine;
+ // If wasn't explicitly set - take the last one available
+ return quantized_engine.value_or(supportedQEngines().back());
}
void Context::setQEngine(at::QEngine e) {
@@ -108,16 +113,31 @@
TORCH_CHECK(false, "quantized engine ", toString(e), " is not supported");
}
-std::vector<at::QEngine> Context::supportedQEngines() const {
- static auto supported_qengines = {
- at::kNoQEngine,
- #ifdef USE_FBGEMM
- at::kFBGEMM,
- #endif
- #ifdef USE_PYTORCH_QNNPACK
- at::kQNNPACK,
- #endif
- };
+const std::vector<at::QEngine>& Context::supportedQEngines() const {
+ static auto supported_qengines = []() {
+ std::vector<at::QEngine> engines = {};
+ // Engines are listed in priority order: later one wins
+ // By default we prefer FBGEMM if we're running on server side
+ // QNNPACK on server side has some issue, so we disable it by default.
+#ifdef C10_MOBILE
+ engines.push_back(at::kNoQEngine);
+#ifdef USE_PYTORCH_QNNPACK
+ engines.push_back(at::kQNNPACK);
+#endif
+#else // C10_MOBILE
+#ifdef USE_PYTORCH_QNNPACK
+ engines.push_back(at::kQNNPACK);
+#endif
+ engines.push_back(at::kNoQEngine);
+#endif // C10_MOBILE
+
+#ifdef USE_FBGEMM
+ if (fbgemm::fbgemmSupportedCPU()) {
+ engines.push_back(at::kFBGEMM);
+ }
+#endif
+ return engines;
+ }();
return supported_qengines;
}
@@ -143,4 +163,4 @@
};
REGISTER_LEGACY_TYPE_INIT(LegacyDeviceTypeInit);
-}
+} // namespace at
diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h
index 1bc180f..9fd58ff 100644
--- a/aten/src/ATen/Context.h
+++ b/aten/src/ATen/Context.h
@@ -108,7 +108,7 @@
void setDeterministicCuDNN(bool);
at::QEngine qEngine() const;
void setQEngine(at::QEngine e);
- std::vector<at::QEngine> supportedQEngines() const;
+ const std::vector<at::QEngine>& supportedQEngines() const;
private:
void initCUDAIfNeeded(DeviceType p) {
@@ -127,12 +127,7 @@
bool deterministic_cudnn = false;
bool benchmark_cudnn = false;
bool enabled_mkldnn = true;
- at::QEngine quantized_engine =
-#ifdef USE_FBGEMM
- at::kFBGEMM;
-#else
- at::kNoQEngine;
-#endif
+ c10::optional<at::QEngine> quantized_engine = c10::nullopt;
std::unique_ptr<THCState, void(*)(THCState*)> thc_state;
std::unique_ptr<THHState, void(*)(THHState*)> thh_state;
};
diff --git a/aten/src/ATen/core/OpsAlreadyMovedToC10.cpp b/aten/src/ATen/core/OpsAlreadyMovedToC10.cpp
index b6de5a5..95af67e 100644
--- a/aten/src/ATen/core/OpsAlreadyMovedToC10.cpp
+++ b/aten/src/ATen/core/OpsAlreadyMovedToC10.cpp
@@ -219,7 +219,6 @@
{"aten::fbgemm_linear_fp16_weight", ""},
{"aten::fbgemm_pack_quantized_matrix", ""},
{"aten::fbgemm_pack_quantized_matrix", "KN"},
- {"aten::fbgemm_is_cpu_supported", ""},
{"aten::log", ""},
{"aten::log_", ""},
{"aten::log10", ""},
diff --git a/aten/src/ATen/native/QuantizedLinear.cpp b/aten/src/ATen/native/QuantizedLinear.cpp
index 7ae4397..288a49f 100644
--- a/aten/src/ATen/native/QuantizedLinear.cpp
+++ b/aten/src/ATen/native/QuantizedLinear.cpp
@@ -252,10 +252,6 @@
quantized, col_offsets, q_params.scale, q_params.zero_point);
}
-bool fbgemm_is_cpu_supported() {
- return fbgemm::fbgemmSupportedCPU();
-}
-
Tensor fbgemm_pack_quantized_matrix(const Tensor& weight) {
// We make a strong guarantee that models using these operators will have the
// same numerics across different machines. Therefore, we do not provide a
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index db6a9ad..59fc8c9 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -1506,9 +1506,6 @@
- func: fbgemm_pack_quantized_matrix.KN(Tensor input, int K, int N) -> Tensor
use_c10_dispatcher: full
-- func: fbgemm_is_cpu_supported() -> bool
- use_c10_dispatcher: full
-
- func: linspace(Scalar start, Scalar end, int steps=100, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
- func: linspace.out(Scalar start, Scalar end, int steps=100, *, Tensor(a!) out) -> Tensor(a!)
diff --git a/aten/src/ATen/native/quantized/cpu/qadd.cpp b/aten/src/ATen/native/quantized/cpu/qadd.cpp
index 3cb4beb..5e7f114 100644
--- a/aten/src/ATen/native/quantized/cpu/qadd.cpp
+++ b/aten/src/ATen/native/quantized/cpu/qadd.cpp
@@ -185,11 +185,12 @@
public:
Tensor operator()(Tensor qa, Tensor qb, double scale, int64_t zero_point) {
check_inputs(qa, qb);
- #ifdef USE_PYTORCH_QNNPACK
- if (at::globalContext().qEngine() == at::QEngine::QNNPACK) {
+#ifdef USE_PYTORCH_QNNPACK
+ if (at::globalContext().qEngine() == at::QEngine::QNNPACK &&
+ qa.scalar_type() == kQUInt8 && qb.scalar_type() == kQUInt8) {
return qnnpack_add(qa, qb, scale, zero_point);
}
- #endif
+#endif
auto qc = at::_empty_affine_quantized(
qa.sizes(),
at::device(kCPU).dtype(qa.scalar_type()),
diff --git a/aten/src/ATen/native/quantized/cpu/qpool.cpp b/aten/src/ATen/native/quantized/cpu/qpool.cpp
index 18bc848..2f3602d 100644
--- a/aten/src/ATen/native/quantized/cpu/qpool.cpp
+++ b/aten/src/ATen/native/quantized/cpu/qpool.cpp
@@ -401,7 +401,7 @@
std::vector<int64_t> dilation,
bool ceil_mode) {
#ifdef USE_PYTORCH_QNNPACK
- if (at::globalContext().qEngine() == at::QEngine::QNNPACK) {
+ if (at::globalContext().qEngine() == at::QEngine::QNNPACK && qx.scalar_type() == kQUInt8) {
return qnnpack_maxpool(qx, kernel_size, stride, padding, dilation, ceil_mode);
}
#endif
diff --git a/aten/src/ATen/native/quantized/cpu/qrelu.cpp b/aten/src/ATen/native/quantized/cpu/qrelu.cpp
index bfa1703..4ce423f 100644
--- a/aten/src/ATen/native/quantized/cpu/qrelu.cpp
+++ b/aten/src/ATen/native/quantized/cpu/qrelu.cpp
@@ -85,7 +85,7 @@
Tensor quantized_relu(const Tensor& qx) {
#ifdef USE_PYTORCH_QNNPACK
- if (at::globalContext().qEngine() == at::QEngine::QNNPACK) {
+ if (at::globalContext().qEngine() == at::QEngine::QNNPACK && qx.scalar_type() == kQUInt8) {
return qnnpack_relu(qx);
}
#endif
diff --git a/test/backward_compatibility/check_backward_compatibility.py b/test/backward_compatibility/check_backward_compatibility.py
index c077a03..f184382 100644
--- a/test/backward_compatibility/check_backward_compatibility.py
+++ b/test/backward_compatibility/check_backward_compatibility.py
@@ -11,6 +11,7 @@
white_list = [
('quantize', datetime.date(2019, 10, 1)),
('q_per_channel_axis', datetime.date(2019, 10, 1)),
+ ('fbgemm_is_cpu_supported', datetime.date(2019, 10, 1)),
]
diff --git a/test/common_quantized.py b/test/common_quantized.py
index 45d2bcd..31cc54f 100644
--- a/test/common_quantized.py
+++ b/test/common_quantized.py
@@ -65,12 +65,9 @@
@contextmanager
def enable_mobile_quantized_engine():
+ previous = torch.backends.quantized.engine
torch.backends.quantized.engine = 'qnnpack'
try:
yield
finally:
- qengines = torch.backends.quantized.get_supported_qengines()
- if 'fbgemm' in qengines:
- torch.backends.quantized.engine = 'fbgemm'
- else:
- torch.backends.quantized.engine = 'none'
+ torch.backends.quantized.engine = previous
diff --git a/test/test_jit.py b/test/test_jit.py
index 795e61c..c440ece 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -7210,9 +7210,9 @@
a = A()
self.assertEqual(a.with_docstring.__doc__, 'test str')
- @unittest.skipIf(not torch.fbgemm_is_cpu_supported(),
- 'Quantized RNN requires FBGEMM. FBGEMM is only optimized for CPUs'
- ' with instruction set support avx2 or newer.')
+ @unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
+ 'Quantized RNN requires FBGEMM. FBGEMM is only optimized for CPUs'
+ ' with instruction set support avx2 or newer.')
def test_rnn_cell_quantized(self):
d_in, d_hid = 2, 2
@@ -7304,9 +7304,9 @@
for out, ref_out in zip(outs, ref_outs):
torch.testing.assert_allclose(out, ref_out)
- @unittest.skipIf(not torch.fbgemm_is_cpu_supported(),
- 'Quantized RNN requires FBGEMM. FBGEMM is only optimized for CPUs'
- ' with instruction set support avx2 or newer.')
+ @unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
+ 'Quantized RNN requires FBGEMM. FBGEMM is only optimized for CPUs'
+ ' with instruction set support avx2 or newer.')
def test_rnn_quantized(self):
d_in, d_hid = 2, 2
@@ -12378,7 +12378,7 @@
traced = torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)])
- if torch.fbgemm_is_cpu_supported():
+ if 'fbgemm' in torch.backends.quantized.supported_engines:
def test_quantization_modules(self):
K1, N1 = 2, 2
@@ -15189,7 +15189,7 @@
with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, "keyword-arg expansion is not supported"):
torch.jit.script(fn)
- @unittest.skipIf(not torch.fbgemm_is_cpu_supported(), "requires FBGEMM")
+ @unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines, "requires FBGEMM")
def test_erase_class_tensor_shapes(self):
class Linear(torch.nn.Module):
def __init__(self, in_features, out_features):
@@ -16064,7 +16064,7 @@
def test_snli(self):
self._test_snli(self, device='cpu')
- if torch.fbgemm_is_cpu_supported():
+ if 'fbgemm' in torch.backends.quantized.supported_engines:
def test_snli_quantized(self):
self._test_snli(self, device='cpu', quantized=True)
@@ -16206,7 +16206,7 @@
def test_vae(self):
self._test_vae(self, device='cpu')
- if torch.fbgemm_is_cpu_supported():
+ if 'fbgemm' in torch.backends.quantized.supported_engines:
def test_vae_quantized(self):
self._test_vae(self, device='cpu', quantized=True)
diff --git a/test/test_nn.py b/test/test_nn.py
index 98c877c..ea5dc18 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -2459,9 +2459,9 @@
# should be bitwise equal
self.assertEqual(input.grad, inputf.grad.to(dtype), prec=0)
- @unittest.skipIf(not torch.fbgemm_is_cpu_supported(),
- 'Linear_FP16_weight requires FBGEMM. FBGEMM is only optimized for CPUs'
- ' with instruction set support avx2 or newer.')
+ @unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
+ 'Linear_FP16_weight requires FBGEMM. FBGEMM is only optimized for CPUs'
+ ' with instruction set support avx2 or newer.')
def test_fb_fc_packed(self):
X = np.random.rand(16, 16).astype(np.float32) - 0.5
W = np.random.rand(16, 16).astype(np.float32) - 0.5
diff --git a/test/test_quantization.py b/test/test_quantization.py
index 198114c..497c647 100644
--- a/test/test_quantization.py
+++ b/test/test_quantization.py
@@ -30,11 +30,9 @@
import io
import copy
-@unittest.skipIf(
- not torch.fbgemm_is_cpu_supported(),
- " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
- " with instruction set support avx2 or newer.",
-)
+@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
+ " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
+ " with instruction set support avx2 or newer.")
class PostTrainingQuantTest(QuantizationTestCase):
def test_single_layer(self):
r"""Quantize SingleLayerLinearModel which has one Linear module, make sure it is swapped
@@ -292,11 +290,9 @@
checkQuantized(model)
-@unittest.skipIf(
- not torch.fbgemm_is_cpu_supported(),
- " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
- " with instruction set support avx2 or newer.",
-)
+@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
+ " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
+ " with instruction set support avx2 or newer.")
class PostTrainingDynamicQuantTest(QuantizationTestCase):
def test_single_layer(self):
r"""Dynamic Quantize SingleLayerLinearDynamicModel which has one Linear module,
@@ -569,11 +565,9 @@
for out, ref in zip(final_hiddens_fp16, ref_hid):
torch.testing.assert_allclose(out, ref)
-@unittest.skipIf(
- not torch.fbgemm_is_cpu_supported(),
- " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
- " with instruction set support avx2 or newer.",
-)
+@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
+ " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
+ " with instruction set support avx2 or newer.")
class QuantizationAwareTrainingTest(QuantizationTestCase):
def test_manual(self):
model = ManualLinearQATModel()
@@ -656,10 +650,9 @@
self.checkScriptable(self.qmodel_under_test, [(xq, xq)], check_save_load=True)
self.checkScriptable(self.model_under_test, [(xq.dequantize(), xq.dequantize())], check_save_load=True)
-@unittest.skipIf(not torch.fbgemm_is_cpu_supported(),
- 'Quantization requires FBGEMM. FBGEMM does not play'
- ' well with UBSAN at the moment, so we skip the test if'
- ' we are in a UBSAN environment.')
+@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
+ " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
+ " with instruction set support avx2 or newer.")
class FusionTest(QuantizationTestCase):
def test_fuse_module_train(self):
model = ModelForFusion(default_qat_qconfig).train()
@@ -901,10 +894,9 @@
loaded = torch.jit.load(buf)
self.assertEqual(obs.calculate_qparams(), loaded.calculate_qparams())
-@unittest.skipIf(not torch.fbgemm_is_cpu_supported(),
- 'Quantization requires FBGEMM. FBGEMM does not play'
- ' well with UBSAN at the moment, so we skip the test if'
- ' we are in a UBSAN environment.')
+@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
+ " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
+ " with instruction set support avx2 or newer.")
class RecordHistogramObserverTest(QuantizationTestCase):
def test_record_observer(self):
model = SingleLayerLinearModel()
diff --git a/test/test_quantized.py b/test/test_quantized.py
index 33f5416..7fbaeee 100644
--- a/test/test_quantized.py
+++ b/test/test_quantized.py
@@ -967,11 +967,9 @@
self.assertEqual(qX.equal(qX2), equal_ref(qX, qX2))
-@unittest.skipIf(
- not torch.fbgemm_is_cpu_supported(),
- " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
- " with instruction set support avx2 or newer.",
-)
+@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
+ " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
+ " with instruction set support avx2 or newer.")
class TestDynamicQuantizedLinear(TestCase):
"""Tests the correctness of the dynamic quantized linear and linear_relu op."""
@no_deadline
@@ -1086,11 +1084,9 @@
self.assertEqual(Y_fp32, Y_fp32_ref,
message="torch.ops.quantized.linear_dynamic (fbgemm) results are off")
-@unittest.skipIf(
- not torch.fbgemm_is_cpu_supported(),
- " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
- " with instruction set support avx2 or newer.",
-)
+@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
+ " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
+ " with instruction set support avx2 or newer.")
class TestQuantizedLinear(unittest.TestCase):
"""Tests the correctness of the quantized linear and linear_relu op."""
@no_deadline
@@ -1264,11 +1260,9 @@
W_q.q_zero_point(), W_q_origin.q_zero_point())
-@unittest.skipIf(
- not torch.fbgemm_is_cpu_supported(),
- " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
- " with instruction set support avx2 or newer.",
-)
+@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
+ " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
+ " with instruction set support avx2 or newer.")
class TestQuantizedConv(unittest.TestCase):
"""Tests the correctness of quantized convolution op."""
@given(batch_size=st.integers(1, 3),
diff --git a/test/test_quantized_models.py b/test/test_quantized_models.py
index 6455fa3..d692196 100644
--- a/test/test_quantized_models.py
+++ b/test/test_quantized_models.py
@@ -1,8 +1,12 @@
import torch
import torch.jit
+import unittest
from common_utils import run_tests
from common_quantization import QuantizationTestCase, ModelMultipleOps
+@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
+ "Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
+ " with instruction set support avx2 or newer.")
class ModelNumerics(QuantizationTestCase):
def test_float_quant_compare(self):
torch.manual_seed(42)
diff --git a/test/test_quantized_nn_mods.py b/test/test_quantized_nn_mods.py
index 702c26a..2af6f3b 100644
--- a/test/test_quantized_nn_mods.py
+++ b/test/test_quantized_nn_mods.py
@@ -34,11 +34,9 @@
self.assertEqual(qY, qY_hat)
@no_deadline
- @unittest.skipIf(
- not torch.fbgemm_is_cpu_supported(),
- " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
- " with instruction set support avx2 or newer.",
- )
+ @unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
+ " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
+ " with instruction set support avx2 or newer.")
@given(
use_bias=st.booleans(),
)
@@ -89,11 +87,9 @@
class DynamicModuleAPITest(QuantizationTestCase):
@no_deadline
- @unittest.skipIf(
- not torch.fbgemm_is_cpu_supported(),
- " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
- " with instruction set support avx2 or newer.",
- )
+ @unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
+ " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
+ " with instruction set support avx2 or newer.")
@given(
batch_size=st.integers(1, 5),
in_features=st.integers(16, 32),
@@ -209,11 +205,9 @@
@no_deadline
- @unittest.skipIf(
- not torch.fbgemm_is_cpu_supported(),
- " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
- " with instruction set support avx2 or newer.",
- )
+ @unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
+ " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
+ " with instruction set support avx2 or newer.")
@given(
batch_size=st.integers(1, 5),
in_features=st.integers(16, 32),
@@ -341,11 +335,9 @@
self.assertEqual(rqr, rqr2)
@no_deadline
- @unittest.skipIf(
- not torch.fbgemm_is_cpu_supported(),
- " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
- " with instruction set support avx2 or newer.",
- )
+ @unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
+ " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
+ " with instruction set support avx2 or newer.")
@given(
use_bias=st.booleans(),
use_fused=st.booleans(),
diff --git a/test/test_quantizer.py b/test/test_quantizer.py
index bca792e..361d7d0 100644
--- a/test/test_quantizer.py
+++ b/test/test_quantizer.py
@@ -36,11 +36,9 @@
super(WeightObserver, self).__init__()
self.dtype = torch.qint8
-@unittest.skipIf(
- not torch.fbgemm_is_cpu_supported(),
- " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
- " with instruction set support avx2 or newer.",
-)
+@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
+ " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
+ " with instruction set support avx2 or newer.")
@unittest.skip("temoprarily disable the test")
class QuantizerTestCase(TestCase):
@_tmp_donotuse_dont_inline_everything
diff --git a/test/test_torch.py b/test/test_torch.py
index bba0460..3adf4dd 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -2070,8 +2070,8 @@
test_inference(torch.float64)
test_inference(torch.float32)
- def test_qengnie(self):
- qengines = torch.backends.quantized.get_supported_qengines()
+ def test_qengine(self):
+ qengines = torch.backends.quantized.supported_engines
original_qe = torch.backends.quantized.engine
for qe in qengines:
torch.backends.quantized.engine = qe
@@ -5356,19 +5356,19 @@
self.assertEqual(s1.data_ptr() + 4, s2.data_ptr())
def test_load_unicode_error_msg(self):
- # This Pickle contains a Python 2 module with Unicode data and the
+ # This Pickle contains a Python 2 module with Unicode data and the
# loading should fail if the user explicitly specifies ascii encoding!
path = download_file('https://download.pytorch.org/test_data/legacy_conv2d.pt')
if sys.version_info >= (3, 0):
self.assertRaises(UnicodeDecodeError, lambda: torch.load(path, encoding='ascii'))
else:
# Just checks the module loaded
- self.assertIsNotNone(torch.load(path))
+ self.assertIsNotNone(torch.load(path))
def test_load_python2_unicode_module(self):
# This Pickle contains some Unicode data!
path = download_file('https://download.pytorch.org/test_data/legacy_conv2d.pt')
- self.assertIsNotNone(torch.load(path))
+ self.assertIsNotNone(torch.load(path))
def test_load_error_msg(self):
expected_err_msg = (".*You can only torch.load from a file that is seekable. " +
diff --git a/torch/backends/quantized/__init__.py b/torch/backends/quantized/__init__.py
index d2b8f08..dce924d 100644
--- a/torch/backends/quantized/__init__.py
+++ b/torch/backends/quantized/__init__.py
@@ -4,9 +4,9 @@
import types
# This function should correspond to the enums present in c10/core/QEngine.h
-def get_qengine_id(qengine):
+def _get_qengine_id(qengine):
# type: (str) -> int
- if qengine == 'none':
+ if qengine == 'none' or qengine == '' or qengine is None:
ret = 0
elif qengine == 'fbgemm':
ret = 1
@@ -18,25 +18,25 @@
return ret
# This function should correspond to the enums present in c10/core/QEngine.h
-def get_qengine_str(qengine):
+def _get_qengine_str(qengine):
# type: (int) -> str
all_engines = {0 : 'none', 1 : 'fbgemm', 2 : 'qnnpack'}
return all_engines.get(qengine)
-def get_supported_qengines():
- qengines = torch._C._supported_qengines()
- return [get_qengine_str(qe) for qe in qengines]
-
-class ContextProp(object):
- def __init__(self, getter, setter):
- self.getter = getter
- self.setter = setter
-
+class _QEngineProp(object):
def __get__(self, obj, objtype):
- return get_qengine_str(self.getter())
+ return _get_qengine_str(torch._C._get_qengine())
def __set__(self, obj, val):
- self.setter(get_qengine_id(val))
+ torch._C._set_qengine(_get_qengine_id(val))
+
+class _SupportedQEnginesProp(object):
+ def __get__(self, obj, objtype):
+ qengines = torch._C._supported_qengines()
+ return [_get_qengine_str(qe) for qe in qengines]
+
+ def __set__(self, obj, val):
+ raise RuntimeError("Assignment not supported")
class QuantizedEngine(types.ModuleType):
def __init__(self, m, name):
@@ -45,7 +45,9 @@
def __getattr__(self, attr):
return self.m.__getattribute__(attr)
- engine = ContextProp(torch._C._get_qengine, torch._C._set_qengine)
+
+ engine = _QEngineProp()
+ supported_engines = _SupportedQEnginesProp()
# This is the sys.modules replacement trick, see
# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273