[ONNX] move CheckerError from torch.onnx.utils to torch.onnx (#66644)
Summary:
This moves it to where the user would expect it to be based on the
documentation and all the other public classes in the torch.onnx module.
Also rename it from ONNXCheckerError, since the qualified name
torch.onnx.ONNXCheckerError is otherwise redundant.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66644
Reviewed By: malfet
Differential Revision: D31662559
Pulled By: msaroufim
fbshipit-source-id: bc8a57b99c2980490ede3974279d1124228a7406
diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py
index 6a114b2..ba4319b 100644
--- a/test/onnx/test_pytorch_onnx_onnxruntime.py
+++ b/test/onnx/test_pytorch_onnx_onnxruntime.py
@@ -36,9 +36,8 @@
from collections import OrderedDict
from torch.nn.utils.rnn import PackedSequence
-from torch.onnx import register_custom_op_symbolic, unregister_custom_op_symbolic
+from torch.onnx import CheckerError, register_custom_op_symbolic, unregister_custom_op_symbolic
from torch.onnx.symbolic_helper import _unimplemented
-from torch.onnx.utils import ONNXCheckerError
def flatten_tuples(elem):
@@ -9965,7 +9964,7 @@
f = io.BytesIO()
try:
- with self.assertRaises(ONNXCheckerError) as cm:
+ with self.assertRaises(CheckerError) as cm:
torch.onnx.export(test_model, (x, y), f)
finally:
unregister_custom_op_symbolic("::add", 1)
diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py
index c7e7487..8faab8d 100644
--- a/torch/onnx/__init__.py
+++ b/torch/onnx/__init__.py
@@ -23,6 +23,10 @@
DIRECTORY = 4
+class CheckerError(Exception):
+ pass
+
+
def _export(*args, **kwargs):
from torch.onnx import utils
result = utils._export(*args, **kwargs)
@@ -309,7 +313,7 @@
This argument is ignored unless ``operator_export_type=OperatorExportTypes.ONNX``.
Raises:
- ONNXCheckerError: If the ONNX checker detects an invalid ONNX graph. Will still export the
+ CheckerError: If the ONNX checker detects an invalid ONNX graph. Will still export the
model to the file ``f`` even if this is raised.
"""
diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py
index d6359ae..c4ce274 100644
--- a/torch/onnx/utils.py
+++ b/torch/onnx/utils.py
@@ -17,7 +17,7 @@
import warnings
from torch._six import string_classes
from torch.jit import _unique_state_dict
-from torch.onnx import ONNX_ARCHIVE_MODEL_PROTO_NAME, ExportTypes, OperatorExportTypes, TrainingMode
+from torch.onnx import ONNX_ARCHIVE_MODEL_PROTO_NAME, ExportTypes, OperatorExportTypes, TrainingMode, CheckerError
from torch._C import ListType, OptionalType, _propagate_and_assign_input_shapes, _check_onnx_proto
from typing import List, Tuple, Union
@@ -26,9 +26,6 @@
__IN_ONNX_EXPORT = False
-class ONNXCheckerError(Exception):
- pass
-
def is_in_onnx_export():
global __IN_ONNX_EXPORT
return __IN_ONNX_EXPORT
@@ -89,7 +86,7 @@
if enable_onnx_checker is not None:
warnings.warn("'enable_onnx_checker' is deprecated and ignored. It will be removed in "
"the next PyTorch release. To proceed despite ONNX checker failures, "
- "catch torch.onnx.ONNXCheckerError.")
+ "catch torch.onnx.CheckerError.")
if _retain_param_name is not None:
warnings.warn("'_retain_param_name' is deprecated and ignored. "
"It will be removed in the next PyTorch release.")
@@ -779,7 +776,7 @@
try:
_check_onnx_proto(proto)
except RuntimeError as e:
- raise ONNXCheckerError(e)
+ raise CheckerError(e)
finally:
assert __IN_ONNX_EXPORT
__IN_ONNX_EXPORT = False