blob: e94c031d1bde8987db0ed85144bc8b9b0accf61d [file] [log] [blame]
# Owner(s): ["module: onnx"]
"""Unit tests for the internal registration wrapper module."""
from __future__ import annotations
import operator
from typing import TypeVar, Union
import onnxscript # type: ignore[import]
from onnxscript import BFLOAT16, DOUBLE, FLOAT, FLOAT16 # type: ignore[import]
from onnxscript.function_libs.torch_lib import ops # type: ignore[import]
from onnxscript.onnx_opset import opset15 as op # type: ignore[import]
import torch
import torch.fx
from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.fx import (
analysis,
diagnostics,
onnxfunction_dispatcher,
registration,
)
from torch.testing._internal import common_utils
# TODO: this can only be global. https://github.com/microsoft/onnxscript/issues/805
TCustomFloat = TypeVar("TCustomFloat", bound=Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16])
class TestRegistration(common_utils.TestCase):
def setUp(self) -> None:
self.registry = torch.onnx.OnnxRegistry()
self.custom_domain = onnxscript.values.Opset(domain="custom", version=1)
def tearDown(self) -> None:
internal_name_instance = registration.OpName.from_name_parts(
namespace="test", op_name="test_op"
)
self.registry._registry.pop(internal_name_instance, None)
def test_register_custom_op_registers_custom_function(self):
self.assertFalse(self.registry.is_registered_op("test", "test_op", "default"))
@onnxscript.script(self.custom_domain)
def custom_add(x, y):
return op.Add(x, y)
self.registry.register_op(custom_add, "test", "test_op", "default")
self.assertTrue(self.registry.is_registered_op("test", "test_op", "default"))
# Test on get_ops
function_group = self.registry.get_op_functions("test", "test_op", "default")
self.assertIsNotNone(function_group)
self.assertEqual({func.onnx_function for func in function_group}, {custom_add}) # type: ignore[arg-type]
def test_custom_onnx_symbolic_joins_existing_function(self):
self.assertFalse(self.registry.is_registered_op("test", "test_op"))
@onnxscript.script(self.custom_domain)
def test_original(x, y):
return op.Add(x, y)
# default has to be specified, as we are not using the registration.OpName
internal_name_instance = registration.OpName.from_name_parts(
namespace="test", op_name="test_op", overload="default"
)
symbolic_fn = registration.ONNXFunction(
test_original, op_full_name=internal_name_instance.qualified_name()
)
self.registry._register(internal_name_instance, symbolic_fn)
self.assertTrue(self.registry.is_registered_op("test", "test_op"))
@onnxscript.script(self.custom_domain)
def test_custom(x, y):
return op.Add(x, y)
self.registry.register_op(test_custom, "test", "test_op")
function_group = self.registry.get_op_functions("test", "test_op")
assert function_group is not None
# The order does matter (list)
self.assertEqual(
[func.onnx_function for func in function_group],
[test_original, test_custom],
)
def test_unsupported_nodes_analysis_with_missing_aten_op(self):
# NOTE: simulate unsupported nodes
aten_mul_tensor = registration.OpName.from_name_parts(
namespace="aten", op_name="mul", overload="Tensor"
)
aten_mul_default = registration.OpName.from_name_parts(
namespace="aten", op_name="mul"
)
aten_add_tensor = registration.OpName.from_name_parts(
namespace="aten", op_name="add", overload="Tensor"
)
aten_add_default = registration.OpName.from_name_parts(
namespace="aten", op_name="add"
)
self.registry._registry.pop(aten_mul_tensor)
self.registry._registry.pop(aten_mul_default)
self.registry._registry.pop(aten_add_tensor)
self.registry._registry.pop(aten_add_default)
diagnostic_context = diagnostics.DiagnosticContext(
"torch.onnx.dynamo_export", torch.__version__
)
dispatcher = onnxfunction_dispatcher.OnnxFunctionDispatcher(
self.registry, diagnostic_context
)
graph: torch.fx.Graph = torch.fx.Graph()
x: torch.fx.Node = graph.create_node("placeholder", "x")
x.meta["val"] = torch.tensor(3.0)
b: torch.fx.Node = graph.create_node(
"call_function", target=torch.ops.aten.mul.Tensor, args=(x, x)
)
c: torch.fx.Node = graph.create_node(
"call_function", target=torch.ops.aten.add.Tensor, args=(b, b)
)
output: torch.fx.Node = graph.output(c)
module = torch.fx.GraphModule(torch.nn.Module(), graph)
with self.assertRaises(infra.RuntimeErrorWithDiagnostic):
analysis.UnsupportedFxNodesAnalysis(
diagnostic_context, module, dispatcher
).analyze(infra.levels.ERROR)
try:
analysis.UnsupportedFxNodesAnalysis(
diagnostic_context, module, dispatcher
).analyze(infra.levels.ERROR)
except infra.RuntimeErrorWithDiagnostic as e:
self.assertIn(
"Unsupported FX nodes: {'call_function': ['aten.mul.Tensor', 'aten.add.Tensor']}.",
e.diagnostic.message,
)
@common_utils.instantiate_parametrized_tests
class TestDispatcher(common_utils.TestCase):
def setUp(self):
self.registry = torch.onnx.OnnxRegistry()
self.diagnostic_context = diagnostics.DiagnosticContext(
"torch.onnx.dynamo_export", torch.__version__
)
self.dispatcher = onnxfunction_dispatcher.OnnxFunctionDispatcher(
self.registry, self.diagnostic_context
)
@common_utils.parametrize(
"node, expected_name",
[
common_utils.subtest(
(
torch.fx.Node(
graph=torch.fx.Graph(),
name="aten::add.Tensor",
op="call_function",
target=torch.ops.aten.add.Tensor, # type: ignore[attr-defined]
args=(torch.tensor(3), torch.tensor(4)),
kwargs={},
),
("aten", "add", "Tensor"),
),
name="get_Opoverload_name",
),
common_utils.subtest(
(
torch.fx.Node(
graph=torch.fx.Graph(),
name="aten::sym_size",
op="call_function",
target=torch.ops.aten.sym_size,
args=(),
kwargs={},
),
("aten", "sym_size", None),
),
name="get_Opoverloadpacket_name",
),
common_utils.subtest(
(
torch.fx.Node(
graph=torch.fx.Graph(),
name="builtin_add",
op="call_function",
target=operator.add,
args=(1, 2),
kwargs={},
),
("_operator", "add", None),
),
name="get_builtin_op_name",
),
],
)
def test_get_aten_name_on_supported_fx_node(
self, node: torch.fx.Node, expected_name: str
):
expected_name_class = registration.OpName.from_name_parts(*expected_name)
self.assertEqual(
self.dispatcher._get_aten_name(node, self.diagnostic_context),
expected_name_class,
)
@common_utils.parametrize(
"node",
[
common_utils.subtest(
torch.fx.Node(
graph=torch.fx.Graph(),
name="aten::add",
op="call_function",
target=torch.ops.aten.add,
args=(),
kwargs={},
),
name="unsupported_Opoverloadpacket_name",
),
common_utils.subtest(
torch.fx.Node(
graph=torch.fx.Graph(),
name="builtin_add",
op="call_function",
target=operator.add,
args=("A", "B"),
kwargs={},
),
name="unsupported_input_dtypes_for_builtin_op",
),
common_utils.subtest(
torch.fx.Node(
graph=torch.fx.Graph(),
name="aten::made_up_node",
op="call_function",
target=lambda: None,
args=(),
kwargs={},
),
name="unsupported_target_function",
),
],
)
def test_get_aten_name_on_unsupported_fx_node(self, node: torch.fx.Node):
with self.assertRaises(RuntimeError):
self.dispatcher._get_aten_name(node, self.diagnostic_context)
def test_get_function_overloads_gives_overload_fall_back_default(self):
# Test fall back to default op name
node_overload = torch.fx.Node(
graph=torch.fx.Graph(),
name="aten::add.Tensor",
op="call_function",
target=torch.ops.aten.add.Tensor, # type: ignore[attr-defined]
args=(torch.tensor(3), torch.tensor(4)),
kwargs={},
)
node_overloadpacket = torch.fx.Node(
graph=torch.fx.Graph(),
name="aten::add",
op="call_function",
target=torch.ops.aten.add.Tensor, # type: ignore[attr-defined]
args=(),
kwargs={},
)
self.assertEqual(
self.dispatcher.get_function_overloads(
node_overload, self.diagnostic_context
),
self.dispatcher.get_function_overloads(
node_overloadpacket,
self.diagnostic_context,
),
)
# Non-registered op
unsupported_op_node = torch.fx.Node(
graph=torch.fx.Graph(),
name="aten::made_up_node",
op="call_function",
target=lambda: None,
args=(),
kwargs={},
)
with self.assertRaises(RuntimeError):
self.dispatcher.get_function_overloads(
unsupported_op_node,
self.diagnostic_context,
)
@common_utils.parametrize(
"node",
[
common_utils.subtest(
torch.fx.Node(
graph=torch.fx.Graph(),
name="aten::add.Tensor",
op="call_function",
target=torch.ops.aten.add.Tensor, # type: ignore[attr-defined]
args=(torch.tensor(3.0), torch.tensor(4.0)),
kwargs={},
),
name="nearest_match",
),
common_utils.subtest(
torch.fx.Node(
graph=torch.fx.Graph(),
name="aten::add.Tensor",
op="call_function",
target=torch.ops.aten.add.Tensor, # type: ignore[attr-defined]
args=(torch.tensor(3.0), torch.tensor(4.0)),
kwargs={"alpha": 1},
),
name="perfect_match_with_kwargs",
),
],
)
def test_find_the_perfect_or_nearest_match_onnxfunction_gives_custom_ops_precedence(
self, node
):
custom_domain = onnxscript.values.Opset(domain="custom", version=1)
@onnxscript.script(custom_domain)
def test_custom_op(
x: TCustomFloat, y: TCustomFloat, alpha: int = 1
) -> TCustomFloat:
return op.Add(x, y)
@onnxscript.script(custom_domain)
def test_default_op(
x: TCustomFloat, y: TCustomFloat, alpha: int = 1
) -> TCustomFloat:
return op.Add(x, y)
op_full_name = "test::test_op"
custom_overloads = [
registration.ONNXFunction(
test_custom_op, op_full_name=op_full_name, is_custom=True
)
]
function_overloads = [
registration.ONNXFunction(test_default_op, op_full_name=op_full_name)
] + custom_overloads
symbolic_fn = self.dispatcher._find_the_perfect_or_nearest_match_onnxfunction(
node,
function_overloads,
node.args,
node.kwargs,
self.diagnostic_context,
)
self.assertEqual(symbolic_fn, test_custom_op)
@common_utils.parametrize(
"node",
[
common_utils.subtest(
torch.fx.Node(
graph=torch.fx.Graph(),
name="aten::add.Tensor",
op="call_function",
target=torch.ops.aten.add.Tensor, # type: ignore[attr-defined]
args=(torch.tensor(3.0), torch.tensor(4.0)),
kwargs={"attr": None},
),
name="perfect_match_with_ignoring_none_attribute",
),
common_utils.subtest(
torch.fx.Node(
graph=torch.fx.Graph(),
name="aten::add.Tensor",
op="call_function",
target=torch.ops.aten.add.Tensor, # type: ignore[attr-defined]
args=(torch.tensor(3.0), torch.tensor(4.0)),
kwargs={"unrelated": None},
),
name="perfect_match_with_ignoring_unrelated_none_attribute",
),
],
)
def test_find_the_perfect_or_nearest_match_onnxfunction_ignores_attribute_with_none(
self, node
):
custom_domain = onnxscript.values.Opset(domain="custom", version=1)
@onnxscript.script(custom_domain)
def test_op_attribute(
x: TCustomFloat, y: TCustomFloat, attr: int
) -> TCustomFloat:
return op.Add(x, y)
@onnxscript.script(custom_domain)
def test_op(x: TCustomFloat, y: TCustomFloat) -> TCustomFloat:
return op.Add(x, y)
op_full_name = "test::test_op"
function_overloads = [
registration.ONNXFunction(test_op_attribute, op_full_name=op_full_name),
registration.ONNXFunction(test_op, op_full_name=op_full_name),
]
symbolic_fn = self.dispatcher._find_the_perfect_or_nearest_match_onnxfunction(
node,
function_overloads,
node.args,
node.kwargs,
self.diagnostic_context,
)
self.assertEqual(symbolic_fn, test_op)
@common_utils.parametrize(
"node",
[
common_utils.subtest(
torch.fx.Node(
graph=torch.fx.Graph(),
name="aten::add.Tensor",
op="call_function",
target=torch.ops.aten.add.Tensor, # type: ignore[attr-defined]
args=(torch.tensor(3.0), torch.tensor(4.0)),
kwargs={},
),
name="nearest_match",
),
common_utils.subtest(
torch.fx.Node(
graph=torch.fx.Graph(),
name="aten::add.Tensor",
op="call_function",
target=torch.ops.aten.add.Tensor, # type: ignore[attr-defined]
args=(torch.tensor(3.0), torch.tensor(4.0)),
kwargs={"alpha": 1},
),
name="perfect_match_with_kwargs",
),
],
)
def test_find_the_perfect_or_nearest_match_onnxfunction_gives_tie_breaks_to_registered_order(
self, node
):
custom_domain = onnxscript.values.Opset(domain="custom", version=1)
@onnxscript.script(custom_domain)
def test_second_custom_op(
x: TCustomFloat, y: TCustomFloat, alpha: int = 1
) -> TCustomFloat:
return op.Add(x, y)
@onnxscript.script(custom_domain)
def test_third_custom_op(
x: TCustomFloat, y: TCustomFloat, alpha: int = 1
) -> TCustomFloat:
return op.Add(x, y)
@onnxscript.script(custom_domain)
def test_first_custom_op(
x: TCustomFloat, y: TCustomFloat, alpha: int = 1
) -> TCustomFloat:
return op.Add(x, y)
op_full_name = "aten::add"
function_overloads = [
registration.ONNXFunction(
test_first_custom_op, op_full_name=op_full_name, is_custom=True
),
registration.ONNXFunction(
test_second_custom_op, op_full_name=op_full_name, is_custom=True
),
registration.ONNXFunction(
test_third_custom_op, op_full_name=op_full_name, is_custom=True
),
]
symbolic_fn = self.dispatcher._find_the_perfect_or_nearest_match_onnxfunction(
node,
function_overloads,
node.args,
node.kwargs,
self.diagnostic_context,
)
self.assertEqual(symbolic_fn, test_third_custom_op)
@common_utils.instantiate_parametrized_tests
class TestOpSchemaWrapper(common_utils.TestCase):
def setUp(self):
# overload type: optional dtype
self.onnx_function_new_full = ops.core.aten_new_full
self.onnx_function_new_full_dtype = ops.core.aten_new_full_dtype
@common_utils.parametrize(
"inputs, attributes, assertion",
[
common_utils.subtest(
([torch.randn(3, 4), torch.randn(3, 4)], {"alpha": 2.0}, True),
name="perfect_match_with_kwargs",
),
common_utils.subtest(
(["A", "B"], {}, False),
name="non_perfect_match_due_to_non_tensor_inputs",
),
common_utils.subtest(
([torch.randn(3, 4), torch.randn(3, 4), torch.randn(3, 4)], {}, False),
name="non_perfect_match_due_to_too_many_inputs",
),
common_utils.subtest(
([torch.randn(3, 4), torch.randn(3, 4)], {"wrong_kwargs": 2.0}, False),
name="non_perfect_match_due_to_wrong_kwargs",
),
],
)
def test_perfect_match_inputs(self, inputs, attributes, assertion):
# OnnxFunction with default attributes
dummy_diagnostic = diagnostics.Diagnostic(
rule=diagnostics.rules.find_opschema_matched_symbolic_function,
level=diagnostics.levels.WARNING,
)
op_schema_wrapper_add = onnxfunction_dispatcher._OnnxSchemaChecker(
ops.core.aten_add
)
self.assertEqual(
op_schema_wrapper_add.perfect_match_inputs(
dummy_diagnostic, inputs, attributes
),
assertion,
)
@common_utils.parametrize(
"inputs, kwargs, op, score",
[
common_utils.subtest(
([torch.randn(3, 4), torch.randn(3, 4)], {}, ops.core.aten_mul, 2),
name="match_2_inputs",
),
common_utils.subtest(
(
[
torch.randint(0, 2, size=(3, 4), dtype=torch.int).bool(),
torch.randint(0, 2, size=(3, 4), dtype=torch.int).bool(),
],
{},
ops.core.aten_mul,
0,
),
name="match_0_inputs",
),
common_utils.subtest(
([torch.randn(3, 4), torch.randn(3, 4)], {}, ops.core.aten_mul_bool, 0),
name="match_0_inputs_bool",
),
common_utils.subtest(
(
[
torch.randint(0, 2, size=(3, 4), dtype=torch.int).bool(),
torch.randint(0, 2, size=(3, 4), dtype=torch.int).bool(),
],
{},
ops.core.aten_mul_bool,
2,
),
name="match_2_inputs_bool",
),
],
)
def test_matching_score_system_on_overload_dtypes(self, inputs, kwargs, op, score):
op_schema_wrapper = onnxfunction_dispatcher._OnnxSchemaChecker(op)
op_schema_wrapper._record_matching_score(inputs, kwargs)
self.assertEqual(op_schema_wrapper.match_score, score)
@common_utils.parametrize(
"inputs, kwargs, op, score",
[
common_utils.subtest(
([torch.randn(3, 4), torch.tensor(3)], {}, ops.core.aten_new_full, 2),
name="match_2_inputs",
),
common_utils.subtest(
(
[torch.randn(3, 4), torch.tensor(3)],
{"dtype": 2}, # at this point, dtype should be converted to int
ops.core.aten_new_full_dtype,
2,
),
name="match_2_input_and_match_1_kwargs_optional",
),
],
)
def test_matching_score_system_on_optional_dtypes(self, inputs, kwargs, op, score):
op_schema_wrapper = onnxfunction_dispatcher._OnnxSchemaChecker(op)
op_schema_wrapper._record_matching_score(inputs, kwargs)
self.assertEqual(op_schema_wrapper.match_score, score)
@common_utils.parametrize(
"value, expected_onnx_str_dtype",
[
common_utils.subtest(
(1, {"tensor(int64)", "tensor(int16)", "tensor(int32)"}),
name="all_ints",
),
common_utils.subtest(
(1.0, {"tensor(float)", "tensor(double)", "tensor(float16)"}),
name="all_floats",
),
common_utils.subtest(
(torch.tensor([True]), {"tensor(bool)"}),
name="bool",
),
common_utils.subtest(
(torch.tensor([1], dtype=torch.int64), {"tensor(int64)"}),
name="int64",
),
common_utils.subtest(
(torch.tensor([1], dtype=torch.int32), {"tensor(int32)"}),
name="int32",
),
common_utils.subtest(
(torch.tensor([1], dtype=torch.int16), {"tensor(int16)"}),
name="int16",
),
common_utils.subtest(
(torch.tensor([1], dtype=torch.float), {"tensor(float)"}),
name="float",
),
common_utils.subtest(
(torch.tensor([1], dtype=torch.float16), {"tensor(float16)"}),
name="float16",
),
common_utils.subtest(
(torch.tensor([1], dtype=torch.double), {"tensor(double)"}),
name="double",
),
common_utils.subtest((None, set()), name="None"), # None allows no dtype
common_utils.subtest(
([], set()), name="empaty_list"
), # Empty list allows no dtype
],
)
def test_find_onnx_data_type(self, value, expected_onnx_str_dtype):
self.assertEqual(
onnxfunction_dispatcher._find_onnx_data_type(value), expected_onnx_str_dtype
)
if __name__ == "__main__":
common_utils.run_tests()