[quant][fx][graphmode] Renable torchvision test (#48602)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48602
Test Plan: Imported from OSS
Reviewed By: vkuzo
Differential Revision: D25224917
fbshipit-source-id: efc73f425253c4eb7ae51064b6760416097f0437
diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py
index 27064c4..7e4048b 100644
--- a/test/quantization/test_quantize_fx.py
+++ b/test/quantization/test_quantize_fx.py
@@ -2103,15 +2103,14 @@
original_out = model(input_value)
is_not_tuple_out = not isinstance(original_out, tuple)
script_out = script(input_value)
- self.assertEqual(
- (original_out - script_out).abs().max(), 0,
- 'Reslut of original graph module and script module does not match')
# set to train just before quantization
+ prepare_fx_fn = prepare_fx
if mode != 'static':
model.train()
+ prepare_fx_fn = prepare_qat_fx
- prepared = prepare_fx(model, qconfig_dict)
+ prepared = prepare_fx_fn(model, qconfig_dict)
if mode == 'ddp':
mp.spawn(run_ddp,
@@ -2207,15 +2206,11 @@
quantized_model_list = set(quantized_model_list) - no_pretrained_model
# test eager and graph consistency
model_list = quantized_model_list
- # slice need to be fixed in symbolic tracing(https://github.com/pytorch/pytorch/issues/43511)
- model_list = set(model_list) - {'googlenet', 'inception_v3'}
- # getattr should not be used as node name(https://github.com/pytorch/pytorch/issues/43522)
- model_list -= {'shufflenet_v2_x1_0', 'mobilenet_v2'}
-
+ # inception_v3 is not symbolically traceable: https://github.com/pytorch/pytorch/issues/48813
+ model_list = set(model_list) - {'inception_v3'}
# mobilenet: dropout error RuntimeError: "bernoulli_scalar_cpu_" not implemented for 'QUInt8'
# incpetion_v3: looks like there is some problem with AuxLogits
- quantized_not_working = [('qat', 'mobilenet_v2'),
- ('qat', 'inception_v3'),
+ quantized_not_working = [('qat', 'inception_v3'),
('static', 'inception_v3')]
fx_eager_not_matching = ['googlenet', # because _transform_input is not quantized in eager
@@ -2257,7 +2252,6 @@
@skip_if_no_torchvision
@skip_if_not_multigpu
@skipIfNoFBGEMM
- @unittest.skip('TODO: not working yet due to https://github.com/pytorch/pytorch/issues/43513')
def test_resnet18_ddp(self):
from torchvision import models
from torchvision.models import quantization as quantized_models
diff --git a/torch/quantization/quantize_fx.py b/torch/quantization/quantize_fx.py
index ba1f58a..77f598e 100644
--- a/torch/quantization/quantize_fx.py
+++ b/torch/quantization/quantize_fx.py
@@ -254,7 +254,7 @@
```
"""
torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_fx")
- assert not model.training, 'prepare_fx only works for models in' + \
+ assert not model.training, 'prepare_fx only works for models in ' + \
'eval mode'
return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict)
@@ -291,7 +291,7 @@
```
"""
torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_qat_fx")
- assert model.training, 'prepare_qat_fx only works for models in ' + \
+ assert model.training, 'prepare_qat_fx only works for models in ' + \
'train mode'
return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict)
diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py
index 2a1e2b6..e44d5df 100644
--- a/torch/testing/_internal/common_distributed.py
+++ b/torch/testing/_internal/common_distributed.py
@@ -63,15 +63,17 @@
def skip_if_not_multigpu(func):
"""Multi-GPU tests requires at least 2 GPUS. Skip if this is not met."""
- @wraps(func)
- def wrapper(*args, **kwargs):
- if torch.cuda.is_available() and torch.cuda.device_count() >= 2:
- return func(*args, **kwargs)
- message = "Need at least {} CUDA devices".format(2)
- TEST_SKIPS["multi-gpu"] = TestSkip(75, message)
- sys.exit(TEST_SKIPS['multi-gpu'].exit_code)
+ def decorator(func):
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ if torch.cuda.is_available() and torch.cuda.device_count() >= 2:
+ return func(*args, **kwargs)
+ message = "Need at least {} CUDA devices".format(2)
+ TEST_SKIPS["multi-gpu"] = TestSkip(75, message)
+ sys.exit(TEST_SKIPS['multi-gpu'].exit_code)
+ return wrapper
- return wrapper
+ return decorator
def require_n_gpus_for_nccl_backend(n, backend):
def decorator(func):