Fix type annotation errors in torch.functional (#43446)

Summary:
Closes gh-42968

Pull Request resolved: https://github.com/pytorch/pytorch/pull/43446

Reviewed By: albanD

Differential Revision: D23280962

Pulled By: malfet

fbshipit-source-id: de5386a95a20ecc814c39cbec3e4252112340b3a
diff --git a/mypy.ini b/mypy.ini
index b9c87ac..ce6bd2a 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -54,9 +54,6 @@
 [mypy-torch.distributed.*]
 ignore_errors = True
 
-[mypy-torch.functional.*]
-ignore_errors = True
-
 [mypy-torch.testing._internal.*]
 ignore_errors = True
 
diff --git a/test/onnx/expect/TestOperators.test_meshgrid.expect b/test/onnx/expect/TestOperators.test_meshgrid.expect
index 465b780..4107cc5 100644
--- a/test/onnx/expect/TestOperators.test_meshgrid.expect
+++ b/test/onnx/expect/TestOperators.test_meshgrid.expect
@@ -17,7 +17,7 @@
     }
   }
   node {
-    input: "x"
+    input: "0"
     input: "3"
     output: "4"
     name: "Reshape_1"
@@ -38,7 +38,7 @@
     }
   }
   node {
-    input: "y"
+    input: "1"
     input: "5"
     output: "6"
     name: "Reshape_3"
@@ -59,7 +59,7 @@
     }
   }
   node {
-    input: "z"
+    input: "2"
     input: "7"
     output: "8"
     name: "Reshape_5"
@@ -221,7 +221,7 @@
   }
   name: "torch-jit-export"
   input {
-    name: "x"
+    name: "0"
     type {
       tensor_type {
         elem_type: 1
@@ -234,7 +234,7 @@
     }
   }
   input {
-    name: "y"
+    name: "1"
     type {
       tensor_type {
         elem_type: 1
@@ -247,7 +247,7 @@
     }
   }
   input {
-    name: "z"
+    name: "2"
     type {
       tensor_type {
         elem_type: 1
diff --git a/torch/functional.py b/torch/functional.py
index b8c1d3c..8202940 100644
--- a/torch/functional.py
+++ b/torch/functional.py
@@ -1,7 +1,10 @@
-from typing import Tuple, Optional
+from typing import (
+    Tuple, Optional, Union, Any, Sequence, TYPE_CHECKING
+)
 
 import torch
 import torch.nn.functional as F
+from torch.types import _size
 from ._lowrank import svd_lowrank, pca_lowrank
 from .overrides import has_torch_function, handle_torch_function
 from ._jit_internal import boolean_dispatch, List
@@ -65,7 +68,7 @@
     if not torch.jit.is_scripting():
         if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):
             return handle_torch_function(broadcast_tensors, tensors, *tensors)
-    return _VF.broadcast_tensors(tensors)
+    return _VF.broadcast_tensors(tensors)  # type: ignore
 
 
 def split(tensor, split_size_or_sections, dim=0):
@@ -117,9 +120,15 @@
     # call here.
     return tensor.split(split_size_or_sections, dim)
 
+
+if TYPE_CHECKING:
+    _Indices = _size
+else:
+    _Indices = List[int]
+
+
 # equivalent to itertools.product(indices)
-def _indices_product(indices):
-    # type: (List[int]) -> (List[List[int]])
+def _indices_product(indices: _Indices) -> List[List[int]]:
     empty_list = torch.jit.annotate(List[int], [])
     result = [empty_list]
     for idx in indices:
@@ -130,6 +139,7 @@
         result = result_temp
     return result
 
+
 def _index_tensor_with_indices_list(tensor, indices):
     # type: (Tensor, List[int]) -> Tensor
     out = tensor
@@ -137,6 +147,7 @@
         out = out[index]
     return out
 
+
 def lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
     # type: (Tensor, Tensor, bool, bool) ->  (Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]])
     r"""Unpacks the data and pivots from a LU factorization of a tensor.
@@ -204,10 +215,12 @@
     m, n = shape[-2:]
     k = min(m, n)
     if unpack_data:
-        U = LU_data.triu()
+        U: Optional[Tensor] = LU_data.triu()
+        assert U is not None
         if m != k:
             U = U.narrow(-2, 0, k)
-        L = LU_data.tril()
+        L: Optional[Tensor] = LU_data.tril()
+        assert L is not None
         if k != n:
             L = L.narrow(-1, 0, k)
         L.diagonal(dim1=-2, dim2=-1).fill_(1)
@@ -217,9 +230,11 @@
     if unpack_pivots:
         LU_pivots_zero_idx = LU_pivots - 1
         if LU_data.dim() > 2:
-            P = torch.eye(m, device=LU_data.device, dtype=LU_data.dtype) \
-                     .expand(shape[:-1] + (m,)) \
-                     .clone(memory_format=torch.contiguous_format)
+            P: Optional[Tensor] = torch.eye(m, device=LU_data.device,
+                                            dtype=LU_data.dtype) \
+                .expand(shape[:-1] + (m,)) \
+                .clone(memory_format=torch.contiguous_format)
+            assert P is not None
 
             # TODO: rewrite when TorchScript supports product and map as
             # product(*map(lambda x: list(range(x)), shape[:-2])) when issue 33781 is fixed
@@ -321,15 +336,24 @@
             return handle_torch_function(einsum, operands, equation, *operands)
     if len(operands) == 1 and isinstance(operands[0], (list, tuple)):
         # the old interface of passing the operands as one list argument
-        operands = operands[0]
+        _operands = operands[0]
         # recurse incase operands contains value that has torch function
         # in the original implementation this line is omitted
-        return einsum(equation, *operands)
+        return einsum(equation, *_operands)
 
-    return _VF.einsum(equation, operands)
+    return _VF.einsum(equation, operands)  # type: ignore
 
 
-def meshgrid(*tensors):
+if TYPE_CHECKING:
+    # The JIT doesn't understand Union, so only add type annotation for mypy
+    def meshgrid(*tensors: Union[Tensor, List[Tensor]]) -> Tuple[Tensor, ...]:
+        return _meshgrid(*tensors)
+else:
+    def meshgrid(*tensors):
+        return _meshgrid(*tensors)
+
+
+def _meshgrid(*tensors):
     r"""Take :math:`N` tensors, each of which can be either scalar or 1-dimensional
 vector, and create :math:`N` N-dimensional grids, where the :math:`i` :sup:`th` grid is defined by
 expanding the :math:`i` :sup:`th` input over dimensions defined by other inputs.
@@ -363,8 +387,8 @@
             return handle_torch_function(meshgrid, tensors, *tensors)
     if len(tensors) == 1 and isinstance(tensors[0], (list, tuple)):
         # the old interface of passing the operands as one list argument
-        tensors = tensors[0]
-    return _VF.meshgrid(tensors)
+        tensors = tensors[0]  # type: ignore
+    return _VF.meshgrid(tensors)  # type: ignore
 
 
 def stft(input, n_fft, hop_length=None, win_length=None, window=None,
@@ -524,15 +548,24 @@
                 window=window, center=center, normalized=normalized, onesided=onesided,
                 length=length)
 
-    return _VF.istft(
-        input, n_fft, hop_length, win_length, window, center, normalized, onesided, length)
+    return _VF.istft(input, n_fft, hop_length, win_length, window, center, normalized, onesided, length)  # type: ignore
 
 
 del torch.unique_dim
 
 
-def _unique_impl(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
-    # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor, Tensor]
+if TYPE_CHECKING:
+    # These _impl functions return a variable number of tensors as output with
+    # __torch_function__; tuple unpacking is done already rather than being
+    # done by the caller of the _impl function
+    _unique_impl_out = Any
+else:
+    _unique_impl_out = Tuple[Tensor, Tensor, Tensor]
+
+
+def _unique_impl(input: Tensor, sorted: bool = True,
+                 return_inverse: bool = False, return_counts: bool = False,
+                 dim: Optional[int] = None) -> _unique_impl_out:
     r"""Returns the unique elements of the input tensor.
 
     .. note:: This function is different from :func:`torch.unique_consecutive` in the sense that
@@ -598,7 +631,7 @@
                 return_counts=return_counts, dim=dim)
 
     if dim is not None:
-        output, inverse_indices, counts = _VF.unique_dim(
+        output, inverse_indices, counts = _VF.unique_dim(  # type: ignore
             input,
             dim,
             sorted=sorted,
@@ -615,8 +648,9 @@
     return output, inverse_indices, counts
 
 
-def _unique_consecutive_impl(input, return_inverse=False, return_counts=False, dim=None):
-    # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor, Tensor]
+def _unique_consecutive_impl(input: Tensor, return_inverse: bool = False,
+                             return_counts: bool = False,
+                             dim: Optional[int] = None) -> _unique_impl_out:
     r"""Eliminates all but the first element from every consecutive group of equivalent elements.
 
     .. note:: This function is different from :func:`torch.unique` in the sense that this function
@@ -671,7 +705,7 @@
             return handle_torch_function(
                 unique_consecutive, (input,), input, return_inverse=return_inverse,
                 return_counts=return_counts, dim=dim)
-    output, inverse_indices, counts = _VF.unique_consecutive(
+    output, inverse_indices, counts = _VF.unique_consecutive(  # type: ignore
         input, return_inverse=return_inverse, return_counts=return_counts, dim=dim)
     return output, inverse_indices, counts
 
@@ -686,6 +720,7 @@
     output, _, counts = _unique_impl(input, sorted, return_inverse, return_counts, dim)
     return output, counts
 
+
 def _return_output(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
     # type: (Tensor, bool, bool, bool, Optional[int]) -> Tensor
 
@@ -696,6 +731,7 @@
     output, _, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim)
     return output
 
+
 def _return_inverse(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
     # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
 
@@ -706,6 +742,7 @@
     output, inverse_indices, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim)
     return output, inverse_indices
 
+
 _return_inverse_false = boolean_dispatch(
     arg_name='return_counts',
     arg_index=3,
@@ -748,6 +785,7 @@
     output, _, counts = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
     return output, counts
 
+
 def _consecutive_return_output(input, return_inverse=False, return_counts=False, dim=None):
     # type: (Tensor, bool, bool, Optional[int]) -> Tensor
 
@@ -758,6 +796,7 @@
     output, _, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
     return output
 
+
 def _consecutive_return_inverse(input, return_inverse=False, return_counts=False, dim=None):
     # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
 
@@ -768,6 +807,7 @@
     output, inverse_indices, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
     return output, inverse_indices
 
+
 _consecutive_return_inverse_false = boolean_dispatch(
     arg_name='return_counts',
     arg_index=1,
@@ -857,7 +897,7 @@
             raise RuntimeError("tensordot expects dims >= 0, but got dims={}".format(dims))
         dims_a = list(range(-dims, 0))
         dims_b = list(range(dims))
-    return _VF.tensordot(a, b, dims_a, dims_b)
+    return _VF.tensordot(a, b, dims_a, dims_b)  # type: ignore
 
 def cartesian_prod(*tensors):
     """Do cartesian product of the given sequence of tensors. The behavior is similar to
@@ -890,7 +930,7 @@
     if not torch.jit.is_scripting():
         if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):
             return handle_torch_function(cartesian_prod, tensors, *tensors)
-    return _VF.cartesian_prod(tensors)
+    return _VF.cartesian_prod(tensors)  # type: ignore
 
 def block_diag(*tensors):
     """Create a block diagonal matrix from provided tensors.
@@ -924,7 +964,7 @@
     """
     if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors):
         return handle_torch_function(block_diag, tensors, *tensors)
-    return torch._C._VariableFunctions.block_diag(tensors)
+    return torch._C._VariableFunctions.block_diag(tensors)  # type: ignore
 
 
 def cdist(x1, x2, p=2., compute_mode='use_mm_for_euclid_dist_if_necessary'):
@@ -974,11 +1014,11 @@
             return handle_torch_function(
                 cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode)
     if compute_mode == 'use_mm_for_euclid_dist_if_necessary':
-        return _VF.cdist(x1, x2, p, None)
+        return _VF.cdist(x1, x2, p, None)  # type: ignore
     elif compute_mode == 'use_mm_for_euclid_dist':
-        return _VF.cdist(x1, x2, p, 1)
+        return _VF.cdist(x1, x2, p, 1)  # type: ignore
     elif compute_mode == 'donot_use_mm_for_euclid_dist':
-        return _VF.cdist(x1, x2, p, 2)
+        return _VF.cdist(x1, x2, p, 2)  # type: ignore
     else:
         raise ValueError("{} is not a valid value for compute_mode".format(compute_mode))
 
@@ -1014,7 +1054,7 @@
             return handle_torch_function(atleast_1d, tensors, *tensors)
     if len(tensors) == 1:
         tensors = tensors[0]
-    return _VF.atleast_1d(tensors)
+    return _VF.atleast_1d(tensors)  # type: ignore
 
 def atleast_2d(*tensors):
     r"""
@@ -1049,7 +1089,7 @@
             return handle_torch_function(atleast_2d, tensors, *tensors)
     if len(tensors) == 1:
         tensors = tensors[0]
-    return _VF.atleast_2d(tensors)
+    return _VF.atleast_2d(tensors)  # type: ignore
 
 def atleast_3d(*tensors):
     r"""
@@ -1093,28 +1133,44 @@
             return handle_torch_function(atleast_3d, tensors, *tensors)
     if len(tensors) == 1:
         tensors = tensors[0]
-    return _VF.atleast_3d(tensors)
+    return _VF.atleast_3d(tensors)  # type: ignore
 
-# TODO: type dim as BroadcastingList when https://github.com/pytorch/pytorch/issues/33782 is fixed
-@overload  # noqa: 749
-def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None):  # noqa: 749
-    # type: (Tensor, str, Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor
-    pass
 
-@overload  # noqa: 749
-def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None):  # noqa: 749
-    # type: (Tensor, Optional[number], Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor
+if TYPE_CHECKING:
     pass
+    # There's no good way to use this type annotation; cannot rename norm() to
+    # _norm_impl() in a way that doesn't break JIT overloads. So leave untyped
+    # for mypy for now.
+    #    def norm(input: Tensor,
+    #             p: Optional[Union[str, Number]] = "fro",
+    #             dim: Optional[Union[int, List[int]]] = None,
+    #             keepdim: bool = False,
+    #             out: Optional[Tensor] = None,
+    #             dtype: _dtype = None) -> Tensor:
+    #        return _norm_impl(input, p, dim, keepdim, out, dtype)
+else:
+    # TODO: type dim as BroadcastingList when
+    # https://github.com/pytorch/pytorch/issues/33782 is fixed
+    @overload  # noqa: 749
+    def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None):  # noqa: 749
+        # type: (Tensor, str, Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor
+        pass
 
-@overload  # noqa: 749
-def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None):  # noqa: 749
-    # type: (Tensor, Optional[number], Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor
-    pass
+    @overload  # noqa: 749
+    def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None):  # noqa: 749
+        # type: (Tensor, Optional[number], Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor
+        pass
 
-@overload  # noqa: 749
-def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None):  # noqa: 749
-    # type: (Tensor, str, Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor
-    pass
+    @overload  # noqa: 749
+    def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None):  # noqa: 749
+        # type: (Tensor, Optional[number], Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor
+        pass
+
+    @overload  # noqa: 749
+    def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None):  # noqa: 749
+        # type: (Tensor, str, Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor
+        pass
+
 
 def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None):  # noqa: 749
     r"""Returns the matrix norm or vector norm of a given tensor.
@@ -1183,15 +1239,14 @@
 
     ndim = input.dim()
 
-
     # catch default case
     if dim is None and out is None and dtype is None and p is not None:
         if isinstance(p, str):
             if p == "fro":
-                return _VF.frobenius_norm(input, dim=(), keepdim=keepdim)
+                return _VF.frobenius_norm(input, dim=(), keepdim=keepdim)  # type: ignore
         if not isinstance(p, str):
             _dim = [i for i in range(ndim)]  # noqa: C416 TODO: rewrite as list(range(m))
-            return _VF.norm(input, p, dim=_dim, keepdim=keepdim)
+            return _VF.norm(input, p, dim=_dim, keepdim=keepdim)  # type: ignore
 
     # TODO: when https://github.com/pytorch/pytorch/issues/33782 is fixed
     # remove the overloads where dim is an int and replace with BraodcastingList1
@@ -1202,7 +1257,7 @@
         else:
             _dim = dim
     else:
-        _dim = None
+        _dim = None  # type: ignore
 
     if isinstance(p, str):
         if p == "fro":
@@ -1212,22 +1267,22 @@
             if _dim is None:
                 _dim = [i for i in range(ndim)]  # noqa: C416 TODO: rewrite as list(range(m))
             if out is None:
-                return _VF.frobenius_norm(input, _dim, keepdim=keepdim)
+                return _VF.frobenius_norm(input, _dim, keepdim=keepdim)  # type: ignore
             else:
-                return _VF.frobenius_norm(input, _dim, keepdim=keepdim, out=out)
+                return _VF.frobenius_norm(input, _dim, keepdim=keepdim, out=out)  # type: ignore
         elif p == "nuc":
             if dtype is not None:
                 raise ValueError("dtype argument is not supported in nuclear norm")
             if _dim is None:
                 if out is None:
-                    return _VF.nuclear_norm(input, keepdim=keepdim)
+                    return _VF.nuclear_norm(input, keepdim=keepdim)  # type: ignore
                 else:
-                    return _VF.nuclear_norm(input, keepdim=keepdim, out=out)
+                    return _VF.nuclear_norm(input, keepdim=keepdim, out=out)  # type: ignore
             else:
                 if out is None:
-                    return _VF.nuclear_norm(input, _dim, keepdim=keepdim)
+                    return _VF.nuclear_norm(input, _dim, keepdim=keepdim)  # type: ignore
                 else:
-                    return _VF.nuclear_norm(input, _dim, keepdim=keepdim, out=out)
+                    return _VF.nuclear_norm(input, _dim, keepdim=keepdim, out=out)  # type: ignore
         raise RuntimeError("only valid string values are 'fro' and 'nuc', found {}".format(p))
     else:
         if _dim is None:
@@ -1235,14 +1290,14 @@
 
         if out is None:
             if dtype is None:
-                return _VF.norm(input, p, _dim, keepdim=keepdim)
+                return _VF.norm(input, p, _dim, keepdim=keepdim)  # type: ignore
             else:
-                return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype)
+                return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype)  # type: ignore
         else:
             if dtype is None:
-                return _VF.norm(input, p, _dim, keepdim=keepdim, out=out)
+                return _VF.norm(input, p, _dim, keepdim=keepdim, out=out)  # type: ignore
             else:
-                return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype, out=out)
+                return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype, out=out)  # type: ignore
 
 def chain_matmul(*matrices):
     r"""Returns the matrix product of the :math:`N` 2-D tensors. This product is efficiently computed
@@ -1276,7 +1331,7 @@
     if not torch.jit.is_scripting():
         if any(type(t) is not Tensor for t in matrices) and has_torch_function(matrices):
             return handle_torch_function(chain_matmul, matrices, *matrices)
-    return _VF.chain_matmul(matrices)
+    return _VF.chain_matmul(matrices)  # type: ignore
 
 
 def _lu_impl(A, pivot=True, get_infos=False, out=None):
@@ -1353,12 +1408,17 @@
     # If get_infos is True, then we don't need to check for errors and vice versa
     return torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos))
 
-def _check_list_size(out_len, get_infos, out):
-    # type: (int, bool, List[Tensor]) -> None
+
+if TYPE_CHECKING:
+    _ListOrSeq = Sequence[Tensor]
+else:
+    _ListOrSeq = List[Tensor]
+
+def _check_list_size(out_len: int, get_infos: bool, out: _ListOrSeq) -> None:
     get_infos_int = 1 if get_infos else 0
     if out_len - get_infos_int != 2:
         raise TypeError("expected tuple of {} elements but got {}"
-                        .format(2 + int(get_infos), len(out_len)))
+                        .format(2 + int(get_infos), out_len))
     if not isinstance(out, (tuple, list)):
         raise TypeError("argument 'out' must be tuple of Tensors, not {}"
                         .format(type(out).__name__))
diff --git a/torch/jit/annotations.py b/torch/jit/annotations.py
index 73dd008..18dc666 100644
--- a/torch/jit/annotations.py
+++ b/torch/jit/annotations.py
@@ -175,7 +175,7 @@
     lines_with_type = list(filter(lambda line: 'type' in line[1], lines))
 
     if len(type_lines) == 0:
-        type_pattern = re.compile('#[\t ]*type[\t ]*:')
+        type_pattern = re.compile('#[\t ]*type[\t ]*(?!: ignore$):')
         wrong_type_lines = list(filter(lambda line: type_pattern.search(line[1]), lines))
         if len(wrong_type_lines) > 0:
             raise RuntimeError("The annotation prefix in line " + str(wrong_type_lines[0][0])
diff --git a/torch/nn/functional.pyi.in b/torch/nn/functional.pyi.in
index 40bf9e8..a3db0d1 100644
--- a/torch/nn/functional.pyi.in
+++ b/torch/nn/functional.pyi.in
@@ -1,6 +1,6 @@
 from torch import Tensor
 from torch.types import _size
-from typing import Any, Optional, Tuple, Dict, List, Callable
+from typing import Any, Optional, Tuple, Dict, List, Callable, Sequence
 from .common_types import _ratio_any_t
 
 # 'TypedDict' is a new accepted type that represents a dictionary with a fixed set of allowed keys.
@@ -300,7 +300,7 @@
 def affine_grid(theta: Tensor, size: List[int], align_corners: Optional[Any] = ...) -> Tensor: ...
 
 
-def pad(input: Tensor, pad: List[int], mode: str = ..., value: float = ...) -> Tensor: ...
+def pad(input: Tensor, pad: Sequence[int], mode: str = ..., value: float = ...) -> Tensor: ...
 
 
 def pairwise_distance(x1: Tensor, x2: Tensor, p: float = ..., eps: float = ..., keepdim: bool = ...) -> Tensor: ...
diff --git a/torch/overrides.py b/torch/overrides.py
index a0a44ee..3d2dadd 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -180,6 +180,7 @@
         Tensor.unflatten,
     }
 
+
 @functools.lru_cache(None)
 def get_testing_overrides() -> Dict[Callable, Callable]:
     """Return a dict containing dummy overrides for all overridable functions
diff --git a/torch/quantization/_equalize.py b/torch/quantization/_equalize.py
index a2ed616..51cbe07 100644
--- a/torch/quantization/_equalize.py
+++ b/torch/quantization/_equalize.py
@@ -127,10 +127,10 @@
     if curr_modules.keys() != prev_modules.keys():
         raise ValueError("The keys to the given mappings must have the same set of names of modules")
 
-    summed_norms = 0
+    summed_norms = torch.tensor(0.)
     if None in prev_modules.values():
         return False
     for name in curr_modules.keys():
         difference = curr_modules[name].weight.sub(prev_modules[name].weight)
         summed_norms += torch.norm(difference)
-    return summed_norms < threshold
+    return bool(summed_norms < threshold)