blob: 13c0e0f396d4cf2fcb6ec6772e00553f48395ccc [file] [log] [blame]
# Owner(s): ["module: onnx"]
"""Unit tests for the internal registration wrapper module."""
from typing import Sequence
from torch.onnx import errors
from torch.onnx._internal import registration
from torch.testing._internal import common_utils
@common_utils.instantiate_parametrized_tests
class TestGlobalHelpers(common_utils.TestCase):
@common_utils.parametrize(
"available_opsets, target, expected",
[
((7, 8, 9, 10, 11), 16, 11),
((7, 8, 9, 10, 11), 11, 11),
((7, 8, 9, 10, 11), 10, 10),
((7, 8, 9, 10, 11), 9, 9),
((7, 8, 9, 10, 11), 8, 8),
((7, 8, 9, 10, 11), 7, 7),
((9, 10, 16), 16, 16),
((9, 10, 16), 15, 10),
((9, 10, 16), 10, 10),
((9, 10, 16), 9, 9),
((9, 10, 16), 8, 9),
((9, 10, 16), 7, 9),
((7, 9, 10, 16), 16, 16),
((7, 9, 10, 16), 10, 10),
((7, 9, 10, 16), 9, 9),
((7, 9, 10, 16), 8, 9),
((7, 9, 10, 16), 7, 7),
([17], 16, None), # New op added in 17
([9], 9, 9),
([9], 8, 9),
([], 16, None),
([], 9, None),
([], 8, None),
# Ops registered at opset 1 found as a fallback when target >= 9
([1], 16, 1),
],
)
def test_dispatch_opset_version_returns_correct_version(
self, available_opsets: Sequence[int], target: int, expected: int
):
actual = registration._dispatch_opset_version(target, available_opsets)
self.assertEqual(actual, expected)
class TestOverrideDict(common_utils.TestCase):
def setUp(self):
self.override_dict: registration.OverrideDict[
str, int
] = registration.OverrideDict()
def test_get_item_returns_base_value_when_no_override(self):
self.override_dict.set_base("a", 42)
self.override_dict.set_base("b", 0)
self.assertEqual(self.override_dict["a"], 42)
self.assertEqual(self.override_dict["b"], 0)
self.assertEqual(len(self.override_dict), 2)
def test_get_item_returns_overridden_value_when_override(self):
self.override_dict.set_base("a", 42)
self.override_dict.set_base("b", 0)
self.override_dict.override("a", 100)
self.override_dict.override("c", 1)
self.assertEqual(self.override_dict["a"], 100)
self.assertEqual(self.override_dict["b"], 0)
self.assertEqual(self.override_dict["c"], 1)
self.assertEqual(len(self.override_dict), 3)
def test_get_item_raises_key_error_when_not_found(self):
self.override_dict.set_base("a", 42)
with self.assertRaises(KeyError):
self.override_dict["nonexistent_key"]
def test_get_returns_overridden_value_when_override(self):
self.override_dict.set_base("a", 42)
self.override_dict.set_base("b", 0)
self.override_dict.override("a", 100)
self.override_dict.override("c", 1)
self.assertEqual(self.override_dict.get("a"), 100)
self.assertEqual(self.override_dict.get("b"), 0)
self.assertEqual(self.override_dict.get("c"), 1)
self.assertEqual(len(self.override_dict), 3)
def test_get_returns_none_when_not_found(self):
self.override_dict.set_base("a", 42)
self.assertEqual(self.override_dict.get("nonexistent_key"), None)
def test_in_base_returns_true_for_base_value(self):
self.override_dict.set_base("a", 42)
self.override_dict.set_base("b", 0)
self.override_dict.override("a", 100)
self.override_dict.override("c", 1)
self.assertIn("a", self.override_dict)
self.assertIn("b", self.override_dict)
self.assertIn("c", self.override_dict)
self.assertTrue(self.override_dict.in_base("a"))
self.assertTrue(self.override_dict.in_base("b"))
self.assertFalse(self.override_dict.in_base("c"))
self.assertFalse(self.override_dict.in_base("nonexistent_key"))
def test_overridden_returns_true_for_overridden_value(self):
self.override_dict.set_base("a", 42)
self.override_dict.set_base("b", 0)
self.override_dict.override("a", 100)
self.override_dict.override("c", 1)
self.assertTrue(self.override_dict.overridden("a"))
self.assertFalse(self.override_dict.overridden("b"))
self.assertTrue(self.override_dict.overridden("c"))
self.assertFalse(self.override_dict.overridden("nonexistent_key"))
def test_remove_override_removes_overridden_value(self):
self.override_dict.set_base("a", 42)
self.override_dict.set_base("b", 0)
self.override_dict.override("a", 100)
self.override_dict.override("c", 1)
self.assertEqual(self.override_dict["a"], 100)
self.assertEqual(self.override_dict["c"], 1)
self.override_dict.remove_override("a")
self.override_dict.remove_override("c")
self.assertEqual(self.override_dict["a"], 42)
self.assertEqual(self.override_dict.get("c"), None)
self.assertFalse(self.override_dict.overridden("a"))
self.assertFalse(self.override_dict.overridden("c"))
def test_remove_override_removes_overridden_key(self):
self.override_dict.override("a", 100)
self.assertEqual(self.override_dict["a"], 100)
self.assertEqual(len(self.override_dict), 1)
self.override_dict.remove_override("a")
self.assertEqual(len(self.override_dict), 0)
self.assertNotIn("a", self.override_dict)
def test_overriden_key_precededs_base_key_regardless_of_insert_order(self):
self.override_dict.set_base("a", 42)
self.override_dict.override("a", 100)
self.override_dict.set_base("a", 0)
self.assertEqual(self.override_dict["a"], 100)
self.assertEqual(len(self.override_dict), 1)
def test_bool_is_true_when_not_empty(self):
if self.override_dict:
self.fail("OverrideDict should be false when empty")
self.override_dict.override("a", 1)
if not self.override_dict:
self.fail("OverrideDict should be true when not empty")
self.override_dict.set_base("a", 42)
if not self.override_dict:
self.fail("OverrideDict should be true when not empty")
self.override_dict.remove_override("a")
if not self.override_dict:
self.fail("OverrideDict should be true when not empty")
class TestRegistrationDecorators(common_utils.TestCase):
def tearDown(self) -> None:
registration.registry._registry.pop("test::test_op", None)
def test_onnx_symbolic_registers_function(self):
self.assertFalse(registration.registry.is_registered_op("test::test_op", 9))
@registration.onnx_symbolic("test::test_op", opset=9)
def test(g, x):
return g.op("test", x)
self.assertTrue(registration.registry.is_registered_op("test::test_op", 9))
function_group = registration.registry.get_function_group("test::test_op")
assert function_group is not None
self.assertEqual(function_group.get(9), test)
def test_onnx_symbolic_registers_function_applied_decorator_when_provided(self):
wrapper_called = False
def decorator(func):
def wrapper(*args, **kwargs):
nonlocal wrapper_called
wrapper_called = True
return func(*args, **kwargs)
return wrapper
@registration.onnx_symbolic("test::test_op", opset=9, decorate=[decorator])
def test():
return
function_group = registration.registry.get_function_group("test::test_op")
assert function_group is not None
registered_function = function_group[9]
self.assertFalse(wrapper_called)
registered_function()
self.assertTrue(wrapper_called)
def test_onnx_symbolic_raises_warning_when_overriding_function(self):
self.assertFalse(registration.registry.is_registered_op("test::test_op", 9))
@registration.onnx_symbolic("test::test_op", opset=9)
def test1():
return
with self.assertWarnsRegex(
errors.OnnxExporterWarning,
"Symbolic function 'test::test_op' already registered",
):
@registration.onnx_symbolic("test::test_op", opset=9)
def test2():
return
def test_custom_onnx_symbolic_registers_custom_function(self):
self.assertFalse(registration.registry.is_registered_op("test::test_op", 9))
@registration.custom_onnx_symbolic("test::test_op", opset=9)
def test(g, x):
return g.op("test", x)
self.assertTrue(registration.registry.is_registered_op("test::test_op", 9))
function_group = registration.registry.get_function_group("test::test_op")
assert function_group is not None
self.assertEqual(function_group.get(9), test)
def test_custom_onnx_symbolic_overrides_existing_function(self):
self.assertFalse(registration.registry.is_registered_op("test::test_op", 9))
@registration.onnx_symbolic("test::test_op", opset=9)
def test_original():
return "original"
self.assertTrue(registration.registry.is_registered_op("test::test_op", 9))
@registration.custom_onnx_symbolic("test::test_op", opset=9)
def test_custom():
return "custom"
function_group = registration.registry.get_function_group("test::test_op")
assert function_group is not None
self.assertEqual(function_group.get(9), test_custom)
if __name__ == "__main__":
common_utils.run_tests()