| import torch | 
 | from torch import Tensor | 
 |  | 
 | aten = torch.ops.aten | 
 | import inspect | 
 | import warnings | 
 | from typing import Dict, List, Optional, Set | 
 |  | 
 | from torch.types import Number | 
 |  | 
 | decomposition_table: Dict[str, torch.jit.ScriptFunction] = {} | 
 | function_name_set: Set[str] = set() | 
 |  | 
 |  | 
 | def check_decomposition_has_type_annotations(f): | 
 |     inspect_empty = inspect._empty  # type: ignore[attr-defined] | 
 |     sig = inspect.signature(f) | 
 |     for param in sig.parameters.values(): | 
 |         assert ( | 
 |             param.annotation != inspect_empty | 
 |         ), f"No signature on param {param.name} for function {f.name}" | 
 |  | 
 |     assert ( | 
 |         sig.return_annotation != inspect_empty | 
 |     ), f"No return annotation for function {f.name}" | 
 |  | 
 |  | 
 | def signatures_match(decomposition_sig, torch_op_sig): | 
 |     decomp_params = decomposition_sig.parameters | 
 |     op_params = torch_op_sig.parameters | 
 |  | 
 |     if len(decomp_params) != len(op_params): | 
 |         return False | 
 |  | 
 |     for decomp_param, op_param in zip(decomp_params.values(), op_params.values()): | 
 |         # can't check full equality yet because not all fields are correcly deduced | 
 |         # in the torch_op_sig - like default value | 
 |         # can't check 'kind' bc | 
 |         # kwarg-only values with defaults not yet supported in TS | 
 |         inspect_empty = inspect._empty  # type: ignore[attr-defined] | 
 |         for field in ["name", "annotation"]: | 
 |             if field == "name" and decomp_param.name == "self": | 
 |                 warnings.warn("PyTorch uses 'input' instead of 'self' on public api") | 
 |  | 
 |             if getattr(decomp_param, field) != getattr(op_param, field): | 
 |                 return False | 
 |  | 
 |         decomp_default = decomp_param.default | 
 |         op_default = op_param.default | 
 |         # default value not always correctly inferred as being present on torch schema, | 
 |         # but if specified on both they should be equal | 
 |         if decomp_default != inspect_empty and op_default != inspect_empty: | 
 |             if decomp_default != op_default: | 
 |                 return False | 
 |  | 
 |     return decomposition_sig.return_annotation == torch_op_sig.return_annotation | 
 |  | 
 |  | 
 | def register_decomposition(aten_op, registry=None): | 
 |     def decomposition_decorator(f): | 
 |         nonlocal registry | 
 |         if registry is None: | 
 |             registry = decomposition_table | 
 |  | 
 |         assert isinstance(aten_op, torch._ops.OpOverload) | 
 |  | 
 |         # Need unique name for jit function serialization | 
 |         assert ( | 
 |             f.__name__ not in function_name_set | 
 |         ), f"Duplicated function name {f.__name__}" | 
 |         function_name_set.add(f.__name__) | 
 |  | 
 |         scripted_func = torch.jit.script(f) | 
 |         torch._C._jit_pass_inline(scripted_func.graph) | 
 |  | 
 |         for _ in range(2): | 
 |             torch._C._jit_pass_peephole(scripted_func.graph) | 
 |             torch._C._jit_pass_constant_propagation(scripted_func.graph) | 
 |  | 
 |         registry[str(aten_op._schema)] = scripted_func | 
 |         return f | 
 |  | 
 |     return decomposition_decorator | 
 |  | 
 |  | 
 | # TODO: replace torch.sigmoid -> aten.sigmoid | 
 |  | 
 |  | 
 | @register_decomposition(aten.var.correction) | 
 | def var_decomposition( | 
 |     input: Tensor, | 
 |     dim: Optional[List[int]] = None, | 
 |     correction: Optional[Number] = None, | 
 |     keepdim: bool = False, | 
 | ) -> Tensor: | 
 |     if dim is None: | 
 |         dim_i: List[int] = [] | 
 |         dim = dim_i | 
 |  | 
 |     if isinstance(dim, (tuple, list)) and len(dim) == 0: | 
 |         n = input.numel() | 
 |     else: | 
 |         n = 1 | 
 |         for dim_i in dim:  # type: ignore[assignment] | 
 |             n *= input.shape[dim_i]  # type: ignore[call-overload] | 
 |  | 
 |     mean = aten.mean(input, dim, True) | 
 |     sub = input - mean | 
 |     sq = sub * sub | 
 |     sum = aten.sum(sq, dim, keepdim) | 
 |  | 
 |     if correction is None: | 
 |         denom = float(n - 1) | 
 |     else: | 
 |         if isinstance(correction, int): | 
 |             denom = float(n - correction) | 
 |         elif isinstance(correction, float): | 
 |             denom = float(n) - correction | 
 |         else: | 
 |             raise RuntimeError("correction must be int or float") | 
 |  | 
 |     return sum / max(0, denom) | 
 |  | 
 |  | 
 | @register_decomposition(aten.var.default) | 
 | def var(input: Tensor, unbiased: bool = True) -> Tensor: | 
 |     return var_decomposition(input, correction=(1 if unbiased else 0)) |