[inductor] Add type hints to functions in decompositions.py (#131780)

Summary: ATT

Test Plan: lintrunner

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131780
Approved by: https://github.com/eellison
diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py
index 6ef6032..4568386 100644
--- a/torch/_inductor/decomposition.py
+++ b/torch/_inductor/decomposition.py
@@ -1,11 +1,10 @@
 # mypy: allow-untyped-decorators
-# mypy: allow-untyped-defs
 import functools
 import logging
 import math
 import sys
 import typing
-from typing import Optional
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
 
 import torch
 import torch._decomp as decomp
@@ -109,8 +108,10 @@
 remove_decompositions(decompositions, decomps_to_exclude)
 
 
-def register_decomposition(ops):
-    for op in [ops] if callable(ops) else ops:
+def register_decomposition(
+    ops: List[Union[torch._ops.OperatorBase, torch._ops.OpOverloadPacket]]
+) -> Callable[..., Any]:
+    for op in [ops] if callable(ops) else ops:  # type: ignore[attr-defined]
         if op in decompositions:
             log.warning("duplicate decomp: %s", ops)
     return decomp.register_decomposition(ops, decompositions)
@@ -119,24 +120,33 @@
 # TODO: for now, inductor doesn't handle asserts
 # because the condition is symbol -> tensor in the graph.
 @register_decomposition([aten._assert_async.msg])
-def assert_async_msg_decomp(tensor, msg):
+def assert_async_msg_decomp(tensor: torch.Tensor, msg: str) -> None:
     return
 
 
 # Following `assert_async_msg_decomp` and implement as non-op.
 @register_decomposition([aten._functional_assert_async.msg])
-def functional_assert_async_msg_decomp(tensor, msg):
+def functional_assert_async_msg_decomp(tensor: torch.Tensor, msg: str) -> None:
     return
 
 
 @register_decomposition([aten.sym_constrain_range_for_size.default])
-def sym_constrain_range_for_size(symbol, *, min=None, max=None):
+def sym_constrain_range_for_size(
+    symbol: torch.SymInt,
+    *,
+    min: Optional[torch.types.Number] = None,
+    max: Optional[torch.types.Number] = None,
+) -> None:
     return
 
 
 @register_decomposition([aten.clamp])
 @pw_cast_for_opmath
-def clamp(x, min=None, max=None):
+def clamp(
+    x: torch.Tensor,
+    min: Optional[torch.types.Number] = None,
+    max: Optional[torch.types.Number] = None,
+) -> torch.Tensor:
     if min is not None:
         x = x.clamp_min(min)
     if max is not None:
@@ -145,7 +155,11 @@
 
 
 @register_decomposition([aten.full])
-def full(size, fill_value, **kwargs):
+def full(
+    size: List[Union[int, torch.SymInt]],
+    fill_value: torch.types.Number,
+    **kwargs: Any,
+) -> torch.Tensor:
     dtype = kwargs.get("dtype")
     if dtype is None:
         kwargs["dtype"] = type_to_dtype(type(fill_value))
@@ -158,7 +172,11 @@
 # to decompose to empty_strided (but inductor is OK with it, because we are
 # cool with strides and everything goes to empty_strided)
 @register_decomposition([aten.empty_permuted.default])
-def empty_permuted(size, physical_layout, **kwargs):
+def empty_permuted(
+    size: List[Union[int, torch.SymInt]],
+    physical_layout: List[int],
+    **kwargs: Any,
+) -> torch.Tensor:
     perm = [0] * len(size)
     for p, l in enumerate(physical_layout):
         perm[l] = p
@@ -167,18 +185,18 @@
 
 @register_decomposition([aten.convolution_backward])
 def convolution_backward(
-    grad_output,
-    input,
-    weight,
-    bias_sizes,
-    stride,
-    padding,
-    dilation,
-    transposed,
-    output_padding,
-    groups,
-    output_mask,
-):
+    grad_output: torch.Tensor,
+    input: torch.Tensor,
+    weight: torch.Tensor,
+    bias_sizes: List[int],
+    stride: Union[int, List[int]],
+    padding: Union[int, List[int]],
+    dilation: Union[int, List[int]],
+    transposed: bool,
+    output_padding: List[int],
+    groups: int,
+    output_mask: List[bool],
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
     if not output_mask[2] or not is_gpu(grad_output.device.type):
         return NotImplemented
     grad_bias = aten.sum(grad_output, [0] + list(range(2, grad_output.dim())))
@@ -199,14 +217,17 @@
 
 
 @register_decomposition([aten.round.decimals])
-def round_dec(x, decimals=0):
+def round_dec(x: torch.Tensor, decimals: int = 0) -> torch.Tensor:
     ten_pow_decimals = 10.0**decimals
     return aten.round(x * ten_pow_decimals) * (1.0 / ten_pow_decimals)
 
 
 @register_decomposition([aten.bmm])
 @pw_cast_for_opmath
-def bmm(self, batch2):
+def bmm(
+    self: torch.Tensor,
+    batch2: torch.Tensor,
+) -> torch.Tensor:
     if config.coordinate_descent_tuning:
         if guard_size_oblivious(self.shape[1] == 1) or guard_size_oblivious(
             batch2.shape[2] == 1
@@ -226,7 +247,13 @@
 
 @register_decomposition([aten.addmm])
 @pw_cast_for_opmath
-def addmm(self, mat1, mat2, beta=1, alpha=1):
+def addmm(
+    self: torch.Tensor,
+    mat1: torch.Tensor,
+    mat2: torch.Tensor,
+    beta: torch.types.Number = 1,
+    alpha: torch.types.Number = 1,
+) -> torch.Tensor:
     if self.device.type == "cpu":
         if guard_size_oblivious(mat1.size(0) == 1) and guard_size_oblivious(
             mat2.size(-1) == 1
@@ -249,7 +276,10 @@
 
 @register_decomposition([aten.mm])
 @pw_cast_for_opmath
-def mm(self, input2):
+def mm(
+    self: torch.Tensor,
+    input2: torch.Tensor,
+) -> torch.Tensor:
     # Our matrix vector multiplies only achieve peak bandwidth with coordinate descent tuning.
     # todo: Look into why and fix it (hopefully)
     if config.coordinate_descent_tuning:
@@ -282,10 +312,13 @@
 # - Normalize cat calls, so that legacy empty 1-D tensors are removed (NB: we
 #   don't remove ALL empty tensors, only the naughty ones)
 @register_decomposition([aten.cat.default])
-def cat(tensors, dim=0):
+def cat(
+    tensors: List[torch.Tensor],
+    dim: int = 0,
+) -> torch.Tensor:
     from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
 
-    def non_empty_tensor(x):
+    def non_empty_tensor(x: torch.Tensor) -> bool:
         # For better or worse, this is a valid cat:
         #
         #   torch.cat([torch.randn(2, 2, 4), torch.randn(0), torch.randn(3, 2, 4)])
@@ -322,7 +355,7 @@
 
 
 @register_decomposition([aten.angle])
-def angle(x):
+def angle(x: torch.Tensor) -> torch.Tensor:
     if x.is_complex():
         return torch.where(
             torch.isnan(x.real), float("nan"), torch.atan2(x.imag, x.real)
@@ -342,7 +375,12 @@
 
 
 @register_decomposition([aten.add])
-def add(x, y, *, alpha=None):
+def add(
+    x: torch.Tensor,
+    y: torch.Tensor,
+    *,
+    alpha: Optional[torch.types.Number] = None,
+) -> torch.Tensor:
     # Require both x and y to be complex tensors.
     x_is_complex_tensor = torch.is_tensor(x) and x.is_complex()
     y_is_complex_tensor = torch.is_tensor(y) and y.is_complex()
@@ -355,7 +393,7 @@
 
     # For complex typed `x`, `x.view(x.real.dtype)` doubles the last dimension and can cause problem
     # when broadcasting the add.
-    def reshape_tensor_complex(tensor):
+    def reshape_tensor_complex(tensor: torch.Tensor) -> torch.Tensor:
         """Reshape tensor from [*initial_dims, last_dim] to *initial_dims, last_dim/2, 2]"""
         # Get the current shape of the tensor
         *initial_dims, last_dim = tensor.shape
@@ -379,68 +417,97 @@
 
 
 @register_decomposition([aten.conj_physical])
-def conj_physical(self):
+def conj_physical(self: torch.Tensor) -> torch.Tensor:
     assert not self.is_complex(), "TODO: implement this"
     return self
 
 
 @register_decomposition([aten.lift, aten.detach_])
-def lift(self):
+def lift(self: torch.Tensor) -> torch.Tensor:
     return self
 
 
 @register_decomposition([aten.bernoulli.default])
-def bernoulli(self, *, generator=None):
+def bernoulli(
+    self: torch.Tensor,
+    *,
+    generator: Optional[torch.Generator] = None,
+) -> torch.Tensor:
     assert generator is None
     return (torch.rand_like(self, dtype=torch.float32) < self).to(self.dtype)
 
 
 @register_decomposition([aten.fmin, prims.fmin])
-def fmin(self, other):
+def fmin(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
     return torch.where(torch.isnan(other) | (other > self), self, other)
 
 
 @register_decomposition([aten.fmax, prims.fmax])
-def fmax(self, other):
+def fmax(self: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
     return torch.where(torch.isnan(other) | (other < self), self, other)
 
 
 @register_decomposition(aten.amax)
-def amax(self, dim=None, keepdim=False):
+def amax(
+    self: torch.Tensor,
+    dim: Optional[int] = None,
+    keepdim: bool = False,
+) -> torch.Tensor:
     if self.dtype == torch.bool:
         return torch.any(self, dim=dim, keepdim=keepdim)
     return NotImplemented
 
 
 @register_decomposition(aten.amin)
-def amin(self, dim=None, keepdim=False):
+def amin(
+    self: torch.Tensor,
+    dim: Optional[int] = None,
+    keepdim: bool = False,
+) -> torch.Tensor:
     if self.dtype == torch.bool:
         return torch.all(self, dim=dim, keepdim=keepdim)
     return NotImplemented
 
 
 @register_decomposition([aten.narrow_copy])
-def narrow_copy(self, dim, start, length):
+def narrow_copy(
+    self: torch.Tensor,
+    dim: int,
+    start: int,
+    length: int,
+) -> torch.Tensor:
     return torch.narrow(self, dim, start, length).clone()
 
 
 @register_decomposition([aten.expand_copy])
-def expand_copy(self, size, *, implicit=False):
+def expand_copy(
+    self: torch.Tensor,
+    size: List[Union[int, torch.SymInt]],
+    *,
+    implicit: bool = False,
+) -> torch.Tensor:
     return aten.expand(self, size, implicit=implicit).clone()
 
 
 @register_decomposition([aten.view_copy.default])
-def view_copy_default(self, size):
+def view_copy_default(
+    self: torch.Tensor,
+    size: List[Union[int, torch.SymInt]],
+) -> torch.Tensor:
     return aten.view(self, size).clone()
 
 
 @register_decomposition([aten.view_copy.dtype])
-def view_copy_dtype(self, dtype):
+def view_copy_dtype(
+    self: torch.Tensor,
+    dtype: torch.dtype,
+) -> torch.Tensor:
     return self.to(dtype).clone()
 
 
 def get_like_layout(
-    tensor: torch.Tensor, memory_format: Optional[torch.memory_format]
+    tensor: torch.Tensor,
+    memory_format: Optional[torch.memory_format] = None,
 ) -> torch.memory_format:
     # TODO: _to_copy tensor to stride permutation
     if memory_format is torch.preserve_format or memory_format is None:
@@ -450,7 +517,14 @@
 
 
 @register_decomposition(aten.rand_like)
-def rand_like(self, *, dtype=None, device=None, memory_format=None, **kwargs):
+def rand_like(
+    self: torch.Tensor,
+    *,
+    dtype: Optional[torch.dtype] = None,
+    device: Optional[torch.device] = None,
+    memory_format: Optional[torch.memory_format] = None,
+    **kwargs: Any,
+) -> torch.Tensor:
     return torch.rand(
         [*self.size()],
         dtype=dtype or self.dtype,
@@ -460,7 +534,14 @@
 
 
 @register_decomposition(aten.randn_like)
-def randn_like(self, *, dtype=None, device=None, memory_format=None, **kwargs):
+def randn_like(
+    self: torch.Tensor,
+    *,
+    dtype: Optional[torch.dtype] = None,
+    device: Optional[torch.device] = None,
+    memory_format: Optional[torch.memory_format] = None,
+    **kwargs: Any,
+) -> torch.Tensor:
     return torch.randn(
         [*self.size()],
         dtype=dtype or self.dtype,
@@ -471,16 +552,16 @@
 
 @register_decomposition(aten.full_like)
 def full_like(
-    self,
-    fill_value,
+    self: torch.Tensor,
+    fill_value: Union[int, float],
     *,
-    dtype=None,
-    layout=None,
-    device=None,
-    pin_memory=False,
-    requires_grad=False,
-    memory_format=torch.preserve_format,
-):
+    dtype: Optional[torch.dtype] = None,
+    layout: Optional[torch.layout] = None,
+    device: Optional[torch.device] = None,
+    pin_memory: bool = False,
+    requires_grad: bool = False,
+    memory_format: torch.memory_format = torch.preserve_format,
+) -> torch.Tensor:
     return torch.full(
         [*self.size()],
         fill_value,
@@ -492,7 +573,15 @@
 
 
 @register_decomposition(aten.randint_like.default)
-def randint_like(self, high, *, dtype=None, device=None, memory_format=None, **kwargs):
+def randint_like(
+    self: torch.Tensor,
+    high: int,
+    *,
+    dtype: Optional[torch.dtype] = None,
+    device: Optional[torch.device] = None,
+    memory_format: Optional[torch.memory_format] = None,
+    **kwargs: Any,
+) -> torch.Tensor:
     return aten.randint.low(
         0,
         high,
@@ -505,8 +594,15 @@
 
 @register_decomposition(aten.randint_like.low_dtype)
 def randint_like_low(
-    self, low, high, *, dtype=None, device=None, memory_format=None, **kwargs
-):
+    self: torch.Tensor,
+    low: int,
+    high: int,
+    *,
+    dtype: Optional[torch.dtype] = None,
+    device: Optional[torch.device] = None,
+    memory_format: Optional[torch.memory_format] = None,
+    **kwargs: Any,
+) -> torch.Tensor:
     return aten.randint.low(
         low,
         high,
@@ -518,13 +614,19 @@
 
 
 @register_decomposition(aten.randint.default)
-def randint(high, size, **kwargs):
+def randint(
+    high: int,
+    size: List[Union[int, torch.SymInt]],
+    **kwargs: Any,
+) -> torch.Tensor:
     return aten.randint.low(0, high, size, **kwargs)
 
 
 @register_decomposition(quantized.linear_dynamic_fp16_unpacked_weight.default)
 def linear_dynamic_fp16_unpacked_weight(
-    input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
+    input: torch.Tensor,
+    weight: torch.Tensor,
+    bias: torch.Tensor,
 ) -> torch.Tensor:
     packed_weight = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16(weight)
     return torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight(
@@ -533,8 +635,8 @@
 
 
 @register_decomposition(torch.ops.quantized.embedding_bag_byte_unpack)
-def q_embedding_bag_byte_unpack_decomp(packed):
-    def bitcast_u8_to_f32(u8):
+def q_embedding_bag_byte_unpack_decomp(packed: torch.Tensor) -> torch.Tensor:
+    def bitcast_u8_to_f32(u8: torch.Tensor) -> torch.Tensor:
         x, y, z, w = (u8[..., n].to(torch.int32) for n in (0, 1, 2, 3))
         if sys.byteorder == "little":
             return (x + (y << 8) + (z << 16) + (w << 24)).view(torch.float32)[..., None]
@@ -578,21 +680,35 @@
 
 
 @register_decomposition(aten._foreach_addcmul.Scalar)
-def _foreach_addcmul_scalar(self, left_tensors, right_tensors, scalar=1):
+def _foreach_addcmul_scalar(
+    self: List[torch.Tensor],
+    left_tensors: List[torch.Tensor],
+    right_tensors: List[torch.Tensor],
+    scalar: float = 1,
+) -> List[torch.Tensor]:
     return aten._foreach_add.List(
         self, aten._foreach_mul.List(left_tensors, right_tensors), alpha=scalar
     )
 
 
 @register_decomposition(aten._foreach_addcdiv.Scalar)
-def _foreach_addcdiv_scalar(self, left_tensors, right_tensors, scalar=1):
+def _foreach_addcdiv_scalar(
+    self: List[torch.Tensor],
+    left_tensors: List[torch.Tensor],
+    right_tensors: List[torch.Tensor],
+    scalar: float = 1,
+) -> List[torch.Tensor]:
     return aten._foreach_add.List(
         self, aten._foreach_div.List(left_tensors, right_tensors), alpha=scalar
     )
 
 
 @register_decomposition(aten._foreach_lerp.Scalar)
-def _foreach_lerp_scalar(start_tensors, end_tensors, weight):
+def _foreach_lerp_scalar(
+    start_tensors: List[torch.Tensor],
+    end_tensors: List[torch.Tensor],
+    weight: torch.types.Number,
+) -> List[torch.Tensor]:
     return aten._foreach_add.List(
         start_tensors,
         aten._foreach_mul.Scalar(
@@ -612,7 +728,7 @@
     training: bool,
     exponential_average_factor: float,
     epsilon: float,
-):
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
     a, b, c = aten.native_batch_norm(
         input,
         weight,
@@ -634,11 +750,13 @@
 
 
 @functools.lru_cache(None)
-def fast_random_decomps():
+def fast_random_decomps() -> Dict[Any, Callable[..., Any]]:
     return {**decompositions, **extra_random_decomps}
 
 
-def select_decomp_table():
+# TODO(aakhundov): replace this (and the above) Any by more
+# specific type and fix all the cascading mypy errors
+def select_decomp_table() -> Dict[Any, Callable[..., Any]]:
     """decomps can change based on config"""
     if config.fallback_random:
         return decompositions
@@ -646,7 +764,11 @@
 
 
 @register_decomposition(aten.masked_scatter)
-def masked_scatter(self, mask, source):
+def masked_scatter(
+    self: torch.Tensor,
+    mask: torch.Tensor,
+    source: torch.Tensor,
+) -> torch.Tensor:
     from .codegen.common import BackendFeature, has_backend_feature
 
     if has_backend_feature(self.device, BackendFeature.MASKED_SCATTER_WITH_INDEX):
@@ -662,8 +784,12 @@
 
 @register_decomposition(quantized_decomposed.choose_qparams.tensor)
 def choose_qparams_tensor(
-    input: torch.Tensor, quant_min: int, quant_max: int, eps: float, dtype: torch.dtype
-):
+    input: torch.Tensor,
+    quant_min: int,
+    quant_max: int,
+    eps: float,
+    dtype: torch.dtype,
+) -> Tuple[torch.Tensor, torch.Tensor]:
     min_val, max_val = torch.aminmax(input)
     scale = (max_val - min_val) / float(quant_max - quant_min)
     scale = torch.max(scale, torch.Tensor([eps]))
@@ -673,7 +799,12 @@
 
 
 @register_decomposition(aten.put)
-def put(self, index, source, accumulate=False):
+def put(
+    self: torch.Tensor,
+    index: torch.Tensor,
+    source: torch.Tensor,
+    accumulate: bool = False,
+) -> torch.Tensor:
     flattened = self.flatten()
     flattened = torch.index_put(
         flattened, [index], source.reshape(index.shape), accumulate
@@ -682,14 +813,24 @@
 
 
 @register_decomposition(aten.put_)
-def put_(self, index, source, accumulate=False):
+def put_(
+    self: torch.Tensor,
+    index: torch.Tensor,
+    source: torch.Tensor,
+    accumulate: bool = False,
+) -> torch.Tensor:
     out = aten.put(self, index, source, accumulate=accumulate)
     return self.copy_(out)
 
 
 @register_decomposition(aten._softmax_backward_data.default)
 @pw_cast_for_opmath
-def _softmax_backward_data(grad_output, output, dim, input_dtype):
+def _softmax_backward_data(
+    grad_output: torch.Tensor,
+    output: torch.Tensor,
+    dim: int,
+    input_dtype: torch.dtype,
+) -> torch.Tensor:
     new_grad_output = grad_output * output
     sum_new_grad = torch.sum(new_grad_output, dim=dim, keepdim=True)
     # grad_input = new_grad_output - output * sum_new_grad
@@ -706,8 +847,14 @@
 
 @register_decomposition(aten.index_reduce)
 def index_reduce(
-    self, dim: int, index, src, reduction_type: str, *, include_self: bool = True
-):
+    self: torch.Tensor,
+    dim: int,
+    index: torch.Tensor,
+    src: torch.Tensor,
+    reduction_type: str,
+    *,
+    include_self: bool = True,
+) -> torch.Tensor:
     if reduction_type == "mean" and not needs_fallback_due_to_atomic_add_limitations(
         self.dtype
     ):
@@ -753,8 +900,13 @@
 
 @register_decomposition(aten.max_pool2d_with_indices)
 def max_pool2d_with_indices(
-    x, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False
-):
+    x: torch.Tensor,
+    kernel_size: List[int],
+    stride: Optional[Union[int, List[int]]] = None,
+    padding: Union[int, List[int]] = 0,
+    dilation: Union[int, List[int]] = 1,
+    ceil_mode: bool = False,
+) -> Tuple[torch.Tensor, torch.Tensor]:
     if dilation == 1:
         dilation = [1, 1]