blob: 81833258762b10a5c45f59b28563e0d799fa4d7b [file] [log] [blame]
# Owner(s): ["module: onnx"]
import contextlib
import dataclasses
import io
import typing
import unittest
from typing import AbstractSet, Tuple
import torch
from torch.onnx import errors
from torch.onnx._internal import diagnostics
from torch.onnx._internal.diagnostics import infra
from torch.testing._internal import common_utils
def _assert_has_diagnostics(
engine: infra.DiagnosticEngine,
rule_level_pairs: AbstractSet[Tuple[infra.Rule, infra.Level]],
):
sarif_log = engine.sarif_log()
unseen_pairs = {(rule.id, level.name.lower()) for rule, level in rule_level_pairs}
actual_results = []
for run in sarif_log.runs:
if run.results is None:
continue
for result in run.results:
id_level_pair = (result.rule_id, result.level)
unseen_pairs.discard(id_level_pair)
actual_results.append(id_level_pair)
if unseen_pairs:
raise AssertionError(
f"Expected diagnostic results of rule id and level pair {unseen_pairs} not found. "
f"Actual diagnostic results: {actual_results}"
)
@contextlib.contextmanager
def assert_all_diagnostics(
test_suite: unittest.TestCase,
engine: infra.DiagnosticEngine,
rule_level_pairs: AbstractSet[Tuple[infra.Rule, infra.Level]],
):
"""Context manager to assert that all diagnostics are emitted.
Usage:
with assert_all_diagnostics(
self,
diagnostics.engine,
{(rule, infra.Level.Error)},
):
torch.onnx.export(...)
Args:
test_suite: The test suite instance.
engine: The diagnostic engine.
rule_level_pairs: A set of rule and level pairs to assert.
Returns:
A context manager.
Raises:
AssertionError: If not all diagnostics are emitted.
"""
try:
yield
except errors.OnnxExporterError:
test_suite.assertIn(infra.Level.ERROR, {level for _, level in rule_level_pairs})
finally:
_assert_has_diagnostics(engine, rule_level_pairs)
def assert_diagnostic(
test_suite: unittest.TestCase,
engine: infra.DiagnosticEngine,
rule: infra.Rule,
level: infra.Level,
):
"""Context manager to assert that a diagnostic is emitted.
Usage:
with assert_diagnostic(
self,
diagnostics.engine,
rule,
infra.Level.Error,
):
torch.onnx.export(...)
Args:
test_suite: The test suite instance.
engine: The diagnostic engine.
rule: The rule to assert.
level: The level to assert.
Returns:
A context manager.
Raises:
AssertionError: If the diagnostic is not emitted.
"""
return assert_all_diagnostics(test_suite, engine, {(rule, level)})
class TestOnnxDiagnostics(common_utils.TestCase):
"""Test cases for diagnostics emitted by the ONNX export code."""
def setUp(self):
engine = diagnostics.engine
engine.clear()
self._sample_rule = diagnostics.rules.missing_custom_symbolic_function
super().setUp()
def _trigger_node_missing_onnx_shape_inference_warning_diagnostic_from_cpp(
self,
) -> diagnostics.ExportDiagnostic:
class CustomAdd(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y):
return x + y
@staticmethod
def symbolic(g, x, y):
return g.op("custom::CustomAdd", x, y)
class M(torch.nn.Module):
def forward(self, x):
return CustomAdd.apply(x, x)
# trigger warning for missing shape inference.
rule = diagnostics.rules.node_missing_onnx_shape_inference
torch.onnx.export(M(), torch.randn(3, 4), io.BytesIO())
context = diagnostics.engine.contexts[-1]
for diagnostic in context.diagnostics:
if (
diagnostic.rule == rule
and diagnostic.level == diagnostics.levels.WARNING
):
return typing.cast(diagnostics.ExportDiagnostic, diagnostic)
raise AssertionError("No diagnostic found.")
def test_assert_diagnostic_raises_when_diagnostic_not_found(self):
with self.assertRaises(AssertionError):
with assert_diagnostic(
self,
diagnostics.engine,
diagnostics.rules.node_missing_onnx_shape_inference,
diagnostics.levels.WARNING,
):
pass
def test_cpp_diagnose_emits_warning(self):
with assert_diagnostic(
self,
diagnostics.engine,
diagnostics.rules.node_missing_onnx_shape_inference,
diagnostics.levels.WARNING,
):
# trigger warning for missing shape inference.
self._trigger_node_missing_onnx_shape_inference_warning_diagnostic_from_cpp()
def test_py_diagnose_emits_error(self):
class M(torch.nn.Module):
def forward(self, x):
return torch.diagonal(x)
with assert_diagnostic(
self,
diagnostics.engine,
diagnostics.rules.operator_supported_in_newer_opset_version,
diagnostics.levels.ERROR,
):
# trigger error for operator unsupported until newer opset version.
torch.onnx.export(
M(),
torch.randn(3, 4),
io.BytesIO(),
opset_version=9,
)
def test_diagnostics_engine_records_diagnosis_reported_outside_of_export(
self,
):
sample_level = diagnostics.levels.ERROR
with assert_diagnostic(
self,
diagnostics.engine,
self._sample_rule,
sample_level,
):
diagnostics.context.diagnose(self._sample_rule, sample_level)
def test_diagnostics_records_python_call_stack(self):
diagnostic = diagnostics.ExportDiagnostic(self._sample_rule, diagnostics.levels.NOTE) # fmt: skip
# Do not break the above line, otherwise it will not work with Python-3.8+
stack = diagnostic.python_call_stack
assert stack is not None # for mypy
self.assertGreater(len(stack.frames), 0)
frame = stack.frames[0]
assert frame.location.snippet is not None # for mypy
self.assertIn("self._sample_rule", frame.location.snippet)
assert frame.location.uri is not None # for mypy
self.assertIn("test_diagnostics.py", frame.location.uri)
def test_diagnostics_records_cpp_call_stack(self):
diagnostic = (
self._trigger_node_missing_onnx_shape_inference_warning_diagnostic_from_cpp()
)
stack = diagnostic.cpp_call_stack
assert stack is not None # for mypy
self.assertGreater(len(stack.frames), 0)
frame_messages = [frame.location.message for frame in stack.frames]
# node missing onnx shape inference warning only comes from ToONNX (_jit_pass_onnx)
# after node-level shape type inference and processed symbolic_fn output type
self.assertTrue(
any(
isinstance(message, str) and "torch::jit::NodeToONNX" in message
for message in frame_messages
)
)
@dataclasses.dataclass
class _RuleCollectionForTest(infra.RuleCollection):
rule_without_message_args: infra.Rule = dataclasses.field(
default=infra.Rule(
"1",
"rule-without-message-args",
message_default_template="rule message",
)
)
class TestDiagnosticsInfra(common_utils.TestCase):
"""Test cases for diagnostics infra."""
def setUp(self):
self.engine = infra.DiagnosticEngine()
self.rules = _RuleCollectionForTest()
with contextlib.ExitStack() as stack:
self.context = stack.enter_context(
self.engine.create_diagnostic_context("test", "1.0.0")
)
self.addCleanup(stack.pop_all().close)
return super().setUp()
def test_diagnostics_engine_records_diagnosis_reported_in_nested_contexts(
self,
):
with self.engine.create_diagnostic_context("inner_test", "1.0.1") as context:
context.diagnose(self.rules.rule_without_message_args, infra.Level.WARNING)
sarif_log = self.engine.sarif_log()
self.assertEqual(len(sarif_log.runs), 2)
self.assertEqual(len(sarif_log.runs[0].results), 0)
self.assertEqual(len(sarif_log.runs[1].results), 1)
self.context.diagnose(self.rules.rule_without_message_args, infra.Level.ERROR)
sarif_log = self.engine.sarif_log()
self.assertEqual(len(sarif_log.runs), 2)
self.assertEqual(len(sarif_log.runs[0].results), 1)
self.assertEqual(len(sarif_log.runs[1].results), 1)
def test_diagnostics_engine_records_diagnosis_with_custom_rules(self):
custom_rules = infra.RuleCollection.custom_collection_from_list(
"CustomRuleCollection",
[
infra.Rule(
"1",
"custom-rule",
message_default_template="custom rule message",
),
infra.Rule(
"2",
"custom-rule-2",
message_default_template="custom rule message 2",
),
],
)
with self.engine.create_diagnostic_context(
"custom_rules", "1.0"
) as diagnostic_context:
with assert_all_diagnostics(
self,
self.engine,
{
(custom_rules.custom_rule, infra.Level.WARNING), # type: ignore[attr-defined]
(custom_rules.custom_rule_2, infra.Level.ERROR), # type: ignore[attr-defined]
},
):
diagnostic_context.diagnose(
custom_rules.custom_rule, infra.Level.WARNING # type: ignore[attr-defined]
)
diagnostic_context.diagnose(
custom_rules.custom_rule_2, infra.Level.ERROR # type: ignore[attr-defined]
)
if __name__ == "__main__":
common_utils.run_tests()