[ONNX][TypePromo] Materialize type promotion rules (#104063)
This PR adds the materialized type promotion rule set and the type promotion rule infrastructure
that will be consumed by ONNX exporter. The script that generates the rule set is included, and
a local unittest is added to check the validity of the materialized rule set.
Full design doc and discussion at https://microsoft-my.sharepoint.com/:w:/p/bowbao/Edj2lF1oi0JIitT_3ntyuqkBo6ll7N6NJDmavM0lp_KkEA?e=OElyjR
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104063
Approved by: https://github.com/titaiwangms, https://github.com/justinchuby
diff --git a/test/onnx/pytorch_test_common.py b/test/onnx/pytorch_test_common.py
index 7bfac4e..27d00ea 100644
--- a/test/onnx/pytorch_test_common.py
+++ b/test/onnx/pytorch_test_common.py
@@ -211,6 +211,28 @@
return skip_dec
+def skip_in_ci(reason: str):
+ """Skip test in CI.
+
+ Args:
+ reason: The reason for skipping test in CI.
+
+ Returns:
+ A decorator for skipping test in CI.
+ """
+
+ def skip_dec(func):
+ @functools.wraps(func)
+ def wrapper(self, *args, **kwargs):
+ if os.getenv("CI"):
+ raise unittest.SkipTest(f"Skip test in CI. {reason}")
+ return func(self, *args, **kwargs)
+
+ return wrapper
+
+ return skip_dec
+
+
def xfail(reason: str):
"""Expect failure.
diff --git a/test/onnx/test_fx_type_promotion.py b/test/onnx/test_fx_type_promotion.py
new file mode 100644
index 0000000..ce6e9a0
--- /dev/null
+++ b/test/onnx/test_fx_type_promotion.py
@@ -0,0 +1,26 @@
+# Owner(s): ["module: onnx"]
+
+import pytorch_test_common
+from torch.onnx._internal.fx.passes import type_promotion
+from torch.testing._internal import common_utils
+
+
+class TestGeneratedTypePromotionRuleSet(common_utils.TestCase):
+ @pytorch_test_common.skip_in_ci(
+ "Reduce noise in CI. "
+ "The test serves as a tool to validate if the generated rule set is current. "
+ )
+ def test_generated_rule_set_is_up_to_date(self):
+ generated_set = type_promotion._GENERATED_ATEN_TYPE_PROMOTION_RULE_SET
+ latest_set = (
+ type_promotion.TypePromotionRuleSetGenerator.generate_from_torch_refs()
+ )
+
+ self.assertEqual(generated_set, latest_set)
+
+ def test_initialize_type_promotion_table_succeeds(self):
+ type_promotion.TypePromotionTable()
+
+
+if __name__ == "__main__":
+ common_utils.run_tests()
diff --git a/torch/onnx/_internal/fx/passes/type_promotion.py b/torch/onnx/_internal/fx/passes/type_promotion.py
new file mode 100644
index 0000000..64e9607
--- /dev/null
+++ b/torch/onnx/_internal/fx/passes/type_promotion.py
@@ -0,0 +1,935 @@
+# Owner(s): ["module: onnx"]
+from __future__ import annotations
+
+import dataclasses
+import inspect
+import logging
+from types import ModuleType
+
+from typing import Callable, Mapping, Optional, Sequence, Set
+
+import torch
+import torch._ops
+import torch.fx
+
+from torch import _prims_common, _refs
+
+from torch._prims_common import (
+ ELEMENTWISE_TYPE_PROMOTION_KIND,
+ wrappers as _prims_common_wrappers,
+)
+from torch._refs import linalg as _linalg_refs, nn as _nn_refs, special as _special_refs
+from torch._refs.nn import functional as _functional_refs
+
+# Imported to resolve beartype issue when type checking node.Argument.
+from torch.fx.node import Node # noqa: F401
+
+from torch.onnx._internal import _beartype
+from torch.utils import _pytree
+
+logger = logging.getLogger(__name__)
+
+
+def _try_getclosurevars(func):
+ try:
+ return inspect.getclosurevars(func)
+ except TypeError as e:
+ return None
+
+
+@dataclasses.dataclass
+class TypePromotionSnapshot:
+ """Type promotion snapshot for a fx node and its inputs.
+
+ Contains the promoted dtype for args and kwargs that needs promoting.
+ Contains the expected node output dtype.
+ """
+
+ args_dtypes: Mapping[int, torch.dtype]
+ """Mapping from arg position to dtype to promote to."""
+
+ kwargs_dtypes: Mapping[str, torch.dtype]
+ """Mapping from kwarg name to dtype to promote to."""
+
+ out_dtype: torch.dtype
+ """Expected output dtype of the node."""
+
+
+class TypePromotionRule:
+ """Defines how to perform type promotion for 'torch.ops.{namespace}.{op_name}'."""
+
+ def __init__(
+ self,
+ namespace: str,
+ op_name: str,
+ promote_args_positions: Sequence[int],
+ promote_kwargs_names: Sequence[str],
+ promotion_kind: _prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND,
+ ):
+ self.namespace = namespace
+ self.op_name = op_name
+ self.promote_args_positions = promote_args_positions
+ self.promote_kwargs_names = promote_kwargs_names
+ self.promotion_kind = promotion_kind
+
+ def __hash__(self) -> int:
+ return f"{self.namespace}.{self.op_name}".__hash__()
+
+ def __repr__(self):
+ return (
+ f"TypePromotionRule('{self.namespace}', '{self.op_name}', {self.promote_args_positions}, "
+ f"{self.promote_kwargs_names}, {self.promotion_kind})"
+ )
+
+ def __eq__(self, __value: object) -> bool:
+ if not isinstance(__value, TypePromotionRule):
+ return False
+ return (
+ self.namespace == __value.namespace
+ and self.op_name == __value.op_name
+ and self.promote_args_positions == __value.promote_args_positions
+ and self.promote_kwargs_names == __value.promote_kwargs_names
+ and self.promotion_kind == __value.promotion_kind
+ )
+
+ def is_valid(self) -> bool:
+ """Check if the rule is valid."""
+ # This always returns a module. If the module does not exist it will be created.
+ module = getattr(torch.ops, self.namespace)
+ py_op = getattr(module, self.op_name, None)
+ if py_op is None:
+ logger.warning(
+ "Cannot find op: %s in module: %s", self.op_name, self.namespace
+ )
+ return False
+ if not isinstance(py_op, torch._ops.OpOverloadPacket):
+ logger.warning(
+ "Op: torch.ops.%s.%s is not an OpOverloadPacket, got: %s",
+ self.namespace,
+ self.op_name,
+ type(py_op),
+ )
+ return False
+
+ return True
+
+ def preview_type_promotion(
+ self, args: tuple, kwargs: dict
+ ) -> TypePromotionSnapshot:
+ """Preview type promotion results for the given fx node.
+
+ Returns a TypePromotionSnapshot object that contains the promoted dtypes for
+ the arguments and the expected output dtype of the given fx node.
+ """
+
+ candidate_args = {
+ i: args[i]
+ for i in self.promote_args_positions
+ if i < len(args) and args[i] is not None
+ }
+ candidate_kwargs = {
+ name: kwargs[name]
+ for name in self.promote_kwargs_names
+ if name in kwargs and kwargs[name] is not None
+ }
+
+ computed_dtype, result_dtype = _prims_common.elementwise_dtypes(
+ *_pytree.tree_flatten(candidate_args)[0],
+ *_pytree.tree_flatten(candidate_kwargs)[0],
+ type_promotion_kind=self.promotion_kind,
+ )
+
+ return TypePromotionSnapshot(
+ {i: computed_dtype for i in candidate_args.keys()},
+ {name: computed_dtype for name in candidate_kwargs.keys()},
+ result_dtype,
+ )
+
+
+# NOTE: BELOW TABLE IS GENERATED FROM `TypePromotionRuleSetGenerator.generate_from_torch_refs`.
+# DO NOT EDIT MANUALLY !!!
+# For missing rules or discrepancies, please
+# 1. Run `pytest test/onnx/test_fx_type_promotion.py` to validate if the generated rule set is current.
+# If it is not, update with new generated set.
+# 2. If discrepancies still exist, consider debugging torch._refs or report a bug.
+# 3. If rules are still missing, add them to `_EXTRA_TYPE_PROMOTION_RULE_SET` or report a bug.
+_GENERATED_ATEN_TYPE_PROMOTION_RULE_SET = {
+ TypePromotionRule(
+ "aten", "abs", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "abs_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "acos", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "acos_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "acosh", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "acosh_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "add", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "add_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "addcdiv", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "addcdiv_", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "addcmul", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "addcmul_", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "addr", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "asin", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "asin_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "asinh", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "asinh_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "atan", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "atan2", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "atan2_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "atan_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "atanh", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "atanh_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "bitwise_and", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "bitwise_and_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten",
+ "bitwise_left_shift",
+ [0, 1],
+ [],
+ ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
+ ),
+ TypePromotionRule(
+ "aten",
+ "bitwise_left_shift_",
+ [0, 1],
+ [],
+ ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
+ ),
+ TypePromotionRule(
+ "aten", "bitwise_not", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "bitwise_not_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "bitwise_or", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "bitwise_or_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten",
+ "bitwise_right_shift",
+ [0, 1],
+ [],
+ ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
+ ),
+ TypePromotionRule(
+ "aten",
+ "bitwise_right_shift_",
+ [0, 1],
+ [],
+ ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
+ ),
+ TypePromotionRule(
+ "aten", "bitwise_xor", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "bitwise_xor_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "cat", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH
+ ),
+ TypePromotionRule(
+ "aten", "cauchy", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "cauchy_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule("aten", "ceil", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT),
+ TypePromotionRule(
+ "aten", "ceil_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule("aten", "celu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT),
+ TypePromotionRule(
+ "aten", "celu_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "clamp", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "clamp_", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "copysign", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "copysign_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "cos", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "cos_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "cosh", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "cosh_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "deg2rad", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "deg2rad_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "digamma", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "digamma_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule("aten", "elu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT),
+ TypePromotionRule("aten", "elu_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT),
+ TypePromotionRule(
+ "aten", "eq", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+ ),
+ TypePromotionRule(
+ "aten", "eq_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+ ),
+ TypePromotionRule(
+ "aten", "erf", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "erf_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "erfc", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "erfc_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "erfinv", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "erfinv_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "exp", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "exp2", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "exp2_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "exp_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "expm1", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "expm1_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "exponential", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "exponential_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "fill", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH
+ ),
+ TypePromotionRule(
+ "aten", "floor", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "floor_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "floor_divide", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "floor_divide_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "fmax", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "fmin", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "fmod", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "fmod_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule("aten", "frac", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT),
+ TypePromotionRule(
+ "aten", "frac_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "gcd", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "gcd_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "ge", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+ ),
+ TypePromotionRule(
+ "aten", "ge_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+ ),
+ TypePromotionRule("aten", "gelu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT),
+ TypePromotionRule(
+ "aten", "geometric", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "geometric_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule("aten", "glu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT),
+ TypePromotionRule(
+ "aten", "gt", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+ ),
+ TypePromotionRule(
+ "aten", "gt_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+ ),
+ TypePromotionRule(
+ "aten", "hardtanh", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "heaviside", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "heaviside_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "huber_loss", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "hypot", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "hypot_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "i0", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "i0_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "igamma", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "igamma_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "igammac", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "igammac_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "isfinite", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+ ),
+ TypePromotionRule(
+ "aten", "isinf", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+ ),
+ TypePromotionRule(
+ "aten", "isnan", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+ ),
+ TypePromotionRule(
+ "aten", "isneginf", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+ ),
+ TypePromotionRule(
+ "aten", "isposinf", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+ ),
+ TypePromotionRule(
+ "aten", "isreal", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+ ),
+ TypePromotionRule(
+ "aten", "l1_loss", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "lcm", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "lcm_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "le", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+ ),
+ TypePromotionRule(
+ "aten", "le_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+ ),
+ TypePromotionRule(
+ "aten", "leaky_relu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "lerp", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "lerp_", [0, 1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "lgamma", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "lgamma_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "log", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "log10", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "log10_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "log1p", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "log1p_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "log2", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "log2_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "log_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "log_normal", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "log_normal_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "logaddexp", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "logaddexp2", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "logical_and", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+ ),
+ TypePromotionRule(
+ "aten", "logical_and_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+ ),
+ TypePromotionRule(
+ "aten", "logical_not", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+ ),
+ TypePromotionRule(
+ "aten", "logical_not_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+ ),
+ TypePromotionRule(
+ "aten", "logical_or", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+ ),
+ TypePromotionRule(
+ "aten", "logical_or_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+ ),
+ TypePromotionRule(
+ "aten", "logical_xor", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+ ),
+ TypePromotionRule(
+ "aten", "logical_xor_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+ ),
+ TypePromotionRule(
+ "aten", "logit", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "logsumexp", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "lt", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+ ),
+ TypePromotionRule(
+ "aten", "lt_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+ ),
+ TypePromotionRule(
+ "aten", "maximum", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "minimum", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule("aten", "mish", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT),
+ TypePromotionRule(
+ "aten", "mish_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "mse_loss", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "mul", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "mul_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "ne", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+ ),
+ TypePromotionRule(
+ "aten", "ne_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+ ),
+ TypePromotionRule("aten", "neg", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT),
+ TypePromotionRule("aten", "neg_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT),
+ TypePromotionRule(
+ "aten", "nextafter", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH
+ ),
+ TypePromotionRule(
+ "aten", "nextafter_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH
+ ),
+ TypePromotionRule(
+ "aten", "nll_loss", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "normal", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "normal_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "pdist", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten",
+ "poisson_nll_loss",
+ [0, 1],
+ [],
+ ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
+ ),
+ TypePromotionRule(
+ "aten", "pow", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG
+ ),
+ TypePromotionRule(
+ "aten", "pow_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG
+ ),
+ TypePromotionRule(
+ "aten", "prelu", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "rad2deg", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "rad2deg_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "reciprocal", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "reciprocal_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule("aten", "relu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT),
+ TypePromotionRule(
+ "aten", "remainder", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "remainder_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "round", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "rsqrt", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "rsqrt_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule("aten", "selu", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT),
+ TypePromotionRule(
+ "aten", "selu_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule("aten", "sgn", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT),
+ TypePromotionRule("aten", "sgn_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT),
+ TypePromotionRule(
+ "aten", "sigmoid", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "sigmoid_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule("aten", "sign", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT),
+ TypePromotionRule(
+ "aten", "sign_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "signbit", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL
+ ),
+ TypePromotionRule(
+ "aten", "sin", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "sin_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "sinc", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "sinc_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "sinh", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "sinh_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten",
+ "smooth_l1_loss",
+ [0, 1],
+ [],
+ ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
+ ),
+ TypePromotionRule(
+ "aten", "softplus", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "sqrt", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "sqrt_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "square", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG
+ ),
+ TypePromotionRule(
+ "aten", "square_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG
+ ),
+ TypePromotionRule(
+ "aten", "sub", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "sub_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "tan", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "tan_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "tanh", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "tanh_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "threshold", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "threshold_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "true_divide", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "true_divide_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "trunc", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "trunc_", [0], [], ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
+ ),
+ TypePromotionRule(
+ "aten", "where", [1, 2], [], ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH
+ ),
+ TypePromotionRule(
+ "aten", "xlogy", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+ TypePromotionRule(
+ "aten", "xlogy_", [0, 1], [], ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
+ ),
+}
+
+# Manually curated extra type promotion rules.
+_EXTRA_TYPE_PROMOTION_RULE_SET = {
+ # torch._refs skips type promotion decoration for `clamp_min` and `clamp_max` since
+ # the call is routed to the decorated `aten.clamp` op.
+ TypePromotionRule(
+ "aten",
+ "clamp_max",
+ promote_args_positions=(0, 1),
+ promote_kwargs_names=(),
+ promotion_kind=_prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
+ ),
+ TypePromotionRule(
+ "aten",
+ "clamp_min",
+ promote_args_positions=(0, 1),
+ promote_kwargs_names=(),
+ promotion_kind=_prims_common.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
+ ),
+ # TODO: torch.ops.aten.div.Tensor_mode applies different type promotion rules
+ # depending on the value of the `mode` argument.
+}
+
+
+class TypePromotionRuleSetGenerator:
+ """Hackly distilling info from reference ops decorated with type promotion rule.
+
+ The goal is to retrieve the decorator
+
+ ```python
+ @elementwise_type_promotion_wrapper(
+ type_promoting_args=("a", "b"),
+ type_promotion_kind=type_promotion_kind,
+ )
+ ```
+
+ from the reference ops. It provides info as for which arguments are promoted
+ and what kind of promotion is applied.
+ """
+
+ @classmethod
+ def generate_from_torch_refs(cls) -> Set[TypePromotionRule]:
+ """Parse type promotion rules from reference ops under torch._C._refs."""
+ rule_set = set()
+ rule_set.update(cls._parse_torch_refs(_refs))
+ rule_set.update(cls._parse_torch_refs(_nn_refs))
+ rule_set.update(cls._parse_torch_refs(_linalg_refs))
+ rule_set.update(cls._parse_torch_refs(_special_refs))
+ rule_set.update(cls._parse_torch_refs(_functional_refs))
+ return rule_set
+
+ @classmethod
+ def _parse_torch_refs(cls, ref_module: ModuleType) -> Set[TypePromotionRule]:
+ logger.info("Processing module: %s", ref_module.__name__)
+ rule_set = set()
+ for name in ref_module.__all__:
+ decorated_op = getattr(ref_module, name)
+ rule = cls._parse_type_promotion_rule_from_refs_op(decorated_op)
+ if rule is not None and rule.is_valid():
+ rule_set.add(rule)
+
+ return rule_set
+
+ @classmethod
+ def _parse_type_promotion_rule_from_refs_op(
+ cls,
+ decorated_op: Callable,
+ ) -> Optional[TypePromotionRule]:
+ """Retrieve and parse type promotion decorator from op under torch._refs."""
+ fn = decorated_op
+ type_promo_wrapper = None
+ while fn_closure_vars := _try_getclosurevars(fn):
+ if "fn" not in fn_closure_vars.nonlocals:
+ break
+ if "self" in fn_closure_vars.nonlocals and isinstance(
+ fn_closure_vars.nonlocals["self"],
+ _prims_common_wrappers.elementwise_type_promotion_wrapper,
+ ):
+ type_promo_wrapper = fn_closure_vars.nonlocals["self"]
+ break
+ fn = fn_closure_vars.nonlocals["fn"]
+
+ if type_promo_wrapper is not None:
+ signature = inspect.signature(decorated_op)
+
+ pos = 0
+ promote_args_positions = []
+ promote_kwargs_names = []
+
+ if type_promo_wrapper.type_promoting_arg_names is not None:
+ for name, param in signature.parameters.items():
+ if name in type_promo_wrapper.type_promoting_arg_names:
+ if param.kind in (
+ param.POSITIONAL_OR_KEYWORD,
+ param.POSITIONAL_ONLY,
+ ):
+ promote_args_positions.append(pos)
+ elif param.kind == param.KEYWORD_ONLY:
+ promote_kwargs_names.append(name)
+ pos += 1
+
+ return TypePromotionRule(
+ "aten",
+ decorated_op.__name__,
+ promote_args_positions=promote_args_positions,
+ promote_kwargs_names=promote_kwargs_names,
+ promotion_kind=type_promo_wrapper.type_promotion_kind,
+ )
+
+ logger.warning(
+ "Cannot find type promotion rule for: %s.%s",
+ decorated_op.__module__,
+ decorated_op.__name__,
+ )
+ return None
+
+
+class TypePromotionTable:
+ """Type promotion table for torch.ops."""
+
+ def __init__(self):
+ self._rule_table = {}
+ for rule in _GENERATED_ATEN_TYPE_PROMOTION_RULE_SET:
+ self.add_rule(rule)
+ for rule in _EXTRA_TYPE_PROMOTION_RULE_SET:
+ self.add_rule(rule)
+
+ @_beartype.beartype
+ def add_rule(self, rule: TypePromotionRule) -> None:
+ """Add a type promotion rule for a python op in a torch.ops module.
+
+ Args:
+ rule: Type promotion rule.
+ module: Module containing the op. E.g. torch.ops.aten.
+
+ Raises:
+ ValueError: If the rule is invalid.
+ """
+ if not rule.is_valid():
+ raise ValueError("Invalid type promotion rule: {}".format(rule))
+ self._rule_table[f"{rule.namespace}.{rule.op_name}"] = rule
+
+ @_beartype.beartype
+ def get_rule(
+ self, py_op: torch._ops.OpOverloadPacket
+ ) -> Optional[TypePromotionRule]:
+ """Get type promotion rule for a python op under 'torch.ops.<namespace>'."""
+ return self._rule_table.get(str(py_op), None)