[BE] enable UFMT for top-level files `torch/*.py` (#127707)

Part of #123062

- #123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127707
Approved by: https://github.com/ezyang
diff --git a/.lintrunner.toml b/.lintrunner.toml
index 92a7fc0..c28399b 100644
--- a/.lintrunner.toml
+++ b/.lintrunner.toml
@@ -1556,7 +1556,6 @@
     'torch/distributed/tensor/parallel/style.py',
     'torch/fft/__init__.py',
     'torch/func/__init__.py',
-    'torch/functional.py',
     'torch/futures/__init__.py',
     'torch/fx/__init__.py',
     'torch/fx/_compatibility.py',
@@ -1642,8 +1641,6 @@
     'torch/fx/subgraph_rewriter.py',
     'torch/fx/tensor_type.py',
     'torch/fx/traceback.py',
-    'torch/hub.py',
-    'torch/library.py',
     'torch/linalg/__init__.py',
     'torch/monitor/__init__.py',
     'torch/nested/__init__.py',
@@ -1767,11 +1764,6 @@
     'torch/nn/utils/rnn.py',
     'torch/nn/utils/spectral_norm.py',
     'torch/nn/utils/weight_norm.py',
-    'torch/overrides.py',
-    'torch/quasirandom.py',
-    'torch/random.py',
-    'torch/return_types.py',
-    'torch/serialization.py',
     'torch/signal/__init__.py',
     'torch/signal/windows/__init__.py',
     'torch/signal/windows/windows.py',
diff --git a/torch/_guards.py b/torch/_guards.py
index 9204170..917f5dc 100644
--- a/torch/_guards.py
+++ b/torch/_guards.py
@@ -2,7 +2,6 @@
 from __future__ import annotations
 
 import contextlib
-
 import dataclasses
 import enum
 import functools
@@ -31,6 +30,7 @@
 from torch.utils._traceback import CapturedTraceback
 from torch.utils.weak import WeakTensorKeyDictionary
 
+
 log = logging.getLogger(__name__)
 
 
@@ -40,7 +40,6 @@
     # Import the following modules during type checking to enable code intelligence features,
     # such as auto-completion in tools like pylance, even when these modules are not explicitly
     # imported in user code.
-
     import torch
 
 
@@ -176,7 +175,7 @@
     def sort_key(self):
         # Put the duplicate input guards at the end. The duplicate guards have
         # two sources while guard.name only considers one source.
-        from ._dynamo.guards import GuardBuilder
+        from torch._dynamo.guards import GuardBuilder
 
         is_duplicate_input = (
             isinstance(self.create_fn, functools.partial)
diff --git a/torch/_lobpcg.py b/torch/_lobpcg.py
index 3f7bdf4..6a14883 100644
--- a/torch/_lobpcg.py
+++ b/torch/_lobpcg.py
@@ -7,9 +7,8 @@
 from typing import Dict, Optional, Tuple
 
 import torch
-from torch import Tensor
-from . import _linalg_utils as _utils
-from .overrides import handle_torch_function, has_torch_function
+from torch import _linalg_utils as _utils, Tensor
+from torch.overrides import handle_torch_function, has_torch_function
 
 
 __all__ = ["lobpcg"]
diff --git a/torch/_lowrank.py b/torch/_lowrank.py
index 4641c4c..bbe01ed 100644
--- a/torch/_lowrank.py
+++ b/torch/_lowrank.py
@@ -6,9 +6,8 @@
 from typing import Optional, Tuple
 
 import torch
-from torch import Tensor
-from . import _linalg_utils as _utils
-from .overrides import handle_torch_function, has_torch_function
+from torch import _linalg_utils as _utils, Tensor
+from torch.overrides import handle_torch_function, has_torch_function
 
 
 def get_approximate_basis(
diff --git a/torch/_tensor.py b/torch/_tensor.py
index 5ea2985..36df3d6 100644
--- a/torch/_tensor.py
+++ b/torch/_tensor.py
@@ -761,22 +761,22 @@
         return torch.norm(self, p, dim, keepdim, dtype=dtype)
 
     def solve(self, other):
-        from ._linalg_utils import solve
+        from torch._linalg_utils import solve
 
         return solve(self, other)
 
     def lstsq(self, other):
-        from ._linalg_utils import lstsq
+        from torch._linalg_utils import lstsq
 
         return lstsq(self, other)
 
     def eig(self, eigenvectors=False):
-        from ._linalg_utils import eig
+        from torch._linalg_utils import eig
 
         return eig(self, eigenvectors=eigenvectors)
 
     def symeig(self, eigenvectors=False):
-        from ._linalg_utils import _symeig
+        from torch._linalg_utils import _symeig
 
         return _symeig(self, eigenvectors=eigenvectors)
 
diff --git a/torch/functional.py b/torch/functional.py
index a836c06..20e1cf1 100644
--- a/torch/functional.py
+++ b/torch/functional.py
@@ -1,47 +1,46 @@
 # mypy: allow-untyped-defs
-from typing import (
-    List, Tuple, Optional, Union, Any, Sequence, TYPE_CHECKING
-)
-import operator
 import itertools
+import operator
+from typing import Any, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
 
 import torch
-from torch._C import _add_docstr
 import torch.nn.functional as F
-from ._lowrank import svd_lowrank, pca_lowrank
-from .overrides import (
-    has_torch_function, has_torch_function_unary, has_torch_function_variadic,
-    handle_torch_function)
-from ._jit_internal import boolean_dispatch
-from ._jit_internal import _overload as overload
+from torch import _VF, Tensor
+from torch._C import _add_docstr
+from torch._jit_internal import _overload as overload, boolean_dispatch
+from torch._lowrank import pca_lowrank, svd_lowrank
+from torch.overrides import (
+    handle_torch_function,
+    has_torch_function,
+    has_torch_function_unary,
+    has_torch_function_variadic,
+)
 
-Tensor = torch.Tensor
-from torch import _VF
 
 __all__ = [
-    'atleast_1d',
-    'atleast_2d',
-    'atleast_3d',
-    'align_tensors',
-    'broadcast_shapes',
-    'broadcast_tensors',
-    'cartesian_prod',
-    'block_diag',
-    'cdist',
-    'chain_matmul',
-    'einsum',
-    'istft',
-    'lu',
-    'norm',
-    'meshgrid',
-    'pca_lowrank',
-    'split',
-    'stft',
-    'svd_lowrank',
-    'tensordot',
-    'unique',
-    'unique_consecutive',
-    'unravel_index',
+    "atleast_1d",
+    "atleast_2d",
+    "atleast_3d",
+    "align_tensors",
+    "broadcast_shapes",
+    "broadcast_tensors",
+    "cartesian_prod",
+    "block_diag",
+    "cdist",
+    "chain_matmul",
+    "einsum",
+    "istft",
+    "lu",
+    "norm",
+    "meshgrid",
+    "pca_lowrank",
+    "split",
+    "stft",
+    "svd_lowrank",
+    "tensordot",
+    "unique",
+    "unique_consecutive",
+    "unravel_index",
 ]
 
 
@@ -124,16 +123,25 @@
             if isinstance(shape, (tuple, list)):
                 for i in range(-1, -1 - len(shape), -1):
                     if shape[i] < 0:
-                        raise RuntimeError(f"Trying to create tensor with negative dimension ({shape[i]}): ({shape[i]})")
+                        raise RuntimeError(
+                            f"Trying to create tensor with negative dimension ({shape[i]}): ({shape[i]})"
+                        )
                     # NB: result is initialized to 1 so this is effectively an
                     # equals one test
-                    if guard_size_oblivious(shape[i] == 1) or guard_size_oblivious(shape[i] == result[i]):
+                    if guard_size_oblivious(shape[i] == 1) or guard_size_oblivious(
+                        shape[i] == result[i]
+                    ):
                         continue
                     if result[i] != 1:
-                        raise RuntimeError("Shape mismatch: objects cannot be broadcast to a single shape")
+                        raise RuntimeError(
+                            "Shape mismatch: objects cannot be broadcast to a single shape"
+                        )
                     result[i] = shape[i]
             else:
-                raise RuntimeError("Input shapes should be of type ints, a tuple of ints, or a list of ints, got ", shape)
+                raise RuntimeError(
+                    "Input shapes should be of type ints, a tuple of ints, or a list of ints, got ",
+                    shape,
+                )
         return torch.Size(result)
     else:
         # with implementation above, torch.jit.trace hardcodes the sizes which makes subsequent replays fail
@@ -188,7 +196,8 @@
     """
     if has_torch_function_unary(tensor):
         return handle_torch_function(
-            split, (tensor,), tensor, split_size_or_sections, dim=dim)
+            split, (tensor,), tensor, split_size_or_sections, dim=dim
+        )
     # Overwriting reason:
     # This dispatches to two ATen functions depending on the type of
     # split_size_or_sections. The branching code is in _tensor.py, which we
@@ -335,10 +344,13 @@
                 [ 0.3311,  5.5201, -3.0356]])
     """
     import torch.backends.opt_einsum as opt_einsum
+
     # This wrapper exists to support variadic args.
     if len(args) < 2:
-        raise ValueError('einsum(): must specify the equation string and at least one operand, '
-                         'or at least one operand and its subscripts list')
+        raise ValueError(
+            "einsum(): must specify the equation string and at least one operand, "
+            "or at least one operand and its subscripts list"
+        )
 
     equation = None
     operands = None
@@ -350,19 +362,21 @@
         # input operands into a tensorlist (List[Tensor]).
         def parse_subscript(n: int) -> str:
             if n == Ellipsis:
-                return '...'
+                return "..."
             if n >= 0 and n < 26:
-                return chr(ord('A') + n)
+                return chr(ord("A") + n)
             if n >= 26 and n < 52:
-                return chr(ord('a') + n - 26)
-            raise ValueError('einsum(): subscript in subscript list is not within the valid range [0, 52)')
+                return chr(ord("a") + n - 26)
+            raise ValueError(
+                "einsum(): subscript in subscript list is not within the valid range [0, 52)"
+            )
 
         # Parse subscripts for input operands
-        equation = ','.join(''.join(parse_subscript(s) for s in l) for l in args[1::2])
+        equation = ",".join("".join(parse_subscript(s) for s in l) for l in args[1::2])
 
         # Parse optional output subscripts (provided when the number of arguments is odd)
         if len(args) % 2 == 1:
-            equation += '->' + ''.join(parse_subscript(s) for s in args[-1])
+            equation += "->" + "".join(parse_subscript(s) for s in args[-1])
             operands = args[:-1:2]
         else:
             operands = args[::2]
@@ -388,7 +402,9 @@
     path = None
     if opt_einsum.is_available():
         _opt_einsum = opt_einsum.get_opt_einsum()
-        tupled_path = _opt_einsum.contract_path(equation, *operands, optimize=opt_einsum.strategy)[0]
+        tupled_path = _opt_einsum.contract_path(
+            equation, *operands, optimize=opt_einsum.strategy
+        )[0]
         # flatten path for dispatching to C++
         path = [item for pair in tupled_path for item in pair]
     return _VF.einsum(equation, operands, path=path)  # type: ignore[attr-defined]
@@ -397,10 +413,13 @@
 # This wrapper exists to support variadic args.
 if TYPE_CHECKING:
     # The JIT doesn't understand Union, so only add type annotation for mypy
-    def meshgrid(*tensors: Union[Tensor, List[Tensor]],
-                 indexing: Optional[str] = None) -> Tuple[Tensor, ...]:
+    def meshgrid(
+        *tensors: Union[Tensor, List[Tensor]], indexing: Optional[str] = None
+    ) -> Tuple[Tensor, ...]:
         return _meshgrid(*tensors, indexing=indexing)
+
 else:
+
     def meshgrid(*tensors, indexing: Optional[str] = None) -> Tuple[Tensor, ...]:
         r"""Creates grids of coordinates specified by the 1D inputs in `attr`:tensors.
 
@@ -509,15 +528,22 @@
     # kwarg for forward compatibility reasons.
     #
     # Remove this two weeks after landing.
-    kwargs = {} if indexing is None else {'indexing': indexing}
+    kwargs = {} if indexing is None else {"indexing": indexing}
     return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
 
 
-def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None,
-         win_length: Optional[int] = None, window: Optional[Tensor] = None,
-         center: bool = True, pad_mode: str = 'reflect', normalized: bool = False,
-         onesided: Optional[bool] = None,
-         return_complex: Optional[bool] = None) -> Tensor:
+def stft(
+    input: Tensor,
+    n_fft: int,
+    hop_length: Optional[int] = None,
+    win_length: Optional[int] = None,
+    window: Optional[Tensor] = None,
+    center: bool = True,
+    pad_mode: str = "reflect",
+    normalized: bool = False,
+    onesided: Optional[bool] = None,
+    return_complex: Optional[bool] = None,
+) -> Tensor:
     r"""Short-time Fourier transform (STFT).
 
     .. warning::
@@ -652,9 +678,19 @@
     """
     if has_torch_function_unary(input):
         return handle_torch_function(
-            stft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length,
-            window=window, center=center, pad_mode=pad_mode, normalized=normalized,
-            onesided=onesided, return_complex=return_complex)
+            stft,
+            (input,),
+            input,
+            n_fft,
+            hop_length=hop_length,
+            win_length=win_length,
+            window=window,
+            center=center,
+            pad_mode=pad_mode,
+            normalized=normalized,
+            onesided=onesided,
+            return_complex=return_complex,
+        )
     # NOTE: Do not edit. This code will be removed once the forward-compatibility
     #       period is over for PR #73432
     if center:
@@ -663,8 +699,16 @@
         pad = int(n_fft // 2)
         input = F.pad(input.view(extended_shape), [pad, pad], pad_mode)
         input = input.view(input.shape[-signal_dim:])
-    return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore[attr-defined]
-                    normalized, onesided, return_complex)
+    return _VF.stft(  # type: ignore[attr-defined]
+        input,
+        n_fft,
+        hop_length,
+        win_length,
+        window,
+        normalized,
+        onesided,
+        return_complex,
+    )
 
 
 istft = _add_docstr(
@@ -746,7 +790,8 @@
 Returns:
     Tensor: Least squares estimation of the original signal of shape `(B?, length)` where
         `B?` is an optional batch dimension from the input tensor.
-""")
+""",
+)
 
 
 if TYPE_CHECKING:
@@ -758,9 +803,13 @@
     _unique_impl_out = Tuple[Tensor, Tensor, Tensor]
 
 
-def _unique_impl(input: Tensor, sorted: bool = True,
-                 return_inverse: bool = False, return_counts: bool = False,
-                 dim: Optional[int] = None) -> _unique_impl_out:
+def _unique_impl(
+    input: Tensor,
+    sorted: bool = True,
+    return_inverse: bool = False,
+    return_counts: bool = False,
+    dim: Optional[int] = None,
+) -> _unique_impl_out:
     r"""unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None) -> Tuple[Tensor, Tensor, Tensor]
 
     Returns the unique elements of the input tensor.
@@ -896,8 +945,14 @@
     """
     if has_torch_function_unary(input):
         return handle_torch_function(
-            unique, (input,), input, sorted=sorted, return_inverse=return_inverse,
-            return_counts=return_counts, dim=dim)
+            unique,
+            (input,),
+            input,
+            sorted=sorted,
+            return_inverse=return_inverse,
+            return_counts=return_counts,
+            dim=dim,
+        )
 
     if dim is not None:
         output, inverse_indices, counts = _VF.unique_dim(
@@ -917,9 +972,12 @@
     return output, inverse_indices, counts
 
 
-def _unique_consecutive_impl(input: Tensor, return_inverse: bool = False,
-                             return_counts: bool = False,
-                             dim: Optional[int] = None) -> _unique_impl_out:
+def _unique_consecutive_impl(
+    input: Tensor,
+    return_inverse: bool = False,
+    return_counts: bool = False,
+    dim: Optional[int] = None,
+) -> _unique_impl_out:
     r"""Eliminates all but the first element from every consecutive group of equivalent elements.
 
     .. note:: This function is different from :func:`torch.unique` in the sense that this function
@@ -971,14 +1029,22 @@
     """
     if has_torch_function_unary(input):
         return handle_torch_function(
-            unique_consecutive, (input,), input, return_inverse=return_inverse,
-            return_counts=return_counts, dim=dim)
+            unique_consecutive,
+            (input,),
+            input,
+            return_inverse=return_inverse,
+            return_counts=return_counts,
+            dim=dim,
+        )
     output, inverse_indices, counts = _VF.unique_consecutive(  # type: ignore[attr-defined]
-        input, return_inverse=return_inverse, return_counts=return_counts, dim=dim)
+        input, return_inverse=return_inverse, return_counts=return_counts, dim=dim
+    )
     return output, inverse_indices, counts
 
 
-def _return_counts(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
+def _return_counts(
+    input, sorted=True, return_inverse=False, return_counts=False, dim=None
+):
     # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
 
     if has_torch_function_unary(input):
@@ -988,7 +1054,9 @@
     return output, counts
 
 
-def _return_output(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
+def _return_output(
+    input, sorted=True, return_inverse=False, return_counts=False, dim=None
+):
     # type: (Tensor, bool, bool, bool, Optional[int]) -> Tensor
 
     if has_torch_function_unary(input):
@@ -998,59 +1066,72 @@
     return output
 
 
-def _return_inverse(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
+def _return_inverse(
+    input, sorted=True, return_inverse=False, return_counts=False, dim=None
+):
     # type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
 
     if has_torch_function_unary(input):
         return _unique_impl(input, sorted, return_inverse, return_counts, dim)
 
-    output, inverse_indices, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim)
+    output, inverse_indices, _ = _unique_impl(
+        input, sorted, return_inverse, return_counts, dim
+    )
     return output, inverse_indices
 
 
 _return_inverse_false = boolean_dispatch(
-    arg_name='return_counts',
+    arg_name="return_counts",
     arg_index=3,
     default=False,
     if_true=_return_counts,
     if_false=_return_output,
     module_name=__name__,
-    func_name='unique')
+    func_name="unique",
+)
 
 _return_inverse_true = boolean_dispatch(
-    arg_name='return_counts',
+    arg_name="return_counts",
     arg_index=3,
     default=False,
     if_true=_unique_impl,
     if_false=_return_inverse,
     module_name=__name__,
-    func_name='unique')
+    func_name="unique",
+)
 
 # The return type of unique depends on `return_inverse`, and `return_counts` so in order to
 # resolve the output type in TorchScript we need to statically know the value of both parameters
 
 unique = boolean_dispatch(
-    arg_name='return_inverse',
+    arg_name="return_inverse",
     arg_index=2,
     default=False,
     if_true=_return_inverse_true,
     if_false=_return_inverse_false,
     module_name=__name__,
-    func_name='unique')
+    func_name="unique",
+)
 unique.__doc__ = _unique_impl.__doc__
 
 
-def _consecutive_return_counts(input, return_inverse=False, return_counts=False, dim=None):
+def _consecutive_return_counts(
+    input, return_inverse=False, return_counts=False, dim=None
+):
     # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
 
     if has_torch_function_unary(input):
         return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
 
-    output, _, counts = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
+    output, _, counts = _unique_consecutive_impl(
+        input, return_inverse, return_counts, dim
+    )
     return output, counts
 
 
-def _consecutive_return_output(input, return_inverse=False, return_counts=False, dim=None):
+def _consecutive_return_output(
+    input, return_inverse=False, return_counts=False, dim=None
+):
     # type: (Tensor, bool, bool, Optional[int]) -> Tensor
 
     if has_torch_function_unary(input):
@@ -1060,45 +1141,52 @@
     return output
 
 
-def _consecutive_return_inverse(input, return_inverse=False, return_counts=False, dim=None):
+def _consecutive_return_inverse(
+    input, return_inverse=False, return_counts=False, dim=None
+):
     # type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
 
     if has_torch_function_unary(input):
         return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
 
-    output, inverse_indices, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
+    output, inverse_indices, _ = _unique_consecutive_impl(
+        input, return_inverse, return_counts, dim
+    )
     return output, inverse_indices
 
 
 _consecutive_return_inverse_false = boolean_dispatch(
-    arg_name='return_counts',
+    arg_name="return_counts",
     arg_index=1,
     default=False,
     if_true=_consecutive_return_counts,
     if_false=_consecutive_return_output,
     module_name=__name__,
-    func_name='unique_consecutive')
+    func_name="unique_consecutive",
+)
 
 _consecutive_return_inverse_true = boolean_dispatch(
-    arg_name='return_counts',
+    arg_name="return_counts",
     arg_index=1,
     default=False,
     if_true=_unique_consecutive_impl,
     if_false=_consecutive_return_inverse,
     module_name=__name__,
-    func_name='unique_consecutive')
+    func_name="unique_consecutive",
+)
 
 # The return type of unique depends on `return_inverse`, and `return_counts` so in order to
 # resolve the output type in TorchScript we need to statically know the value of both parameters
 
 unique_consecutive = boolean_dispatch(
-    arg_name='return_inverse',
+    arg_name="return_inverse",
     arg_index=2,
     default=False,
     if_true=_consecutive_return_inverse_true,
     if_false=_consecutive_return_inverse_false,
     module_name=__name__,
-    func_name='unique_consecutive')
+    func_name="unique_consecutive",
+)
 unique_consecutive.__doc__ = _unique_consecutive_impl.__doc__
 
 if TYPE_CHECKING:
@@ -1106,24 +1194,50 @@
     # There's no good way to use this type annotation without breaking JIT
     # overloads. So leave untyped for mypy for now.
 else:
+
     @overload
-    def tensordot(a, b, dims: int = 2, out: Optional[torch.Tensor] = None):
+    def tensordot(
+        a,
+        b,
+        dims: int = 2,
+        out: Optional[torch.Tensor] = None,
+    ):
         pass
 
-    @overload  # noqa: F811
-    def tensordot(a, b, dims: Tuple[List[int], List[int]], out: Optional[torch.Tensor] = None):  # noqa: F811
+    @overload
+    def tensordot(  # noqa: F811
+        a,
+        b,
+        dims: Tuple[List[int], List[int]],
+        out: Optional[torch.Tensor] = None,
+    ):
         pass
 
-    @overload  # noqa: F811
-    def tensordot(a, b, dims: List[List[int]], out: Optional[torch.Tensor] = None):  # noqa: F811
+    @overload
+    def tensordot(  # noqa: F811
+        a,
+        b,
+        dims: List[List[int]],
+        out: Optional[torch.Tensor] = None,
+    ):
         pass
 
-    @overload  # noqa: F811
-    def tensordot(a, b, dims: torch.Tensor, out: Optional[torch.Tensor] = None):  # noqa: F811
+    @overload
+    def tensordot(  # noqa: F811
+        a,
+        b,
+        dims: torch.Tensor,
+        out: Optional[torch.Tensor] = None,
+    ):
         pass
 
 
-def tensordot(a, b, dims=2, out: Optional[torch.Tensor] = None):  # noqa: F811
+def tensordot(  # noqa: F811
+    a,
+    b,
+    dims=2,
+    out: Optional[torch.Tensor] = None,
+):
     r"""Returns a contraction of a and b over multiple dimensions.
 
     :attr:`tensordot` implements a generalized matrix product.
@@ -1178,10 +1292,12 @@
         return handle_torch_function(tensordot, (a, b), a, b, dims=dims, out=out)
 
     if not isinstance(dims, (tuple, list, torch.Tensor, int, torch.SymInt)):
-        raise RuntimeError("tensordot expects dims to be int or "
-                           + "Tuple[List[int], List[int]] or "
-                           + "List[List[int]] containing two lists, but got "
-                           + f"dims={dims}")
+        raise RuntimeError(
+            "tensordot expects dims to be int or "
+            + "Tuple[List[int], List[int]] or "
+            + "List[List[int]] containing two lists, but got "
+            + f"dims={dims}"
+        )
 
     dims_a: List[int] = []
     dims_b: List[int] = []
@@ -1206,7 +1322,9 @@
         if dims < 0:
             raise RuntimeError(f"tensordot expects dims >= 0, but got dims={dims}")
         if dims > min(a.dim(), b.dim()):
-            raise RuntimeError(f"tensordot expects dims < ndim_a or ndim_b, but got dims={dims}")
+            raise RuntimeError(
+                f"tensordot expects dims < ndim_a or ndim_b, but got dims={dims}"
+            )
         dims_a = list(range(-dims, 0))
         dims_b = list(range(dims))
 
@@ -1287,7 +1405,7 @@
     return torch._C._VariableFunctions.block_diag(tensors)  # type: ignore[attr-defined]
 
 
-def cdist(x1, x2, p=2., compute_mode='use_mm_for_euclid_dist_if_necessary'):
+def cdist(x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary"):
     # type: (Tensor, Tensor, float, str) -> (Tensor)
     r"""Computes batched the p-norm distance between each pair of the two collections of row vectors.
 
@@ -1331,12 +1449,13 @@
     """
     if has_torch_function_variadic(x1, x2):
         return handle_torch_function(
-            cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode)
-    if compute_mode == 'use_mm_for_euclid_dist_if_necessary':
+            cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode
+        )
+    if compute_mode == "use_mm_for_euclid_dist_if_necessary":
         return _VF.cdist(x1, x2, p, None)  # type: ignore[attr-defined]
-    elif compute_mode == 'use_mm_for_euclid_dist':
+    elif compute_mode == "use_mm_for_euclid_dist":
         return _VF.cdist(x1, x2, p, 1)  # type: ignore[attr-defined]
-    elif compute_mode == 'donot_use_mm_for_euclid_dist':
+    elif compute_mode == "donot_use_mm_for_euclid_dist":
         return _VF.cdist(x1, x2, p, 2)  # type: ignore[attr-defined]
     else:
         raise ValueError(f"{compute_mode} is not a valid value for compute_mode")
@@ -1478,27 +1597,62 @@
     # TODO: type dim as BroadcastingList when
     # https://github.com/pytorch/pytorch/issues/33782 is fixed
     @overload
-    def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None):
+    def norm(
+        input,
+        p="fro",
+        dim=None,
+        keepdim=False,
+        out=None,
+        dtype=None,
+    ):
         # type: (Tensor, str, Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor
         pass
 
-    @overload  # noqa: F811
-    def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None):  # noqa: F811
+    @overload
+    def norm(  # noqa: F811
+        input,
+        p="fro",
+        dim=None,
+        keepdim=False,
+        out=None,
+        dtype=None,
+    ):
         # type: (Tensor, Optional[number], Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor
         pass
 
-    @overload  # noqa: F811
-    def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None):  # noqa: F811
+    @overload
+    def norm(  # noqa: F811
+        input,
+        p="fro",
+        dim=None,
+        keepdim=False,
+        out=None,
+        dtype=None,
+    ):
         # type: (Tensor, Optional[number], Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor
         pass
 
-    @overload  # noqa: F811
-    def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None):  # noqa: F811
+    @overload
+    def norm(  # noqa: F811
+        input,
+        p="fro",
+        dim=None,
+        keepdim=False,
+        out=None,
+        dtype=None,
+    ):
         # type: (Tensor, str, Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor
         pass
 
 
-def norm(input, p: Optional[Union[float, str]] = "fro", dim=None, keepdim=False, out=None, dtype=None):  # noqa: F811
+def norm(  # noqa: F811
+    input,
+    p: Optional[Union[float, str]] = "fro",
+    dim=None,
+    keepdim=False,
+    out=None,
+    dtype=None,
+):
     r"""Returns the matrix norm or vector norm of a given tensor.
 
     .. warning::
@@ -1594,14 +1748,19 @@
 
     if has_torch_function_unary(input):
         return handle_torch_function(
-            norm, (input,), input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype)
+            norm, (input,), input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype
+        )
 
     # NB. All the repeated code and weird python is to please TorchScript.
     #     For a more compact implementation see the relevant function in `_refs/__init__.py`
 
     # We don't do this for MPS or sparse tensors
-    if input.layout == torch.strided and input.device.type in \
-            ("cpu", "cuda", "meta", torch.utils.backend_registration._privateuse1_backend_name):
+    if input.layout == torch.strided and input.device.type in (
+        "cpu",
+        "cuda",
+        "meta",
+        torch.utils.backend_registration._privateuse1_backend_name,
+    ):
         if dim is not None:
             if isinstance(dim, (int, torch.SymInt)):
                 _dim = [dim]
@@ -1611,11 +1770,17 @@
             _dim = None  # type: ignore[assignment]
 
         if isinstance(p, str):
-            if p == "fro" and (dim is None or isinstance(dim, (int, torch.SymInt)) or len(dim) <= 2):
+            if p == "fro" and (
+                dim is None or isinstance(dim, (int, torch.SymInt)) or len(dim) <= 2
+            ):
                 if out is None:
-                    return torch.linalg.vector_norm(input, 2, _dim, keepdim, dtype=dtype)
+                    return torch.linalg.vector_norm(
+                        input, 2, _dim, keepdim, dtype=dtype
+                    )
                 else:
-                    return torch.linalg.vector_norm(input, 2, _dim, keepdim, dtype=dtype, out=out)
+                    return torch.linalg.vector_norm(
+                        input, 2, _dim, keepdim, dtype=dtype, out=out
+                    )
 
             # Here we either call the nuclear norm, or we call matrix_norm with some arguments
             # that will throw an error
@@ -1624,14 +1789,18 @@
             if out is None:
                 return torch.linalg.matrix_norm(input, p, _dim, keepdim, dtype=dtype)
             else:
-                return torch.linalg.matrix_norm(input, p, _dim, keepdim, dtype=dtype, out=out)
+                return torch.linalg.matrix_norm(
+                    input, p, _dim, keepdim, dtype=dtype, out=out
+                )
         else:
             # NB. p should be Union[str, number], not Optional!
             _p = 2.0 if p is None else p
             if out is None:
                 return torch.linalg.vector_norm(input, _p, _dim, keepdim, dtype=dtype)
             else:
-                return torch.linalg.vector_norm(input, _p, _dim, keepdim, dtype=dtype, out=out)
+                return torch.linalg.vector_norm(
+                    input, _p, _dim, keepdim, dtype=dtype, out=out
+                )
 
     ndim = input.dim()
 
@@ -1641,7 +1810,7 @@
             if p == "fro":
                 return _VF.frobenius_norm(input, dim=(), keepdim=keepdim)
         if not isinstance(p, str):
-            _dim = [i for i in range(ndim)]  # noqa: C416 TODO: rewrite as list(range(m))
+            _dim = list(range(ndim))
             return _VF.norm(input, p, dim=_dim, keepdim=keepdim)  # type: ignore[attr-defined]
 
     # TODO: when https://github.com/pytorch/pytorch/issues/33782 is fixed
@@ -1695,7 +1864,10 @@
             else:
                 return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype, out=out)  # type: ignore[attr-defined]
 
-def unravel_index(indices: Tensor, shape: Union[int, Sequence[int], torch.Size]) -> Tuple[Tensor, ...]:
+
+def unravel_index(
+    indices: Tensor, shape: Union[int, Sequence[int], torch.Size]
+) -> Tuple[Tensor, ...]:
     r"""Converts a tensor of flat indices into a tuple of coordinate tensors that
     index into an arbitrary tensor of the specified shape.
 
@@ -1745,19 +1917,23 @@
          tensor([[34], [78]]))
     """
     if has_torch_function_unary(indices):
-        return handle_torch_function(
-            unravel_index, (indices,), indices, shape=shape)
+        return handle_torch_function(unravel_index, (indices,), indices, shape=shape)
     res_tensor = _unravel_index(indices, shape)
     return res_tensor.unbind(-1)
 
+
 def _unravel_index(indices: Tensor, shape: Union[int, Sequence[int]]) -> Tensor:
     torch._check_type(
-        not indices.is_complex() and not indices.is_floating_point() and not indices.dtype == torch.bool,
-        lambda: f"expected 'indices' to be integer dtype, but got {indices.dtype}")
+        not indices.is_complex()
+        and not indices.is_floating_point()
+        and not indices.dtype == torch.bool,
+        lambda: f"expected 'indices' to be integer dtype, but got {indices.dtype}",
+    )
 
     torch._check_type(
         isinstance(shape, (int, torch.SymInt, Sequence)),
-        lambda: f"expected 'shape' to be int or sequence of ints, but got {type(shape)}")
+        lambda: f"expected 'shape' to be int or sequence of ints, but got {type(shape)}",
+    )
 
     if isinstance(shape, (int, torch.SymInt)):
         shape = torch.Size([shape])
@@ -1765,18 +1941,29 @@
         for dim in shape:
             torch._check_type(
                 isinstance(dim, (int, torch.SymInt)),
-                lambda: f"expected 'shape' sequence to only contain ints, but got {type(dim)}")
+                lambda: f"expected 'shape' sequence to only contain ints, but got {type(dim)}",
+            )
         shape = torch.Size(shape)
 
     torch._check_value(
         all(dim >= 0 for dim in shape),
-        lambda: f"'shape' cannot have negative values, but got {tuple(shape)}")
+        lambda: f"'shape' cannot have negative values, but got {tuple(shape)}",
+    )
 
-    coefs = list(reversed(list(itertools.accumulate(reversed(shape[1:] + torch.Size([1])), func=operator.mul))))
+    coefs = list(
+        reversed(
+            list(
+                itertools.accumulate(
+                    reversed(shape[1:] + torch.Size([1])), func=operator.mul
+                )
+            )
+        )
+    )
     return indices.unsqueeze(-1).floor_divide(
         torch.tensor(coefs, device=indices.device, dtype=torch.int64)
     ) % torch.tensor(shape, device=indices.device, dtype=torch.int64)
 
+
 def chain_matmul(*matrices, out=None):
     r"""Returns the matrix product of the :math:`N` 2-D tensors. This product is efficiently computed
     using the matrix chain order algorithm which selects the order in which incurs the lowest cost in terms
@@ -1923,6 +2110,7 @@
     # If get_infos is True, then we don't need to check for errors and vice versa
     return torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos))
 
+
 if TYPE_CHECKING:
     _ListOrSeq = Sequence[Tensor]
 else:
@@ -1932,16 +2120,21 @@
 def _check_list_size(out_len: int, get_infos: bool, out: _ListOrSeq) -> None:
     get_infos_int = 1 if get_infos else 0
     if out_len - get_infos_int != 2:
-        raise TypeError(f"expected tuple of {2 + int(get_infos)} elements but got {out_len}")
+        raise TypeError(
+            f"expected tuple of {2 + int(get_infos)} elements but got {out_len}"
+        )
     if not isinstance(out, (tuple, list)):
-        raise TypeError(f"argument 'out' must be tuple of Tensors, not {type(out).__name__}")
+        raise TypeError(
+            f"argument 'out' must be tuple of Tensors, not {type(out).__name__}"
+        )
 
 
 def _lu_with_infos(A, pivot=True, get_infos=False, out=None):
     # type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor, Tensor]]) -> Tuple[Tensor, Tensor, Tensor]
     if has_torch_function_unary(A):
         return handle_torch_function(
-            lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out)
+            lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out
+        )
     result = _lu_impl(A, pivot, get_infos, out)
     if out is not None:
         _check_list_size(len(out), get_infos, out)
@@ -1957,7 +2150,8 @@
     # need to check for torch_function here so that we exit if
     if has_torch_function_unary(A):
         return handle_torch_function(
-            lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out)
+            lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out
+        )
     result = _lu_impl(A, pivot, get_infos, out)
     if out is not None:
         _check_list_size(len(out), get_infos, out)
@@ -1967,18 +2161,20 @@
     else:
         return result[0], result[1]  # A_LU, pivots
 
+
 # The return type of lu depends on `get_infos`, so in order to resolve the output type
 # of lu in TorchScript we need to statically know the value of `get_infos`
 lu = boolean_dispatch(
-    arg_name='get_infos',
+    arg_name="get_infos",
     arg_index=2,
     default=False,
     if_true=_lu_with_infos,
     if_false=_lu_no_infos,
     module_name=__name__,
-    func_name='lu')
+    func_name="lu",
+)
 lu.__doc__ = _lu_impl.__doc__
 
 
 def align_tensors(*tensors):
-    raise RuntimeError('`align_tensors` not yet implemented.')
+    raise RuntimeError("`align_tensors` not yet implemented.")
diff --git a/torch/hub.py b/torch/hub.py
index 213a129..57a07e2 100644
--- a/torch/hub.py
+++ b/torch/hub.py
@@ -8,22 +8,22 @@
 import shutil
 import sys
 import tempfile
-import torch
 import uuid
 import warnings
 import zipfile
 from pathlib import Path
-from typing import Dict, Optional, Any
+from typing import Any, Dict, Optional
 from typing_extensions import deprecated
 from urllib.error import HTTPError, URLError
-from urllib.request import urlopen, Request
 from urllib.parse import urlparse  # noqa: F401
+from urllib.request import Request, urlopen
+
+import torch
 from torch.serialization import MAP_LOCATION
 
-class _Faketqdm:  # type: ignore[no-redef]
 
-    def __init__(self, total=None, disable=False,
-                 unit=None, *args, **kwargs):
+class _Faketqdm:  # type: ignore[no-redef]
+    def __init__(self, total=None, disable=False, unit=None, *args, **kwargs):
         self.total = total
         self.disable = disable
         self.n = 0
@@ -57,7 +57,8 @@
         if self.disable:
             return
 
-        sys.stderr.write('\n')
+        sys.stderr.write("\n")
+
 
 try:
     from tqdm import tqdm  # If tqdm is installed use it, otherwise use the fake wrapper
@@ -65,25 +66,30 @@
     tqdm = _Faketqdm
 
 __all__ = [
-    'download_url_to_file',
-    'get_dir',
-    'help',
-    'list',
-    'load',
-    'load_state_dict_from_url',
-    'set_dir',
+    "download_url_to_file",
+    "get_dir",
+    "help",
+    "list",
+    "load",
+    "load_state_dict_from_url",
+    "set_dir",
 ]
 
 # matches bfd8deac from resnet18-bfd8deac.pth
-HASH_REGEX = re.compile(r'-([a-f0-9]*)\.')
+HASH_REGEX = re.compile(r"-([a-f0-9]*)\.")
 
-_TRUSTED_REPO_OWNERS = ("facebookresearch", "facebookincubator", "pytorch", "fairinternal")
-ENV_GITHUB_TOKEN = 'GITHUB_TOKEN'
-ENV_TORCH_HOME = 'TORCH_HOME'
-ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
-DEFAULT_CACHE_DIR = '~/.cache'
-VAR_DEPENDENCY = 'dependencies'
-MODULE_HUBCONF = 'hubconf.py'
+_TRUSTED_REPO_OWNERS = (
+    "facebookresearch",
+    "facebookincubator",
+    "pytorch",
+    "fairinternal",
+)
+ENV_GITHUB_TOKEN = "GITHUB_TOKEN"
+ENV_TORCH_HOME = "TORCH_HOME"
+ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME"
+DEFAULT_CACHE_DIR = "~/.cache"
+VAR_DEPENDENCY = "dependencies"
+MODULE_HUBCONF = "hubconf.py"
 READ_DATA_CHUNK = 128 * 1024
 _hub_dir: Optional[str] = None
 
@@ -101,6 +107,7 @@
 def _import_module(name, path):
     import importlib.util
     from importlib.abc import Loader
+
     spec = importlib.util.spec_from_file_location(name, path)
     assert spec is not None
     module = importlib.util.module_from_spec(spec)
@@ -131,18 +138,20 @@
 
 def _get_torch_home():
     torch_home = os.path.expanduser(
-        os.getenv(ENV_TORCH_HOME,
-                  os.path.join(os.getenv(ENV_XDG_CACHE_HOME,
-                                         DEFAULT_CACHE_DIR), 'torch')))
+        os.getenv(
+            ENV_TORCH_HOME,
+            os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "torch"),
+        )
+    )
     return torch_home
 
 
 def _parse_repo_info(github):
-    if ':' in github:
-        repo_info, ref = github.split(':')
+    if ":" in github:
+        repo_info, ref = github.split(":")
     else:
         repo_info, ref = github, None
-    repo_owner, repo_name = repo_info.split('/')
+    repo_owner, repo_name = repo_info.split("/")
 
     if ref is None:
         # The ref wasn't specified by the user, so we need to figure out the
@@ -150,16 +159,18 @@
         # then it's the default branch, otherwise it's master.
         try:
             with urlopen(f"https://github.com/{repo_owner}/{repo_name}/tree/main/"):
-                ref = 'main'
+                ref = "main"
         except HTTPError as e:
             if e.code == 404:
-                ref = 'master'
+                ref = "master"
             else:
                 raise
         except URLError as e:
             # No internet connection, need to check for cache as last resort
             for possible_ref in ("main", "master"):
-                if os.path.exists(f"{get_dir()}/{repo_owner}_{repo_name}_{possible_ref}"):
+                if os.path.exists(
+                    f"{get_dir()}/{repo_owner}_{repo_name}_{possible_ref}"
+                ):
                     ref = possible_ref
                     break
             if ref is None:
@@ -172,35 +183,40 @@
 
 def _read_url(url):
     with urlopen(url) as r:
-        return r.read().decode(r.headers.get_content_charset('utf-8'))
+        return r.read().decode(r.headers.get_content_charset("utf-8"))
 
 
 def _validate_not_a_forked_repo(repo_owner, repo_name, ref):
     # Use urlopen to avoid depending on local git.
-    headers = {'Accept': 'application/vnd.github.v3+json'}
+    headers = {"Accept": "application/vnd.github.v3+json"}
     token = os.environ.get(ENV_GITHUB_TOKEN)
     if token is not None:
-        headers['Authorization'] = f'token {token}'
+        headers["Authorization"] = f"token {token}"
     for url_prefix in (
-            f'https://api.github.com/repos/{repo_owner}/{repo_name}/branches',
-            f'https://api.github.com/repos/{repo_owner}/{repo_name}/tags'):
+        f"https://api.github.com/repos/{repo_owner}/{repo_name}/branches",
+        f"https://api.github.com/repos/{repo_owner}/{repo_name}/tags",
+    ):
         page = 0
         while True:
             page += 1
-            url = f'{url_prefix}?per_page=100&page={page}'
+            url = f"{url_prefix}?per_page=100&page={page}"
             response = json.loads(_read_url(Request(url, headers=headers)))
             # Empty response means no more data to process
             if not response:
                 break
             for br in response:
-                if br['name'] == ref or br['commit']['sha'].startswith(ref):
+                if br["name"] == ref or br["commit"]["sha"].startswith(ref):
                     return
 
-    raise ValueError(f'Cannot find {ref} in https://github.com/{repo_owner}/{repo_name}. '
-                     'If it\'s a commit from a forked repo, please call hub.load() with forked repo directly.')
+    raise ValueError(
+        f"Cannot find {ref} in https://github.com/{repo_owner}/{repo_name}. "
+        "If it's a commit from a forked repo, please call hub.load() with forked repo directly."
+    )
 
 
-def _get_cache_or_reload(github, force_reload, trust_repo, calling_fn, verbose=True, skip_validation=False):
+def _get_cache_or_reload(
+    github, force_reload, trust_repo, calling_fn, verbose=True, skip_validation=False
+):
     # Setup hub_dir to save downloaded files
     hub_dir = get_dir()
     os.makedirs(hub_dir, exist_ok=True)
@@ -210,27 +226,33 @@
     # this causes confusion with path on both Linux and Windows.
     # Backslash is not allowed in Github branch name so no need to
     # to worry about it.
-    normalized_br = ref.replace('/', '_')
+    normalized_br = ref.replace("/", "_")
     # Github renames folder repo-v1.x.x to repo-1.x.x
     # We don't know the repo name before downloading the zip file
     # and inspect name from it.
     # To check if cached repo exists, we need to normalize folder names.
-    owner_name_branch = '_'.join([repo_owner, repo_name, normalized_br])
+    owner_name_branch = "_".join([repo_owner, repo_name, normalized_br])
     repo_dir = os.path.join(hub_dir, owner_name_branch)
     # Check that the repo is in the trusted list
-    _check_repo_is_trusted(repo_owner, repo_name, owner_name_branch, trust_repo=trust_repo, calling_fn=calling_fn)
+    _check_repo_is_trusted(
+        repo_owner,
+        repo_name,
+        owner_name_branch,
+        trust_repo=trust_repo,
+        calling_fn=calling_fn,
+    )
 
     use_cache = (not force_reload) and os.path.exists(repo_dir)
 
     if use_cache:
         if verbose:
-            sys.stderr.write(f'Using cache found in {repo_dir}\n')
+            sys.stderr.write(f"Using cache found in {repo_dir}\n")
     else:
         # Validate the tag/branch is from the original repo instead of a forked repo
         if not skip_validation:
             _validate_not_a_forked_repo(repo_owner, repo_name, ref)
 
-        cached_file = os.path.join(hub_dir, normalized_br + '.zip')
+        cached_file = os.path.join(hub_dir, normalized_br + ".zip")
         _remove_if_exists(cached_file)
 
         try:
@@ -250,7 +272,9 @@
                     "refs/tags/tag_name as the ref. That might require using skip_validation=True."
                 )
                 disambiguated_branch_ref = f"refs/heads/{ref}"
-                url = _git_archive_link(repo_owner, repo_name, ref=disambiguated_branch_ref)
+                url = _git_archive_link(
+                    repo_owner, repo_name, ref=disambiguated_branch_ref
+                )
                 download_url_to_file(url, cached_file, progress=False)
             else:
                 raise
@@ -269,7 +293,9 @@
     return repo_dir
 
 
-def _check_repo_is_trusted(repo_owner, repo_name, owner_name_branch, trust_repo, calling_fn="load"):
+def _check_repo_is_trusted(
+    repo_owner, repo_name, owner_name_branch, trust_repo, calling_fn="load"
+):
     hub_dir = get_dir()
     filepath = os.path.join(hub_dir, "trusted_list")
 
@@ -282,7 +308,7 @@
     # if a repo was already downloaded by torchhub, then it is already trusted (even if it's not in the allowlist)
     trusted_repos_legacy = next(os.walk(hub_dir))[1]
 
-    owner_name = '_'.join([repo_owner, repo_name])
+    owner_name = "_".join([repo_owner, repo_name])
     is_trusted = (
         owner_name in trusted_repos
         or owner_name_branch in trusted_repos_legacy
@@ -298,13 +324,15 @@
                 "trust_repo=False) and a command prompt will appear asking for an explicit confirmation of trust, "
                 f"or {calling_fn}(..., trust_repo=True), which will assume that the prompt is to be answered with "
                 f"'yes'. You can also use {calling_fn}(..., trust_repo='check') which will only prompt for "
-                f"confirmation if the repo is not already trusted. This will eventually be the default behaviour")
+                f"confirmation if the repo is not already trusted. This will eventually be the default behaviour"
+            )
         return
 
     if (trust_repo is False) or (trust_repo == "check" and not is_trusted):
         response = input(
             f"The repository {owner_name} does not belong to the list of trusted repositories and as such cannot be downloaded. "
-            "Do you trust this repository and wish to add it to the trusted list of repositories (y/N)?")
+            "Do you trust this repository and wish to add it to the trusted list of repositories (y/N)?"
+        )
         if response.lower() in ("y", "yes"):
             if is_trusted:
                 print("The repository is already trusted.")
@@ -321,6 +349,7 @@
 
 def _check_module_exists(name):
     import importlib.util
+
     return importlib.util.find_spec(name) is not None
 
 
@@ -335,7 +364,7 @@
 
 def _load_entry_from_hubconf(m, model):
     if not isinstance(model, str):
-        raise ValueError('Invalid input: model should be a string of function name')
+        raise ValueError("Invalid input: model should be a string of function name")
 
     # Note that if a missing dependency is imported at top level of hubconf, it will
     # throw before this function. It's a chicken and egg situation where we have to
@@ -346,7 +375,7 @@
     func = _load_attr_from_module(m, model)
 
     if func is None or not callable(func):
-        raise RuntimeError(f'Cannot find callable {model} in hubconf')
+        raise RuntimeError(f"Cannot find callable {model} in hubconf")
 
     return func
 
@@ -362,12 +391,12 @@
     variable is not set.
     """
     # Issue warning to move data if old env is set
-    if os.getenv('TORCH_HUB'):
-        warnings.warn('TORCH_HUB is deprecated, please use env TORCH_HOME instead')
+    if os.getenv("TORCH_HUB"):
+        warnings.warn("TORCH_HUB is deprecated, please use env TORCH_HOME instead")
 
     if _hub_dir is not None:
         return _hub_dir
-    return os.path.join(_get_torch_home(), 'hub')
+    return os.path.join(_get_torch_home(), "hub")
 
 
 def set_dir(d):
@@ -381,7 +410,9 @@
     _hub_dir = os.path.expanduser(d)
 
 
-def list(github, force_reload=False, skip_validation=False, trust_repo=None, verbose=True):
+def list(
+    github, force_reload=False, skip_validation=False, trust_repo=None, verbose=True
+):
     r"""
     List all callable entrypoints available in the repo specified by ``github``.
 
@@ -424,15 +455,25 @@
         >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB)
         >>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True)
     """
-    repo_dir = _get_cache_or_reload(github, force_reload, trust_repo, "list", verbose=verbose,
-                                    skip_validation=skip_validation)
+    repo_dir = _get_cache_or_reload(
+        github,
+        force_reload,
+        trust_repo,
+        "list",
+        verbose=verbose,
+        skip_validation=skip_validation,
+    )
 
     with _add_to_sys_path(repo_dir):
         hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF)
         hub_module = _import_module(MODULE_HUBCONF, hubconf_path)
 
     # We take functions starts with '_' as internal helper functions
-    entrypoints = [f for f in dir(hub_module) if callable(getattr(hub_module, f)) and not f.startswith('_')]
+    entrypoints = [
+        f
+        for f in dir(hub_module)
+        if callable(getattr(hub_module, f)) and not f.startswith("_")
+    ]
 
     return entrypoints
 
@@ -474,8 +515,14 @@
         >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB)
         >>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True))
     """
-    repo_dir = _get_cache_or_reload(github, force_reload, trust_repo, "help", verbose=True,
-                                    skip_validation=skip_validation)
+    repo_dir = _get_cache_or_reload(
+        github,
+        force_reload,
+        trust_repo,
+        "help",
+        verbose=True,
+        skip_validation=skip_validation,
+    )
 
     with _add_to_sys_path(repo_dir):
         hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF)
@@ -486,9 +533,17 @@
     return entry.__doc__
 
 
-def load(repo_or_dir, model, *args, source='github', trust_repo=None, force_reload=False, verbose=True,
-         skip_validation=False,
-         **kwargs):
+def load(
+    repo_or_dir,
+    model,
+    *args,
+    source="github",
+    trust_repo=None,
+    force_reload=False,
+    verbose=True,
+    skip_validation=False,
+    **kwargs,
+):
     r"""
     Load a model from a github repo or a local directory.
 
@@ -559,13 +614,20 @@
     """
     source = source.lower()
 
-    if source not in ('github', 'local'):
+    if source not in ("github", "local"):
         raise ValueError(
-            f'Unknown source: "{source}". Allowed values: "github" | "local".')
+            f'Unknown source: "{source}". Allowed values: "github" | "local".'
+        )
 
-    if source == 'github':
-        repo_or_dir = _get_cache_or_reload(repo_or_dir, force_reload, trust_repo, "load",
-                                           verbose=verbose, skip_validation=skip_validation)
+    if source == "github":
+        repo_or_dir = _get_cache_or_reload(
+            repo_or_dir,
+            force_reload,
+            trust_repo,
+            "load",
+            verbose=verbose,
+            skip_validation=skip_validation,
+        )
 
     model = _load_local(repo_or_dir, model, *args, **kwargs)
     return model
@@ -601,8 +663,9 @@
     return model
 
 
-def download_url_to_file(url: str, dst: str, hash_prefix: Optional[str] = None,
-                         progress: bool = True) -> None:
+def download_url_to_file(
+    url: str, dst: str, hash_prefix: Optional[str] = None, progress: bool = True
+) -> None:
     r"""Download object at the given URL to a local path.
 
     Args:
@@ -623,7 +686,7 @@
     req = Request(url, headers={"User-Agent": "torch.hub"})
     u = urlopen(req)
     meta = u.info()
-    if hasattr(meta, 'getheaders'):
+    if hasattr(meta, "getheaders"):
         content_length = meta.getheaders("Content-Length")
     else:
         content_length = meta.get_all("Content-Length")
@@ -637,20 +700,25 @@
     # file permissions being applied to the downloaded file.
     dst = os.path.expanduser(dst)
     for seq in range(tempfile.TMP_MAX):
-        tmp_dst = dst + '.' + uuid.uuid4().hex + '.partial'
+        tmp_dst = dst + "." + uuid.uuid4().hex + ".partial"
         try:
-            f = open(tmp_dst, 'w+b')
+            f = open(tmp_dst, "w+b")
         except FileExistsError:
             continue
         break
     else:
-        raise FileExistsError(errno.EEXIST, 'No usable temporary file name found')
+        raise FileExistsError(errno.EEXIST, "No usable temporary file name found")
 
     try:
         if hash_prefix is not None:
             sha256 = hashlib.sha256()
-        with tqdm(total=file_size, disable=not progress,
-                  unit='B', unit_scale=True, unit_divisor=1024) as pbar:
+        with tqdm(
+            total=file_size,
+            disable=not progress,
+            unit="B",
+            unit_scale=True,
+            unit_divisor=1024,
+        ) as pbar:
             while True:
                 buffer = u.read(READ_DATA_CHUNK)
                 if len(buffer) == 0:
@@ -663,8 +731,10 @@
         f.close()
         if hash_prefix is not None:
             digest = sha256.hexdigest()  # type: ignore[possibly-undefined]
-            if digest[:len(hash_prefix)] != hash_prefix:
-                raise RuntimeError(f'invalid hash value (expected "{hash_prefix}", got "{digest}")')
+            if digest[: len(hash_prefix)] != hash_prefix:
+                raise RuntimeError(
+                    f'invalid hash value (expected "{hash_prefix}", got "{digest}")'
+                )
         shutil.move(f.name, dst)
     finally:
         f.close()
@@ -683,23 +753,30 @@
 
 
 @deprecated(
-    'Falling back to the old format < 1.6. This support will be '
-    'deprecated in favor of default zipfile format introduced in 1.6. '
-    'Please redo torch.save() to save it in the new zipfile format.',
+    "Falling back to the old format < 1.6. This support will be "
+    "deprecated in favor of default zipfile format introduced in 1.6. "
+    "Please redo torch.save() to save it in the new zipfile format.",
     category=FutureWarning,
 )
-def _legacy_zip_load(filename: str, model_dir: str, map_location: MAP_LOCATION, weights_only: bool) -> Dict[str, Any]:
+def _legacy_zip_load(
+    filename: str,
+    model_dir: str,
+    map_location: MAP_LOCATION,
+    weights_only: bool,
+) -> Dict[str, Any]:
     # Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand.
     #       We deliberately don't handle tarfile here since our legacy serialization format was in tar.
     #       E.g. resnet18-5c106cde.pth which is widely used.
     with zipfile.ZipFile(filename) as f:
         members = f.infolist()
         if len(members) != 1:
-            raise RuntimeError('Only one file(not dir) is allowed in the zipfile')
+            raise RuntimeError("Only one file(not dir) is allowed in the zipfile")
         f.extractall(model_dir)
         extraced_name = members[0].filename
         extracted_file = os.path.join(model_dir, extraced_name)
-    return torch.load(extracted_file, map_location=map_location, weights_only=weights_only)
+    return torch.load(
+        extracted_file, map_location=map_location, weights_only=weights_only
+    )
 
 
 def load_state_dict_from_url(
@@ -742,12 +819,14 @@
 
     """
     # Issue warning to move data if old env is set
-    if os.getenv('TORCH_MODEL_ZOO'):
-        warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
+    if os.getenv("TORCH_MODEL_ZOO"):
+        warnings.warn(
+            "TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead"
+        )
 
     if model_dir is None:
         hub_dir = get_dir()
-        model_dir = os.path.join(hub_dir, 'checkpoints')
+        model_dir = os.path.join(hub_dir, "checkpoints")
 
     os.makedirs(model_dir, exist_ok=True)
 
diff --git a/torch/library.py b/torch/library.py
index d0a4cf2..bda34c2 100644
--- a/torch/library.py
+++ b/torch/library.py
@@ -1,28 +1,34 @@
 # mypy: allow-untyped-defs
-from ._ops import OpOverload
-from typing import Any, Optional, Set, List, Union, Callable, Tuple, Dict, Sequence
-from typing_extensions import deprecated
-import traceback
-import torch
-import weakref
+import contextlib
 import functools
 import inspect
 import re
-import contextlib
 import sys
-from torch._library.custom_ops import custom_op, _maybe_get_opdef, device_types_t, CustomOpDef
+import traceback
+import weakref
+from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
+from typing_extensions import deprecated
+
+import torch
 import torch._library as _library
+from torch._library.custom_ops import (
+    _maybe_get_opdef,
+    custom_op,
+    CustomOpDef,
+    device_types_t,
+)
+from torch._ops import OpOverload
 
 
 __all__ = [
-    'Library',
-    'impl',
-    'define',
-    'fallthrough_kernel',
-    'impl_abstract',
-    'register_fake',
-    'get_ctx',
-    'custom_op',
+    "Library",
+    "impl",
+    "define",
+    "fallthrough_kernel",
+    "impl_abstract",
+    "register_fake",
+    "get_ctx",
+    "custom_op",
 ]
 
 # Set containing the combination of (namespace, operator, DispatchKey) for which a new kernel has been registered
@@ -33,7 +39,8 @@
 _defs: Set[str] = set()
 
 # prim is reserved by TorchScript interpreter
-_reserved_namespaces = ['prim']
+_reserved_namespaces = ["prim"]
+
 
 def fallthrough_kernel():
     """
@@ -41,6 +48,7 @@
     """
     raise NotImplementedError("fallthrough_kernel() should never be called.")
 
+
 class Library:
     """
     A class to create libraries that can be used to register new operators or
@@ -59,16 +67,22 @@
         kind: "DEF", "IMPL" (default: "IMPL"), "FRAGMENT"
         dispatch_key: PyTorch dispatch key (default: "")
     """
+
     def __init__(self, ns, kind, dispatch_key=""):
-        if kind not in ('IMPL', 'DEF', 'FRAGMENT'):
+        if kind not in ("IMPL", "DEF", "FRAGMENT"):
             raise ValueError("Unsupported kind: ", kind)
 
-        if ns in _reserved_namespaces and (kind == "DEF" or kind == 'FRAGMENT'):
-            raise ValueError(ns, " is a reserved namespace. Please try creating a library with another name.")
+        if ns in _reserved_namespaces and (kind == "DEF" or kind == "FRAGMENT"):
+            raise ValueError(
+                ns,
+                " is a reserved namespace. Please try creating a library with another name.",
+            )
 
         frame = traceback.extract_stack(limit=3)[0]
         filename, lineno = frame.filename, frame.lineno
-        self.m: Optional[Any] = torch._C._dispatch_library(kind, ns, dispatch_key, filename, lineno)
+        self.m: Optional[Any] = torch._C._dispatch_library(
+            kind, ns, dispatch_key, filename, lineno
+        )
         self.ns = ns
         self._op_defs: Set[str] = set()
         self._op_impls: Set[str] = set()
@@ -79,13 +93,21 @@
         # Python __del__ can lead to weird things (globals and locals may already
         # be gone when __del__ actually gets called!). finalizers help the
         # situation because it lets us capture references and keeps them alive
-        weakref.finalize(self, _del_library, _impls, self._op_impls, _defs, self._op_defs, self._registration_handles)
+        weakref.finalize(
+            self,
+            _del_library,
+            _impls,
+            self._op_impls,
+            _defs,
+            self._op_defs,
+            self._registration_handles,
+        )
 
     def __repr__(self):
         return f"Library(kind={self.kind}, ns={self.ns}, dispatch_key={self.dispatch_key})>"
 
     def define(self, schema, alias_analysis="", *, tags=()):
-        r'''Defines a new operator and its semantics in the ns namespace.
+        r"""Defines a new operator and its semantics in the ns namespace.
 
         Args:
             schema: function schema to define a new operator.
@@ -102,7 +124,7 @@
         Example::
             >>> my_lib = Library("mylib", "DEF")
             >>> my_lib.define("sum(Tensor self) -> Tensor")
-        '''
+        """
         # This is added because we also want to disallow PURE_FUNCTION alias analysis which is a valid
         # AliasAnalysis type in C++
         if alias_analysis not in ["", "FROM_SCHEMA", "CONSERVATIVE"]:
@@ -113,7 +135,9 @@
 
         name = schema.split("(")[0]
         packet_name = name.split(".")[0] if "." in name else name
-        has_preexisting_packet = hasattr(torch.ops, self.ns) and hasattr(getattr(torch.ops, self.ns), packet_name)
+        has_preexisting_packet = hasattr(torch.ops, self.ns) and hasattr(
+            getattr(torch.ops, self.ns), packet_name
+        )
 
         result = self.m.define(schema, alias_analysis, tuple(tags))
         name = schema.split("(")[0]
@@ -131,7 +155,7 @@
         return result
 
     def _register_fake(self, op_name, fn, _stacklevel=1):
-        r'''Registers the fake impl for an operator defined in the library.'''
+        r"""Registers the fake impl for an operator defined in the library."""
         source = torch._library.utils.get_source(_stacklevel + 1)
         frame = sys._getframe(_stacklevel)
         caller_module = inspect.getmodule(frame)
@@ -141,7 +165,9 @@
 
         # TODO(rzou): We're gonna need to stage this change with torchvision,
         # since torchvision is github first.
-        if caller_module_name is not None and caller_module_name.startswith("torchvision."):
+        if caller_module_name is not None and caller_module_name.startswith(
+            "torchvision."
+        ):
             caller_module_name = None
 
         qualname = f"{self.ns}::{op_name}"
@@ -154,8 +180,8 @@
         handle = entry.abstract_impl.register(func_to_register, source)
         self._registration_handles.append(handle)
 
-    def _impl_with_aoti_compile(self, op_name, dispatch_key=''):
-        r'''Register the operator to use the AOTI-compiled implementation.
+    def _impl_with_aoti_compile(self, op_name, dispatch_key=""):
+        r"""Register the operator to use the AOTI-compiled implementation.
 
         Args:
             op_name: operator name (along with the overload) or OpOverload object.
@@ -165,8 +191,8 @@
         Example::
             >>> my_lib = Library("aten", "IMPL")
             >>> my_lib._impl_with_aoti_compile("div.Tensor", "CPU")
-        '''
-        if dispatch_key == '':
+        """
+        if dispatch_key == "":
             dispatch_key = self.dispatch_key
         assert torch.DispatchKeySet(dispatch_key).has(torch._C.DispatchKey.Dense)
 
@@ -175,19 +201,24 @@
         elif isinstance(op_name, OpOverload):
             name = op_name._schema.name
             overload_name = op_name._schema.overload_name
-            if overload_name != '':
-                name = name + '.' + overload_name
+            if overload_name != "":
+                name = name + "." + overload_name
         else:
-            raise RuntimeError("_impl_with_aoti_compile should be passed either a name or an OpOverload object "
-                               "as the first argument")
+            raise RuntimeError(
+                "_impl_with_aoti_compile should be passed either a name or an OpOverload object "
+                "as the first argument"
+            )
 
         key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key
         if key in _impls:
             # TODO: in future, add more info about where the existing function is registered (this info is
             # today already returned by the C++ warning when _impl_with_aoti_compile is called but we error out before that)
-            raise RuntimeError("This is not allowed since there's already a kernel registered from python overriding {}"
-                               "'s behavior for {} dispatch key and {} namespace.".
-                               format(name.split("::")[-1], dispatch_key, self.ns))
+            raise RuntimeError(
+                "This is not allowed since there's already a kernel registered from python overriding {}"
+                "'s behavior for {} dispatch key and {} namespace.".format(
+                    name.split("::")[-1], dispatch_key, self.ns
+                )
+            )
 
         assert self.m is not None
         impl_fn: Callable = self.m.impl_with_aoti_compile
@@ -196,8 +227,8 @@
         _impls.add(key)
         self._op_impls.add(key)
 
-    def impl(self, op_name, fn, dispatch_key='', *, with_keyset=False):
-        r'''Registers the function implementation for an operator defined in the library.
+    def impl(self, op_name, fn, dispatch_key="", *, with_keyset=False):
+        r"""Registers the function implementation for an operator defined in the library.
 
         Args:
             op_name: operator name (along with the overload) or OpOverload object.
@@ -211,10 +242,12 @@
             >>> def div_cpu(self, other):
             >>>     return self * (1 / other)
             >>> my_lib.impl("div.Tensor", div_cpu, "CPU")
-        '''
+        """
         if not callable(fn):
-            raise TypeError(f"Input function is required to be a callable but found type {type(fn)}")
-        if dispatch_key == '':
+            raise TypeError(
+                f"Input function is required to be a callable but found type {type(fn)}"
+            )
+        if dispatch_key == "":
             dispatch_key = self.dispatch_key
 
         if isinstance(op_name, str):
@@ -222,37 +255,50 @@
         elif isinstance(op_name, OpOverload):
             name = op_name._schema.name
             overload_name = op_name._schema.overload_name
-            if overload_name != '':
-                name = name + '.' + overload_name
+            if overload_name != "":
+                name = name + "." + overload_name
         else:
-            raise RuntimeError("impl should be passed either a name or an OpOverload object as the first argument")
+            raise RuntimeError(
+                "impl should be passed either a name or an OpOverload object as the first argument"
+            )
 
         key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key
         if key in _impls:
             # TODO: in future, add more info about where the existing function is registered (this info is
             # today already returned by the C++ warning when impl is called but we error out before that)
-            raise RuntimeError("This is not allowed since there's already a kernel registered from python overriding {}"
-                               "'s behavior for {} dispatch key and {} namespace.".
-                               format(name.split("::")[-1], dispatch_key, self.ns))
+            raise RuntimeError(
+                "This is not allowed since there's already a kernel registered from python overriding {}"
+                "'s behavior for {} dispatch key and {} namespace.".format(
+                    name.split("::")[-1], dispatch_key, self.ns
+                )
+            )
 
         if dispatch_key == "Meta":
             dispatcher_op_name = name
-            if '::' not in dispatcher_op_name:
-                dispatcher_op_name = f'{self.ns}::{dispatcher_op_name}'
+            if "::" not in dispatcher_op_name:
+                dispatcher_op_name = f"{self.ns}::{dispatcher_op_name}"
 
             # Internally, we shouldn't be registering meta kernels for any operators that
             # have CompositeImplicitAutograd kernels.
             # Instead, we should be letting those decompositions run, and writing meta kernels
             # only for the base operators.
-            if torch._C._dispatch_has_kernel_for_dispatch_key(dispatcher_op_name, "CompositeImplicitAutograd"):
+            if torch._C._dispatch_has_kernel_for_dispatch_key(
+                dispatcher_op_name, "CompositeImplicitAutograd"
+            ):
                 raise RuntimeError(
                     f"We should not register a meta kernel directly to the operator '{name}',"
                     " because it has a CompositeImplicitAutograd kernel in core."
                     " Instead we should let the operator decompose, and ensure that we have meta kernels"
-                    " for the base ops that it decomposes into.")
+                    " for the base ops that it decomposes into."
+                )
 
         assert self.m is not None
-        self.m.impl(name, dispatch_key if dispatch_key != "" else "CompositeImplicitAutograd", fn, with_keyset)
+        self.m.impl(
+            name,
+            dispatch_key if dispatch_key != "" else "CompositeImplicitAutograd",
+            fn,
+            with_keyset,
+        )
 
         _impls.add(key)
         self._op_impls.add(key)
@@ -283,7 +329,9 @@
             delattr(namespace, name)
 
 
-def _del_library(captured_impls, op_impls, captured_defs, op_defs, registration_handles):
+def _del_library(
+    captured_impls, op_impls, captured_defs, op_defs, registration_handles
+):
     captured_impls -= op_impls
     captured_defs -= op_defs
     for handle in registration_handles:
@@ -357,7 +405,8 @@
     if not isinstance(qualname, str):
         raise ValueError(
             f"define(qualname, schema): expected qualname "
-            f"to be instance of str, got {type(qualname)}")
+            f"to be instance of str, got {type(qualname)}"
+        )
     namespace, name = torch._library.utils.parse_namespace(qualname)
     if lib is None:
         lib = Library(namespace, "FRAGMENT")
@@ -366,7 +415,8 @@
         raise ValueError(
             f"define(qualname, schema, ...): expected schema "
             f'to look like e.g. "(Tensor x) -> Tensor" but '
-            f'got "{schema}"')
+            f'got "{schema}"'
+        )
     lib.define(name + schema, alias_analysis="", tags=tags)
 
 
@@ -375,10 +425,12 @@
     """The old torch.library.define.
     We're keeping this around for BC reasons
     """
+
     def wrap(f):
         name = lib.define(schema, alias_analysis)
         lib.impl(name, f)
         return f
+
     return wrap
 
 
@@ -460,9 +512,11 @@
 @impl.register
 def _(lib: Library, name, dispatch_key=""):
     """Legacy torch.library.impl API. Kept around for BC"""
+
     def wrap(f):
         lib.impl(name, f, dispatch_key)
         return f
+
     return wrap
 
 
@@ -480,16 +534,19 @@
     return register_fake(qualname, func, lib=lib, _stacklevel=_stacklevel)
 
 
-_op_identifier = Union[str, "torch._ops.OpOverload", "torch._library.custom_ops.CustomOpDef"]
+_op_identifier = Union[
+    str, "torch._ops.OpOverload", "torch._library.custom_ops.CustomOpDef"
+]
 
 
 def register_kernel(
-        op: _op_identifier,
-        device_types: device_types_t,
-        func: Optional[Callable] = None,
-        /,
-        *,
-        lib: Optional[Library] = None):
+    op: _op_identifier,
+    device_types: device_types_t,
+    func: Optional[Callable] = None,
+    /,
+    *,
+    lib: Optional[Library] = None,
+):
     """Register an implementation for a device type for this operator.
 
     Some valid device_types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu".
@@ -530,7 +587,9 @@
 
     """
 
-    if not isinstance(op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)):
+    if not isinstance(
+        op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
+    ):
         raise ValueError("register_kernel(op): got unexpected type for op: {type(op)}")
     if isinstance(op, torch._ops.OpOverload):
         op = op._name
@@ -544,12 +603,13 @@
 
 
 def register_fake(
-        op: _op_identifier,
-        func: Optional[Callable] = None,
-        /,
-        *,
-        lib: Optional[Library] = None,
-        _stacklevel: int = 1):
+    op: _op_identifier,
+    func: Optional[Callable] = None,
+    /,
+    *,
+    lib: Optional[Library] = None,
+    _stacklevel: int = 1,
+):
     r"""Register a FakeTensor implementation ("fake impl") for this operator.
 
     Also sometimes known as a "meta kernel", "abstract impl".
@@ -630,7 +690,9 @@
         >>> assert torch.allclose(trace(x), torch.ops.mylib.custom_nonzero(x))
 
     """
-    if not isinstance(op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)):
+    if not isinstance(
+        op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
+    ):
         raise ValueError("register_fake(op): got unexpected type for op: {type(op)}")
     if isinstance(op, torch._ops.OpOverload):
         op = op._name
@@ -661,7 +723,14 @@
         return register(func)
 
 
-def register_autograd(op: _op_identifier, backward: Callable, /, *, setup_context: Optional[Callable] = None, lib=None) -> None:
+def register_autograd(
+    op: _op_identifier,
+    backward: Callable,
+    /,
+    *,
+    setup_context: Optional[Callable] = None,
+    lib=None,
+) -> None:
     r"""Register a backward formula for this custom op.
 
     In order for an operator to work with autograd, you need to register
@@ -737,8 +806,12 @@
         >>> assert torch.allclose(grad_x, torch.full_like(x, 3.14))
 
     """
-    if not isinstance(op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)):
-        raise ValueError(f"register_autograd(op): got unexpected type for op: {type(op)}")
+    if not isinstance(
+        op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
+    ):
+        raise ValueError(
+            f"register_autograd(op): got unexpected type for op: {type(op)}"
+        )
     if isinstance(op, torch._ops.OpOverload):
         op = op._name
     opdef = _maybe_get_opdef(op)
@@ -760,7 +833,8 @@
         raise NotImplementedError(
             f"register_autograd with kwarg-only Tensor args. In the original "
             f"definition of the op, please make your tensors not kwarg-only. "
-            f"Got: {schema}")
+            f"Got: {schema}"
+        )
 
     info = _library.autograd.Info(backward, setup_context)
     autograd_kernel = _library.autograd.make_autograd_impl(op, info)
@@ -788,8 +862,8 @@
             return func(*args, **kwargs)
 
         maybe_pystub = torch._C._dispatch_pystub(
-            op._schema.name,
-            op._schema.overload_name)
+            op._schema.name, op._schema.overload_name
+        )
         if maybe_pystub is None:
             if torch._library.utils.requires_set_python_module():
                 namespace = op.namespace
@@ -800,7 +874,8 @@
                     f'companion C++ `m.set_python_module("{actual_module_name}")` '
                     f"call, but we could not find one. Please add that to "
                     f"to the top of the C++ TORCH_LIBRARY({namespace}, ...) block the "
-                    f"operator was registered in ({cpp_filename})")
+                    f"operator was registered in ({cpp_filename})"
+                )
         else:
             pystub_module = maybe_pystub[0]
             if actual_module_name != pystub_module:
@@ -809,9 +884,11 @@
                     f"Operator '{qualname}' specified that its python fake impl "
                     f"is in the Python module '{pystub_module}' but it was actually found "
                     f"in '{actual_module_name}'. Please either move the fake impl "
-                    f"or correct the m.set_python_module call ({cpp_filename})")
+                    f"or correct the m.set_python_module call ({cpp_filename})"
+                )
         checked = True
         return func(*args, **kwargs)
+
     return inner
 
 
@@ -929,4 +1006,7 @@
 
     """
     import torch.testing._internal.optests as optests
-    return optests.opcheck(op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception)
+
+    return optests.opcheck(
+        op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
+    )
diff --git a/torch/overrides.py b/torch/overrides.py
index 5095689..651912b 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -23,19 +23,26 @@
 import __future__  # noqa: F404
 
 import collections
+import contextlib
 import functools
 import types
 import warnings
-from typing import Dict, Set, List, Any, Callable, Iterable, Type, Tuple
 from functools import wraps
-import contextlib
+from typing import Any, Callable, Dict, Iterable, List, Set, Tuple, Type
 
 import torch
 from torch._C import (
-    _has_torch_function, _has_torch_function_unary,
-    _has_torch_function_variadic, _add_docstr,
-    _push_on_torch_function_stack, _pop_torch_function_stack, _get_function_stack_at, _len_torch_function_stack,
-    _is_torch_function_mode_enabled)
+    _add_docstr,
+    _get_function_stack_at,
+    _has_torch_function,
+    _has_torch_function_unary,
+    _has_torch_function_variadic,
+    _is_torch_function_mode_enabled,
+    _len_torch_function_stack,
+    _pop_torch_function_stack,
+    _push_on_torch_function_stack,
+)
+
 
 __all__ = [
     "get_ignored_functions",
@@ -52,7 +59,8 @@
 
 
 def _disable_user_warnings(
-        func: Callable, regex: str = '.*is deprecated, please use.*', module: str = 'torch') -> Callable:
+    func: Callable, regex: str = ".*is deprecated, please use.*", module: str = "torch"
+) -> Callable:
     """
     Decorator that temporarily disables ``UserWarning``s for the given ``module`` if the warning message matches the
     given ``regex`` pattern.
@@ -75,8 +83,11 @@
     @wraps(func)
     def wrapper(*args, **kwargs):
         with warnings.catch_warnings():
-            warnings.filterwarnings("ignore", category=UserWarning, message=regex, module=module)
+            warnings.filterwarnings(
+                "ignore", category=UserWarning, message=regex, module=module
+            )
             return func(*args, **kwargs)
+
     return wrapper
 
 
@@ -470,8 +481,9 @@
         torch.batch_norm_update_stats: lambda input, running_mean, running_var, momentum: -1,
         torch.bernoulli: lambda input, generator=None, out=None: -1,
         torch.bilinear: lambda input1, input2, weight, bias: -1,
-        torch.binary_cross_entropy_with_logits: (lambda input, target, weight=None, size_average=None, reduce=None,
-                                                 reduction='mean', pos_weight=None: -1),
+        torch.binary_cross_entropy_with_logits: (
+            lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean", pos_weight=None: -1
+        ),
         torch.bincount: lambda input, weights=None, minlength=0: -1,
         torch.binomial: lambda count, prob, generator=None: -1,
         torch.bitwise_and: lambda input, other, out=None: -1,
@@ -489,11 +501,11 @@
         torch.cat: lambda tensors, dim=0, out=None: -1,
         torch.concat: lambda tensors, dim=0, out=None: -1,  # alias for torch.cat
         torch.concatenate: lambda tensors, dim=0, out=None: -1,  # alias for torch.concatenate
-        torch.cdist: lambda x1, x2, p=2.0, compute_mode='use_mm_for_euclid_dist_if_necessary': -1,
+        torch.cdist: lambda x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary": -1,
         torch.ceil: lambda input, out=None: -1,
-        torch.celu: lambda input, alpha=1., inplace=False: -1,
+        torch.celu: lambda input, alpha=1.0, inplace=False: -1,
         torch.chain_matmul: lambda *matrices, out=None: -1,
-        torch.channel_shuffle: lambda input, groups : -1,
+        torch.channel_shuffle: lambda input, groups: -1,
         torch.cholesky: lambda input, upper=False, out=None: -1,
         torch.linalg.cholesky: lambda input, out=None: -1,
         torch.linalg.cholesky_ex: lambda input, check_errors=False, out=None: -1,
@@ -528,14 +540,15 @@
         torch.conv_transpose3d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1,
         torch.corrcoef: lambda input: -1,
         torch.cos: lambda input, out=None: -1,
-        torch.cosine_embedding_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean': -1,
+        torch.cosine_embedding_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1,
         torch.cosh: lambda input, out=None: -1,
         torch.cosine_similarity: lambda x1, x2, dim=1, eps=1e-8: -1,
         torch.count_nonzero: lambda input: -1,
         torch.cross: lambda input, other, dim=None, out=None: -1,
         torch.linalg.cross: lambda input, other, dim=-1, out=None: -1,
-        torch.ctc_loss: (lambda log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean',
-                         zero_infinity=False: -1),
+        torch.ctc_loss: (
+            lambda log_probs, targets, input_lengths, target_lengths, blank=0, reduction="mean", zero_infinity=False: -1
+        ),
         torch.cummax: lambda input, dim, out=None: -1,
         torch.cummin: lambda input, dim, out=None: -1,
         torch.cumprod: lambda input, dim, out=None, dtype=None: -1,
@@ -570,10 +583,12 @@
         torch.linalg.eigh: lambda input, UPLO="L", out=None: -1,
         torch.linalg.eigvalsh: lambda input, UPLO="L", out=None: -1,
         torch.einsum: lambda equation, *operands: -1,
-        torch.embedding: (lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False,
-                          sparse=False: -1),
-        torch.embedding_bag: (lambda input, weight, offsets, max_norm=None, norm_type=2, scale_grad_by_freq=False,
-                              mode='mean', sparse=False, per_sample_weights=None, padding_idx=None: -1),
+        torch.embedding: (
+            lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False: -1  # noqa: B950
+        ),
+        torch.embedding_bag: (
+            lambda input, weight, offsets, max_norm=None, norm_type=2, scale_grad_by_freq=False, mode="mean", sparse=False, per_sample_weights=None, padding_idx=None: -1  # noqa: B950
+        ),
         torch.empty_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
         torch.eq: lambda input, other, out=None: -1,
         torch.equal: lambda input, other: -1,
@@ -585,14 +600,15 @@
         torch.expm1: lambda input, out=None: -1,
         torch.fake_quantize_per_channel_affine: lambda input, scale, zero_point, axis, quant_min, quant_max: -1,
         torch.fake_quantize_per_tensor_affine: lambda input, scale, zero_point, quant_min, quant_max: -1,
-        torch.fused_moving_avg_obs_fake_quant: (lambda x, observer_on, fake_quant_on, averaging_const, running_min,
-                                                running_max, scale, zero_point, quant_min, quant_max, ch_axis,
-                                                per_row_fake_quant=False, symmetric_quant=False: -1),
+        torch.fused_moving_avg_obs_fake_quant: (
+            lambda x, observer_on, fake_quant_on, averaging_const, running_min, running_max, scale, zero_point, quant_min, quant_max, ch_axis, per_row_fake_quant=False, symmetric_quant=False: -1  # noqa: B950
+        ),
         torch.fbgemm_linear_fp16_weight: lambda input, packed_weight, bias: -1,
         torch.fbgemm_linear_fp16_weight_fp32_activation: lambda input, packed_weight, bias: -1,
-        torch.fbgemm_linear_int8_weight: lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1,
-        torch.fbgemm_linear_int8_weight_fp32_activation: (lambda input, weight, packed, col_offsets, weight_scale,
-                                                          weight_zero_point, bias: -1),
+        torch.fbgemm_linear_int8_weight: lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1,  # noqa: B950
+        torch.fbgemm_linear_int8_weight_fp32_activation: (
+            lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1
+        ),
         torch.fbgemm_linear_quantize_weight: lambda input: -1,
         torch.fbgemm_pack_gemm_matrix_fp16: lambda input: -1,
         torch.fbgemm_pack_quantized_matrix: lambda input, a, b: -1,
@@ -630,7 +646,7 @@
         torch.fmod: lambda input, other, out=None: -1,
         torch.frac: lambda input, out=None: -1,
         torch.frexp: lambda input, out=None: -1,
-        torch.full_like: lambda input, fill_value, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1,
+        torch.full_like: lambda input, fill_value, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1,  # noqa: B950
         torch._functional_assert_async: lambda input, msg, dep_token: -1,
         torch.lu_unpack: lambda LU_data, LU_pivots, unpack_data=True, unpack_pivots=True: -1,
         torch.gather: lambda input, dim, index, out=None, sparse_grad=False: -1,
@@ -653,7 +669,7 @@
         torch.greater: lambda input, other, out=None: -1,
         torch.hardshrink: lambda input, lambd=0.5: -1,
         torch.heaviside: lambda input, values, out=None: -1,
-        torch.hinge_embedding_loss: lambda input, target, margin=1.0, size_average=None, reduce=None, reduction='mean': -1,
+        torch.hinge_embedding_loss: lambda input, target, margin=1.0, size_average=None, reduce=None, reduction="mean": -1,  # noqa: B950
         torch.histc: lambda input, bins=100, min=0, max=0, out=None: -1,
         torch.histogram: lambda input, bins=100, min=None, max=None, weight=None, density=False, out=None: -1,
         torch.histogramdd: lambda input, bins, range=None, weight=None, density=False: -1,
@@ -677,8 +693,9 @@
         torch.isreal: lambda tensor: -1,
         torch.isposinf: lambda input, out=None: -1,
         torch.isneginf: lambda input, out=None: -1,
-        torch.instance_norm: (lambda input, running_mean, running_var, weight, bias, use_input_stats, momentum, eps,
-                              cudnn_enabled: -1),
+        torch.instance_norm: (
+            lambda input, running_mean, running_var, weight, bias, use_input_stats, momentum, eps, cudnn_enabled: -1
+        ),
         torch.int_repr: lambda input: -1,
         torch.inverse: lambda input, out=None: -1,
         torch.linalg.inv: lambda input, out=None: -1,
@@ -694,9 +711,10 @@
         torch.is_signed: lambda input: -1,
         torch.isclose: lambda input, other, rtol=1e-05, atol=1e-08, equal_nan=False: -1,
         torch.isnan: lambda input: -1,
-        torch.istft: (lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True,
-                      normalized=False, onesided=None, length=None, return_complex=False: -1),
-        torch.kl_div: lambda input, target, size_average=None, reduce=None, reduction='mean', log_target=False: -1,
+        torch.istft: (
+            lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, normalized=False, onesided=None, length=None, return_complex=False: -1  # noqa: B950
+        ),
+        torch.kl_div: lambda input, target, size_average=None, reduce=None, reduction="mean", log_target=False: -1,
         torch.kron: lambda input, other: -1,
         torch.kthvalue: lambda input, k, dim=None, keepdim=False, out=None: -1,
         torch.linalg.ldl_factor_ex: lambda input, hermitian=False, check_errors=False, out=None: -1,
@@ -709,8 +727,7 @@
         torch.less_equal: lambda input, other, out=None: -1,
         torch.lerp: lambda input, end, weight, out=None: -1,
         torch.lgamma: lambda input, out=None: -1,
-        torch.lobpcg: lambda input, k=None, B=None, X=None, n=None, iK=None, niter=None, tol=None, largest=None, method=None,
-        tracker=None, ortho_iparams=None, ortho_fparams=None, ortho_bparams=None: -1,
+        torch.lobpcg: lambda input, k=None, B=None, X=None, n=None, iK=None, niter=None, tol=None, largest=None, method=None, tracker=None, ortho_iparams=None, ortho_fparams=None, ortho_bparams=None: -1,  # noqa: B950
         torch.log: lambda input, out=None: -1,
         torch.log_softmax: lambda input, dim, dtype=None: -1,
         torch.log10: lambda input, out=None: -1,
@@ -732,7 +749,7 @@
         torch.less: lambda input, other, out=None: -1,
         torch.lu: lambda A, pivot=True, get_infos=False, out=None: -1,
         torch.lu_solve: lambda b, LU_data, LU_pivots, out=None: -1,
-        torch.margin_ranking_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean': -1,  # type: ignore[attr-defined]  # noqa: B950
+        torch.margin_ranking_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1,  # type: ignore[attr-defined]  # noqa: B950
         torch.masked_fill: lambda input, mask, value: -1,
         torch.masked_scatter: lambda input, mask, source: -1,
         torch.masked_select: lambda input, mask, out=None: -1,
@@ -754,8 +771,9 @@
         torch.max_pool1d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
         torch.max_pool2d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
         torch.max_pool3d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
-        torch.max_pool1d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
-                                        return_indices=False, ceil_mode=False: -1),
+        torch.max_pool1d_with_indices: (
+            lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
+        ),
         torch.mean: lambda input, dim=None: -1,
         torch.nanmean: lambda input, dim=None, keepdim=False, dtype=None, out=None: -1,
         torch.median: lambda input, dim=None: -1,
@@ -764,17 +782,21 @@
         torch.min: lambda input, out=None: -1,
         torch.minimum: lambda input, other, out=None: -1,
         torch.fmin: lambda input, other, out=None: -1,
-        torch.miopen_batch_norm: (lambda input, weight, bias, running_mean, running_var, training,
-                                  exponential_average_factor, epsilon: -1),
-        torch.miopen_convolution: lambda input, weight, bias, padding, stride, dilation, groups, benchmark, deterministic: -1,
+        torch.miopen_batch_norm: (
+            lambda input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon: -1
+        ),
+        torch.miopen_convolution: lambda input, weight, bias, padding, stride, dilation, groups, benchmark, deterministic: -1,  # noqa: B950
         torch.miopen_convolution_add_relu: lambda input, weight, z, alpha, bias, stride, padding, dilation, groups: -1,
         torch.miopen_convolution_relu: lambda input, weight, bias, stride, padding, dilation, groups: -1,
-        torch.miopen_convolution_transpose: (lambda input, weight, bias, padding, output_padding, stride, dilation,
-                                             groups, benchmark, deterministic: -1),
-        torch.miopen_depthwise_convolution: (lambda input, weight, bias, padding, stride, dilation, groups, benchmark,
-                                             deterministic: -1),
-        torch.miopen_rnn: (lambda input, weight, weight_stride0, hx, cx, mode, hidden_size, num_layers, batch_first,
-                           dropout, train, bidirectional, batch_sizes, dropout_state: -1),
+        torch.miopen_convolution_transpose: (
+            lambda input, weight, bias, padding, output_padding, stride, dilation, groups, benchmark, deterministic: -1
+        ),
+        torch.miopen_depthwise_convolution: (
+            lambda input, weight, bias, padding, stride, dilation, groups, benchmark, deterministic: -1
+        ),
+        torch.miopen_rnn: (
+            lambda input, weight, weight_stride0, hx, cx, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state: -1  # noqa: B950
+        ),
         torch.mm: lambda input, mat2, out=None: -1,
         torch.mode: lambda input, dim=-1, keepdim=False, out=None: -1,
         torch.movedim: lambda input, source, destination: -1,
@@ -793,7 +815,7 @@
         torch.native_layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,
         torch.native_group_norm: lambda input, weight, bias, N, C, HxW, group, eps: -1,
         torch.native_norm: lambda input, p=2, dim=None, keepdim=False, dtype=None: -1,
-        torch.native_channel_shuffle: lambda input, groups : -1,
+        torch.native_channel_shuffle: lambda input, groups: -1,
         torch.ne: lambda input, other, out=None: -1,
         torch.not_equal: lambda input, other, out=None: -1,
         torch.neg: lambda input, out=None: -1,
@@ -809,62 +831,76 @@
         torch.nn.functional.adaptive_max_pool3d_with_indices: lambda input, output_size, return_indices=False: -1,
         torch.nn.functional.affine_grid: lambda theta, size, align_corners=None: -1,
         torch.nn.functional.alpha_dropout: lambda input, p=0.5, training=False, inplace=False: -1,
-        torch.nn.functional.avg_pool2d: (lambda input, kernel_size, stride=None, padding=0, ceil_mode=False,
-                                         count_include_pad=True, divisor_override=None: -1),
-        torch.nn.functional.avg_pool3d: (lambda input, kernel_size, stride=None, padding=0, ceil_mode=False,
-                                         count_include_pad=True, divisor_override=None: -1),
-        torch.nn.functional.batch_norm: (lambda input, running_mean, running_var, weight=None, bias=None, training=False,
-                                         momentum=0.1, eps=1e-05: -1),
+        torch.nn.functional.avg_pool2d: (
+            lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None: -1  # noqa: B950
+        ),
+        torch.nn.functional.avg_pool3d: (
+            lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None: -1  # noqa: B950
+        ),
+        torch.nn.functional.batch_norm: (
+            lambda input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-05: -1
+        ),
         torch.nn.functional.bilinear: lambda input1, input2, weight, bias=None: -1,
-        torch.nn.functional.binary_cross_entropy: (lambda input, target, weight=None, size_average=None, reduce=None,
-                                                   reduction="mean": -1),
-        torch.nn.functional.binary_cross_entropy_with_logits: (lambda input, target, weight=None, size_average=None,
-                                                               reduce=None, reduction="mean", pos_weight=None: -1),
+        torch.nn.functional.binary_cross_entropy: (
+            lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean": -1
+        ),
+        torch.nn.functional.binary_cross_entropy_with_logits: (
+            lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean", pos_weight=None: -1
+        ),
         torch.nn.functional.celu: lambda input, alpha=1.0, inplace=False: -1,
-        torch.nn.functional.cosine_embedding_loss: (lambda input1, input2, target, margin=0, size_average=None,
-                                                    reduce=None, reduction='mean': -1),
-        torch.nn.functional.cross_entropy: (lambda input, target, weight=None, size_average=None, ignore_index=-100,
-                                            reduce=None, reduction="mean", label_smoothing=0.0: -1),
-        torch.nn.functional.ctc_loss: (lambda log_probs, targets, input_lengths, target_lengths, blank=0,
-                                       reduction='mean', zero_infinity=False: -1),
+        torch.nn.functional.cosine_embedding_loss: (
+            lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1
+        ),
+        torch.nn.functional.cross_entropy: (
+            lambda input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean", label_smoothing=0.0: -1  # noqa: B950
+        ),
+        torch.nn.functional.ctc_loss: (
+            lambda log_probs, targets, input_lengths, target_lengths, blank=0, reduction="mean", zero_infinity=False: -1
+        ),
         torch.nn.functional.dropout: lambda input, p=0.5, training=True, inplace=False: -1,
         torch.nn.functional.dropout1d: lambda input, p=0.5, training=True, inplace=False: -1,
         torch.nn.functional.dropout2d: lambda input, p=0.5, training=True, inplace=False: -1,
         torch.nn.functional.dropout3d: lambda input, p=0.5, training=True, inplace=False: -1,
         torch.nn.functional.elu: lambda input, alpha=1.0, inplace=False: -1,
-        torch.nn.functional.embedding: (lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0,
-                                        scale_grad_by_freq=False, sparse=False: -1),
-        torch.nn.functional.embedding_bag: (lambda input, weight, offsets=None, max_norm=None, norm_type=2,
-                                            scale_grad_by_freq=False, mode='mean', sparse=False, per_sample_weights=None,
-                                            include_last_offset=False, padding_idx=None: -1),
+        torch.nn.functional.embedding: (
+            lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False: -1  # noqa: B950
+        ),
+        torch.nn.functional.embedding_bag: (
+            lambda input, weight, offsets=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, mode="mean", sparse=False, per_sample_weights=None, include_last_offset=False, padding_idx=None: -1  # noqa: B950
+        ),
         torch.nn.functional.feature_alpha_dropout: lambda input, p=0.5, training=False, inplace=False: -1,
         torch.nn.functional.fold: lambda input, output_size, kernel_size, dilation=1, padding=0, stride=1: -1,
-        torch.nn.functional.fractional_max_pool2d: (lambda input, kernel_size, output_size=None, output_ratio=None,
-                                                    return_indices=False, _random_samples=None: -1),
+        torch.nn.functional.fractional_max_pool2d: (
+            lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1  # noqa: B950
+        ),
         torch.nn.functional.fractional_max_pool2d_with_indices: (
-            lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False,
-            _random_samples=None: -1),
-        torch.nn.functional.fractional_max_pool3d: (lambda input, kernel_size, output_size=None, output_ratio=None,
-                                                    return_indices=False, _random_samples=None: -1),
+            lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1  # noqa: B950
+        ),
+        torch.nn.functional.fractional_max_pool3d: (
+            lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1  # noqa: B950
+        ),
         torch.nn.functional.fractional_max_pool3d_with_indices: (
-            lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False,
-            _random_samples=None: -1),
-        torch.nn.functional.gaussian_nll_loss: lambda input, target, var, full=False, eps=1e-06, reduction='mean': -1,
-        torch.nn.functional.gelu: lambda input, approximate='none': -1,
+            lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1  # noqa: B950
+        ),
+        torch.nn.functional.gaussian_nll_loss: lambda input, target, var, full=False, eps=1e-06, reduction="mean": -1,
+        torch.nn.functional.gelu: lambda input, approximate="none": -1,
         torch.nn.functional.glu: lambda input, dim=-1: -1,
-        torch.nn.functional.grid_sample: lambda input, grid, mode='bilinear', padding_mode='zeros', align_corners=None: -1,
+        torch.nn.functional.grid_sample: lambda input, grid, mode="bilinear", padding_mode="zeros", align_corners=None: -1,  # noqa: B950
         torch.nn.functional.group_norm: lambda input, num_groups, weight=None, bias=None, eps=1e-05: -1,
         torch.nn.functional.gumbel_softmax: lambda logits, tau=1, hard=False, eps=1e-10, dim=-1: -1,
         torch.nn.functional.hardshrink: lambda input, lambd=0.5: -1,
-        torch.nn.functional.hardtanh: lambda input, min_val=-1., max_val=1., inplace=False: -1,
-        torch.nn.functional.hinge_embedding_loss: (lambda input, target, margin=1.0, size_average=None, reduce=None,
-                                                   reduction='mean': -1),
-        torch.nn.functional.instance_norm: (lambda input, running_mean=None, running_var=None, weight=None, bias=None,
-                                            use_input_stats=True, momentum=0.1, eps=1e-05: -1),
-        torch.nn.functional.interpolate: (lambda input, size=None, scale_factor=None, mode='nearest', align_corners=None,
-                                          recompute_scale_factor=None, antialias=False: -1),
-        torch.nn.functional.kl_div: lambda input, target, size_average=None, reduce=None, reduction='mean', log_target=False: -1,
-        torch.nn.functional.l1_loss: lambda input, target, size_average=None, reduce=None, reduction='mean': -1,
+        torch.nn.functional.hardtanh: lambda input, min_val=-1.0, max_val=1.0, inplace=False: -1,
+        torch.nn.functional.hinge_embedding_loss: (
+            lambda input, target, margin=1.0, size_average=None, reduce=None, reduction="mean": -1
+        ),
+        torch.nn.functional.instance_norm: (
+            lambda input, running_mean=None, running_var=None, weight=None, bias=None, use_input_stats=True, momentum=0.1, eps=1e-05: -1  # noqa: B950
+        ),
+        torch.nn.functional.interpolate: (
+            lambda input, size=None, scale_factor=None, mode="nearest", align_corners=None, recompute_scale_factor=None, antialias=False: -1  # noqa: B950
+        ),
+        torch.nn.functional.kl_div: lambda input, target, size_average=None, reduce=None, reduction="mean", log_target=False: -1,  # noqa: B950
+        torch.nn.functional.l1_loss: lambda input, target, size_average=None, reduce=None, reduction="mean": -1,
         torch.nn.functional.layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,
         torch.nn.functional.leaky_relu: lambda input, negative_slope=0.01, inplace=False: -1,
         torch.nn.functional.linear: lambda input, weight, bias=None: -1,
@@ -874,55 +910,65 @@
         torch.nn.functional.lp_pool1d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
         torch.nn.functional.lp_pool2d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
         torch.nn.functional.lp_pool3d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
-        torch.nn.functional.margin_ranking_loss: (lambda input1, input2, target, margin=0, size_average=None,
-                                                  reduce=None, reduction='mean': -1),
-        torch.nn.functional.max_pool1d: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
-                                         ceil_mode=False, return_indices=False: -1),
-        torch.nn.functional.max_pool1d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
-                                                      return_indices=False, ceil_mode=False: -1),
-        torch.nn.functional.max_pool2d: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
-                                         ceil_mode=False, return_indices=False: -1),
-        torch.nn.functional.max_pool2d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
-                                                      return_indices=False, ceil_mode=False: -1),
-        torch.nn.functional.max_pool3d: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
-                                         return_indices=False, ceil_mode=False: -1),
-        torch.nn.functional.max_pool3d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
-                                                      return_indices=False, ceil_mode=False: -1),
-        torch.nn.functional.max_unpool1d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1,
-        torch.nn.functional.max_unpool2d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1,
-        torch.nn.functional.max_unpool3d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1,
-        torch.nn.functional.mse_loss: lambda input, target, size_average=None, reduce=None, reduction='mean': -1,
+        torch.nn.functional.margin_ranking_loss: (
+            lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1
+        ),
+        torch.nn.functional.max_pool1d: (
+            lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False: -1
+        ),
+        torch.nn.functional.max_pool1d_with_indices: (
+            lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
+        ),
+        torch.nn.functional.max_pool2d: (
+            lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False: -1
+        ),
+        torch.nn.functional.max_pool2d_with_indices: (
+            lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
+        ),
+        torch.nn.functional.max_pool3d: (
+            lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
+        ),
+        torch.nn.functional.max_pool3d_with_indices: (
+            lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
+        ),
+        torch.nn.functional.max_unpool1d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1,  # noqa: B950
+        torch.nn.functional.max_unpool2d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1,  # noqa: B950
+        torch.nn.functional.max_unpool3d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1,  # noqa: B950
+        torch.nn.functional.mse_loss: lambda input, target, size_average=None, reduce=None, reduction="mean": -1,
         torch.nn.functional.multi_head_attention_forward: (
-            lambda query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v,
-            add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training=True, key_padding_mask=None,
-            need_weights=True, attn_mask=None, use_separate_proj_weight=False, q_proj_weight=None, k_proj_weight=None,
-            v_proj_weight=None, static_k=None, static_v=None, average_attn_weights=None, is_causal=False: -1),
-        torch.nn.functional.multi_margin_loss: (lambda input, target, p=1, margin=1.0, weight=None, size_average=None,
-                                                reduce=None, reduction='mean': -1),
-        torch.nn.functional.multilabel_margin_loss: (lambda input, target, size_average=None, reduce=None,
-                                                     reduction='mean': -1),
-        torch.nn.functional.multilabel_soft_margin_loss: (lambda input, target, weight=None, size_average=None,
-                                                          reduce=None, reduction='mean': -1),
-        torch.nn.functional.nll_loss: (lambda input, target, weight=None, size_average=None, ignore_index=-100,
-                                       reduce=None, reduction='mean': -1),
+            lambda query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training=True, key_padding_mask=None, need_weights=True, attn_mask=None, use_separate_proj_weight=False, q_proj_weight=None, k_proj_weight=None, v_proj_weight=None, static_k=None, static_v=None, average_attn_weights=None, is_causal=False: -1  # noqa: B950
+        ),
+        torch.nn.functional.multi_margin_loss: (
+            lambda input, target, p=1, margin=1.0, weight=None, size_average=None, reduce=None, reduction="mean": -1
+        ),
+        torch.nn.functional.multilabel_margin_loss: (
+            lambda input, target, size_average=None, reduce=None, reduction="mean": -1
+        ),
+        torch.nn.functional.multilabel_soft_margin_loss: (
+            lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean": -1
+        ),
+        torch.nn.functional.nll_loss: (
+            lambda input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean": -1
+        ),
         torch.nn.functional.normalize: lambda input, p=2, dim=1, eps=1e-12, out=None: -1,
         torch.nn.functional.one_hot: lambda tensor, num_classes=-1: -1,
-        torch.nn.functional.pad: lambda input, pad, mode='constant', value=0: -1,
+        torch.nn.functional.pad: lambda input, pad, mode="constant", value=0: -1,
         torch.nn.functional.pairwise_distance: lambda x1, x2, p=2.0, eps=1e-06, keepdim=False: -1,
-        torch.nn.functional.poisson_nll_loss: (lambda input, target, log_input=True, full=False, size_average=None,
-                                               eps=1e-08, reduce=None, reduction='mean': -1),
+        torch.nn.functional.poisson_nll_loss: (
+            lambda input, target, log_input=True, full=False, size_average=None, eps=1e-08, reduce=None, reduction="mean": -1  # noqa: B950
+        ),
         torch.nn.functional.prelu: lambda input, weight: -1,
         torch.nn.functional.relu: lambda input, inplace=False: -1,
         torch.nn.functional.relu6: lambda input, inplace=False: -1,
         torch.nn.functional.rms_norm: lambda input, normalized_shape, weight=None, eps=1e-6: -1,
-        torch.nn.functional.rrelu: lambda input, lower=0.125, upper=0.3333333333333333, training=False, inplace=False: -1,
+        torch.nn.functional.rrelu: lambda input, lower=0.125, upper=0.3333333333333333, training=False, inplace=False: -1,  # noqa: B950
         torch.nn.functional.selu: lambda input, inplace=False: -1,
         torch.nn.functional.silu: lambda input, inplace=False: -1,
         torch.nn.functional.mish: lambda input, inplace=False: -1,
         torch.nn.functional.scaled_dot_product_attention: lambda query, key, value, attn_mask=None, dropout_p=0.0: -1,
-        torch.nn.functional.smooth_l1_loss: lambda input, target, size_average=None, reduce=None, reduction='mean', beta=1.: -1,
-        torch.nn.functional.huber_loss: lambda input, target, reduction='mean', delta=1.: -1,
-        torch.nn.functional.soft_margin_loss: lambda input, target, size_average=None, reduce=None, reduction='mean': -1,
+        torch.nn.functional.smooth_l1_loss: lambda input, target, size_average=None, reduce=None, reduction="mean", beta=1.0: -1,  # noqa: B950
+        torch.nn.functional.huber_loss: lambda input, target, reduction="mean", delta=1.0: -1,
+        torch.nn.functional.soft_margin_loss: lambda input, target, size_average=None, reduce=None, reduction="mean": -1,  # noqa: B950
         torch.nn.functional.softmax: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
         torch.nn.functional.softmin: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
         torch.nn.functional.softplus: lambda input, beta=1, threshold=20: -1,
@@ -930,25 +976,29 @@
         torch.nn.functional.softsign: lambda input: -1,
         torch.nn.functional.tanhshrink: lambda input: -1,
         torch.nn.functional.threshold: lambda input, threshold, value, inplace=False: -1,
-        torch.nn.functional.triplet_margin_loss: (lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06,
-                                                  swap=False, size_average=None, reduce=None, reduction='mean': -1),
-        torch.nn.functional.triplet_margin_with_distance_loss: (lambda anchor, positive, negative, *,
-                                                                distance_function=None, margin=1.0,
-                                                                swap=False, reduction='mean': -1),
+        torch.nn.functional.triplet_margin_loss: (
+            lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False, size_average=None, reduce=None, reduction="mean": -1  # noqa: B950
+        ),
+        torch.nn.functional.triplet_margin_with_distance_loss: (
+            lambda anchor, positive, negative, *, distance_function=None, margin=1.0, swap=False, reduction="mean": -1
+        ),
         torch.nn.functional.unfold: lambda input, kernel_size, dilation=1, padding=0, stride=1: -1,
-        torch.nn.init.uniform_: lambda tensor, a=0., b=1., generator=None: -1,
-        torch.nn.init.normal_: lambda tensor, mean=0., std=1., generator=None: -1,
+        torch.nn.init.uniform_: lambda tensor, a=0.0, b=1.0, generator=None: -1,
+        torch.nn.init.normal_: lambda tensor, mean=0.0, std=1.0, generator=None: -1,
         torch.nn.init.constant_: lambda tensor, val: -1,
-        torch.nn.init.kaiming_uniform_: lambda tensor, a=0, mode='fan_in', nonlinearity='leaky_relu', generator=None: -1,
+        torch.nn.init.kaiming_uniform_: lambda tensor, a=0, mode="fan_in", nonlinearity="leaky_relu", generator=None: -1,  # noqa: B950
         torch.nonzero: lambda input, as_tuple=False: -1,
         torch.nonzero_static: lambda input, *, size, fill_value=-1: -1,
         torch.argwhere: lambda input: -1,
-        torch.norm: lambda input, p='fro', dim=None, keepdim=False, out=None, dtype=None: -1,
+        torch.norm: lambda input, p="fro", dim=None, keepdim=False, out=None, dtype=None: -1,
         torch.linalg.norm: lambda input, ord=None, dim=None, keepdim=False, out=None, dtype=None: -1,
         torch.linalg.vector_norm: lambda input, ord=2, dim=None, keepdim=False, out=None, dtype=None: -1,
-        torch.linalg.matrix_norm: lambda input, ord='fro', dim=(-2, -1), keepdim=False, out=None, dtype=None: -1,
+        torch.linalg.matrix_norm: lambda input, ord="fro", dim=(
+            -2,
+            -1,
+        ), keepdim=False, out=None, dtype=None: -1,
         torch.norm_except_dim: lambda v, pow=2, dim=0: -1,
-        torch.nuclear_norm: lambda input, p='fro', dim=None, keepdim=False, out=None, dtype=None: -1,
+        torch.nuclear_norm: lambda input, p="fro", dim=None, keepdim=False, out=None, dtype=None: -1,
         torch.numel: lambda input: -1,
         torch.orgqr: lambda input, tau: -1,
         torch.ormqr: lambda input, input2, input3, left=True, transpose=False: -1,
@@ -975,28 +1025,43 @@
         torch.q_scale: lambda input: -1,
         torch.q_zero_point: lambda input: -1,
         torch.qr: lambda input, some=True, out=None: -1,
-        torch.linalg.qr: lambda input, mode='reduced', out=None: -1,
-        torch.quantile: lambda input, q, dim=None, keepdim=False, interpolation='linear', out=None: -1,
-        torch.nanquantile: lambda input, q, dim=None, keepdim=False, interpolation='linear', out=None: -1,
+        torch.linalg.qr: lambda input, mode="reduced", out=None: -1,
+        torch.quantile: lambda input, q, dim=None, keepdim=False, interpolation="linear", out=None: -1,
+        torch.nanquantile: lambda input, q, dim=None, keepdim=False, interpolation="linear", out=None: -1,
         torch.quantize_per_channel: lambda input, scales, zero_points, axis, dtype: -1,
         torch.quantize_per_tensor: lambda input, scale, zero_point, dtype: -1,
         torch.quantize_per_tensor_dynamic: lambda input, dtype, reduce_range: -1,
         torch.quantized_batch_norm: lambda input, weight, bias, mean, var, eps, output_scale, output_zero_point: -1,
-        torch.quantized_gru_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
-                                   col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
-
-        torch.quantized_lstm_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
-                                    col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
-        torch.quantized_max_pool1d: (lambda input, kernel_size, stride=tuple(), padding=(0,),
-                                     dilation=(1,), ceil_mode=False: -1),
-        torch.quantized_max_pool2d: (lambda input, kernel_size, stride=tuple(), padding=(0, 0),
-                                     dilation=(1, 1), ceil_mode=False: -1),
-        torch.quantized_max_pool3d: (lambda input, kernel_size, stride=tuple(), padding=(0, 0, 0),
-                                     dilation=(1, 1, 1), ceil_mode=False: -1),
-        torch.quantized_rnn_relu_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
-                                        col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
-        torch.quantized_rnn_tanh_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
-                                        col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
+        torch.quantized_gru_cell: (
+            lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1  # noqa: B950
+        ),
+        torch.quantized_lstm_cell: (
+            lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1  # noqa: B950
+        ),
+        torch.quantized_max_pool1d: (
+            lambda input, kernel_size, stride=tuple(), padding=(0,), dilation=(
+                1,
+            ), ceil_mode=False: -1
+        ),
+        torch.quantized_max_pool2d: (
+            lambda input, kernel_size, stride=tuple(), padding=(0, 0), dilation=(
+                1,
+                1,
+            ), ceil_mode=False: -1
+        ),
+        torch.quantized_max_pool3d: (
+            lambda input, kernel_size, stride=tuple(), padding=(0, 0, 0), dilation=(
+                1,
+                1,
+                1,
+            ), ceil_mode=False: -1
+        ),
+        torch.quantized_rnn_relu_cell: (
+            lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1  # noqa: B950
+        ),
+        torch.quantized_rnn_tanh_cell: (
+            lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1  # noqa: B950
+        ),
         torch.rad2deg: lambda input, out=None: -1,
         torch.rand_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
         torch.randint_like: lambda input, high, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1,
@@ -1014,16 +1079,16 @@
         torch.repeat_interleave: lambda input, dim=None: -1,
         torch.reshape: lambda input, shape: -1,
         torch.rms_norm: lambda input, normalized_shape, weight=None, eps=1e-6: -1,
-        torch.rnn_relu: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1,
+        torch.rnn_relu: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1,  # noqa: B950
         torch.rnn_relu_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
-        torch.rnn_tanh: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1,
+        torch.rnn_tanh: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1,  # noqa: B950
         torch.rnn_tanh_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
         torch.roll: lambda input, shifts, dims=None: -1,
         torch.rot90: lambda input, k=1, dims=(0, 1): -1,
         torch.round: lambda input, out=None: -1,
         torch.row_stack: lambda tensors, out=None: -1,  # alias for torch.vstack
         torch._rowwise_prune: (lambda weight, mask, compressed_indices_dtype: -1),
-        torch.rrelu: lambda input, lower=1. / 8, upper=1. / 3, training=False, inplace=False: -1,
+        torch.rrelu: lambda input, lower=1.0 / 8, upper=1.0 / 3, training=False, inplace=False: -1,
         torch.rsqrt: lambda input, out=None: -1,
         torch.rsub: lambda input, other, alpha=1: -1,
         torch.saddmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1,
@@ -1031,7 +1096,7 @@
         torch.scatter_add: lambda input, dim, index, src: -1,
         torch.scatter_reduce: lambda input, dim, index, src, reduce, include_self=True: -1,
         torch.searchsorted: lambda sorted_sequence, input, out_int32=False, right=False, out=None: -1,
-        torch._segment_reduce: lambda data, reduce="max", lengths=None, indices=None, offsets=None, axis=0, unsafe=False: -1,
+        torch._segment_reduce: lambda data, reduce="max", lengths=None, indices=None, offsets=None, axis=0, unsafe=False: -1,  # noqa: B950
         torch.select: lambda input, dim, index: -1,
         torch.select_scatter: lambda input, src, dim, index: -1,
         torch.slice_inverse: lambda input, src, dim=0, start=None, end=None, step=1: -1,
@@ -1061,8 +1126,9 @@
         torch.stack: lambda tensors, dim=0, out=None: -1,
         torch.std: lambda input, dim=None: -1,
         torch.std_mean: lambda input, dim=None: -1,
-        torch.stft: (lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True,
-                     pad_mode='reflect', normalized=False, onesided=True, return_complex=None: -1),
+        torch.stft: (
+            lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, pad_mode="reflect", normalized=False, onesided=True, return_complex=None: -1  # noqa: B950
+        ),
         torch.sub: lambda input, other, out=None: -1,
         torch.subtract: lambda input, other, out=None: -1,
         torch.sum: lambda input, dim=None: -1,
@@ -1164,9 +1230,9 @@
         torch.triangular_solve: lambda input, A, upper=True, transpose=False, unitriangular=False: -1,
         torch.linalg.solve_triangular: lambda input, B, upper, left=True, unitriangular=False: -1,
         torch.tril: lambda input, diagonal=0, out=None: -1,
-        torch.triplet_margin_loss: (lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False,
-
-                                    size_average=None, reduce=None, reduction='mean': -1),
+        torch.triplet_margin_loss: (
+            lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False, size_average=None, reduce=None, reduction="mean": -1  # noqa: B950
+        ),
         torch.triu: lambda input, diagonal=0, out=None: -1,
         torch.true_divide: lambda input, other: -1,
         torch.trunc: lambda input, out=None: -1,
@@ -1436,10 +1502,16 @@
         torch.linalg.lstsq: lambda self, b, cond=None, driver=None: -1,
     }
 
-    privateuse1_backend_name = torch.utils.backend_registration._privateuse1_backend_name
+    privateuse1_backend_name = (
+        torch.utils.backend_registration._privateuse1_backend_name
+    )
     if hasattr(Tensor, privateuse1_backend_name):
-        ret[getattr(Tensor, privateuse1_backend_name)] = lambda self, device=None, non_blocking=False, **kwargs: -1
-        ret[getattr(Tensor, f'is_{privateuse1_backend_name}').__get__] = lambda self: -1  # noqa: B009
+        ret[
+            getattr(Tensor, privateuse1_backend_name)
+        ] = lambda self, device=None, non_blocking=False, **kwargs: -1
+        ret[
+            getattr(Tensor, f"is_{privateuse1_backend_name}").__get__
+        ] = lambda self: -1  # noqa: B009
 
     ret2 = {}
     ignored = get_ignored_functions()
@@ -1457,12 +1529,10 @@
         if k.__name__.startswith("bitwise_"):
             # bitwise_<op> have dunder methods of the form __<op>__
             # And so on.
-            subname = k.__name__[len("bitwise_"):]
-            names.extend([
-                "__" + subname + "__",
-                "__i" + subname + "__",
-                "__r" + subname + "__"
-            ])
+            subname = k.__name__[len("bitwise_") :]
+            names.extend(
+                ["__" + subname + "__", "__i" + subname + "__", "__r" + subname + "__"]
+            )
 
         for name in names:
             func = getattr(Tensor, name, None)
@@ -1472,6 +1542,7 @@
     ret.update(ret2)
     return ret
 
+
 def wrap_torch_function(dispatcher: Callable):
     """Wraps a given function with ``__torch_function__`` -related functionality.
 
@@ -1495,6 +1566,7 @@
     >>> def func(a): # This will make func dispatchable by __torch_function__
     ...     return a + 0
     """
+
     def inner(func):
         @functools.wraps(func)
         def wrapped(*args, **kwargs):
@@ -1508,7 +1580,10 @@
 
     return inner
 
-def _get_overloaded_args(relevant_args: Iterable[Any], get_type_fn: Callable[[Any], Type] = None) -> List[Any]:
+
+def _get_overloaded_args(
+    relevant_args: Iterable[Any], get_type_fn: Callable[[Any], Type] = None
+) -> List[Any]:
     """Returns a list of arguments on which to call __torch_function__.
 
     Checks arguments in relevant_args for __torch_function__ implementations,
@@ -1559,8 +1634,11 @@
         #
         # NB: Important to exclude _disabled_torch_function_impl, otherwise
         # https://github.com/pytorch/pytorch/issues/64687
-        if (arg_type not in overloaded_types and hasattr(arg_type, '__torch_function__') and
-                arg_type.__torch_function__ != torch._C._disabled_torch_function_impl):
+        if (
+            arg_type not in overloaded_types
+            and hasattr(arg_type, "__torch_function__")
+            and arg_type.__torch_function__ != torch._C._disabled_torch_function_impl
+        ):
             # Create lists explicitly for the first type (usually the only one
             # done) to avoid setting up the iterator for overloaded_args.
             if overloaded_types:
@@ -1581,7 +1659,8 @@
 
 
 def handle_torch_function(
-        public_api: Callable, relevant_args: Iterable[Any], *args, **kwargs) -> Any:
+    public_api: Callable, relevant_args: Iterable[Any], *args, **kwargs
+) -> Any:
     """Implement a function with checks for ``__torch_function__`` overrides.
 
     See torch::autograd::handle_torch_function for the equivalent of this
@@ -1636,11 +1715,16 @@
         # This call needs to become a classmethod call in the future.
         # See https://github.com/pytorch/pytorch/issues/63767
         torch_func_method = overloaded_arg.__torch_function__
-        if hasattr(torch_func_method, "__self__") and torch_func_method.__self__ is overloaded_arg and \
-                torch_func_method is not torch._C._disabled_torch_function_impl:
-            warnings.warn("Defining your `__torch_function__ as a plain method is deprecated and "
-                          "will be an error in future, please define it as a classmethod.",
-                          DeprecationWarning)
+        if (
+            hasattr(torch_func_method, "__self__")
+            and torch_func_method.__self__ is overloaded_arg
+            and torch_func_method is not torch._C._disabled_torch_function_impl
+        ):
+            warnings.warn(
+                "Defining your `__torch_function__ as a plain method is deprecated and "
+                "will be an error in future, please define it as a classmethod.",
+                DeprecationWarning,
+            )
 
         # Use `public_api` instead of `implementation` so __torch_function__
         # implementations can do equality/identity comparisons.
@@ -1649,15 +1733,16 @@
         if result is not NotImplemented:
             return result
 
-    func_name = f'{public_api.__module__}.{public_api.__name__}'
+    func_name = f"{public_api.__module__}.{public_api.__name__}"
     msg = (
         f"no implementation found for '{func_name}' on types that implement "
-        f'__torch_function__: {[type(arg) for arg in overloaded_args]}'
+        f"__torch_function__: {[type(arg) for arg in overloaded_args]}"
     )
     if _is_torch_function_mode_enabled():
         msg += f" nor in mode {_get_current_function_mode()}"
     raise TypeError(msg)
 
+
 has_torch_function = _add_docstr(
     _has_torch_function,
     r"""Check for __torch_function__ implementations in the elements of an iterable
@@ -1678,7 +1763,7 @@
     ________
     torch.is_tensor_like
         Checks if something is a Tensor-like, including an exact ``Tensor``.
-    """
+    """,
 )
 
 has_torch_function_unary = _add_docstr(
@@ -1689,7 +1774,7 @@
     call:
       `has_torch_function_unary(t)`
     which skips unnecessary packing and unpacking work.
-    """
+    """,
 )
 
 has_torch_function_variadic = _add_docstr(
@@ -1703,11 +1788,14 @@
     call:
       `has_torch_function_variadic(a, b)`
     which skips unnecessary packing and unpacking work.
-    """
+    """,
 )
 
+
 @functools.lru_cache(None)
-def _get_overridable_functions() -> Tuple[Dict[Any, List[Callable]], Dict[Callable, str]]:
+def _get_overridable_functions() -> (
+    Tuple[Dict[Any, List[Callable]], Dict[Callable, str]]
+):
     overridable_funcs = collections.defaultdict(list)
     index = {}
     tested_namespaces = [
@@ -1725,21 +1813,21 @@
             ignore = False
             # ignore private functions or functions that are deleted in torch.__init__
             if namespace is not torch.Tensor:
-                if func_name.startswith('__'):
+                if func_name.startswith("__"):
                     continue
-                elif func_name.startswith('_'):
+                elif func_name.startswith("_"):
                     ignore = True
-                elif func_name.endswith('_'):
+                elif func_name.endswith("_"):
                     ignore = True
                 elif not func_name[0].islower():
                     ignore = True
-                elif func_name == 'unique_dim':
+                elif func_name == "unique_dim":
                     continue
             else:
                 func = getattr(namespace, func_name)
                 if getattr(object, func_name, None) == func:
                     continue
-                if func_name == '__weakref__':
+                if func_name == "__weakref__":
                     continue
             func = getattr(namespace, func_name)
             if namespace is torch.Tensor and getattr(object, func_name, None) == func:
@@ -1757,9 +1845,13 @@
                 if ignore:
                     continue
                 if func.__get__ in get_ignored_functions():
-                    msg = ("{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
-                           "but still has an explicit override")
-                    assert func.__get__ not in get_testing_overrides(), msg.format(namespace, func.__name__)
+                    msg = (
+                        "{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
+                        "but still has an explicit override"
+                    )
+                    assert func.__get__ not in get_testing_overrides(), msg.format(
+                        namespace, func.__name__
+                    )
                     continue
                 else:
                     overridable_funcs[func].append(func.__get__)
@@ -1775,13 +1867,18 @@
 
             # cannot be overriden by __torch_function__
             if func in get_ignored_functions():
-                msg = ("{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
-                       "but still has an explicit override")
-                assert func not in get_testing_overrides(), msg.format(namespace, func.__name__)
+                msg = (
+                    "{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
+                    "but still has an explicit override"
+                )
+                assert func not in get_testing_overrides(), msg.format(
+                    namespace, func.__name__
+                )
                 continue
             overridable_funcs[namespace].append(func)
     return overridable_funcs, index
 
+
 @_disable_user_warnings
 def get_overridable_functions() -> Dict[Any, List[Callable]]:
     """List functions that are overridable via __torch_function__
@@ -1794,6 +1891,7 @@
     """
     return _get_overridable_functions()[0]
 
+
 @_disable_user_warnings
 def resolve_name(f):
     """Get a human readable string name for a function passed to
@@ -1814,13 +1912,15 @@
         return str(f)
     return _get_overridable_functions()[1].get(f)
 
+
 @functools.lru_cache(None)
 def _get_tensor_methods() -> Set[Callable]:
-    """ Returns a set of the overridable methods on ``torch.Tensor`` """
+    """Returns a set of the overridable methods on ``torch.Tensor``"""
     overridable_funcs = get_overridable_functions()
     methods = set(overridable_funcs[torch.Tensor])
     return methods
 
+
 @_disable_user_warnings
 def is_tensor_method_or_property(func: Callable) -> bool:
     """
@@ -1846,6 +1946,7 @@
     """
     return func in _get_tensor_methods() or func.__name__ == "__get__"
 
+
 def is_tensor_like(inp):
     """
     Returns ``True`` if the passed-in input is a Tensor-like.
@@ -1882,6 +1983,7 @@
     """
     return type(inp) is torch.Tensor or hasattr(inp, "__torch_function__")
 
+
 class TorchFunctionMode:
     """
     A ``TorchFunctionMode`` allows you to override the meaning of all
@@ -1912,6 +2014,7 @@
     ``enable_torch_function_mode(self, replace=self.inner)`` to make PyTorch
     API self-referential (beware of infinite loops, in this case!)
     """
+
     inner: "TorchFunctionMode"
 
     # Force metaclass to generate constructor at the base of the hierarchy
@@ -1930,7 +2033,9 @@
 
     @classmethod
     def push(cls, *args, **kwargs):
-        warnings.warn("`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`")
+        warnings.warn(
+            "`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`"
+        )
         instance = cls(*args, **kwargs)
         return instance
 
@@ -1944,6 +2049,7 @@
     stack_len = _len_torch_function_stack()
     return [_get_function_stack_at(i) for i in range(stack_len)]
 
+
 def _push_mode(mode):
     _push_on_torch_function_stack(mode)
 
@@ -1961,6 +2067,7 @@
     finally:
         _push_mode(old)
 
+
 class BaseTorchFunctionMode(TorchFunctionMode):
     def __torch_function__(self, func, types, args=(), kwargs=None):
         if kwargs is None:
diff --git a/torch/quasirandom.py b/torch/quasirandom.py
index a121801..509be4f 100644
--- a/torch/quasirandom.py
+++ b/torch/quasirandom.py
@@ -1,7 +1,8 @@
 # mypy: allow-untyped-defs
-import torch
 from typing import Optional
 
+import torch
+
 
 class SobolEngine:
     r"""
@@ -48,8 +49,10 @@
 
     def __init__(self, dimension, scramble=False, seed=None):
         if dimension > self.MAXDIM or dimension < 1:
-            raise ValueError("Supported range of dimensionality "
-                             f"for SobolEngine is [1, {self.MAXDIM}]")
+            raise ValueError(
+                "Supported range of dimensionality "
+                f"for SobolEngine is [1, {self.MAXDIM}]"
+            )
 
         self.seed = seed
         self.scramble = scramble
@@ -57,7 +60,9 @@
 
         cpu = torch.device("cpu")
 
-        self.sobolstate = torch.zeros(dimension, self.MAXBIT, device=cpu, dtype=torch.long)
+        self.sobolstate = torch.zeros(
+            dimension, self.MAXBIT, device=cpu, dtype=torch.long
+        )
         torch._sobol_engine_initialize_state_(self.sobolstate, self.dimension)
 
         if not self.scramble:
@@ -66,11 +71,15 @@
             self._scramble()
 
         self.quasi = self.shift.clone(memory_format=torch.contiguous_format)
-        self._first_point = (self.quasi / 2 ** self.MAXBIT).reshape(1, -1)
+        self._first_point = (self.quasi / 2**self.MAXBIT).reshape(1, -1)
         self.num_generated = 0
 
-    def draw(self, n: int = 1, out: Optional[torch.Tensor] = None,
-             dtype: Optional[torch.dtype] = None) -> torch.Tensor:
+    def draw(
+        self,
+        n: int = 1,
+        out: Optional[torch.Tensor] = None,
+        dtype: Optional[torch.dtype] = None,
+    ) -> torch.Tensor:
         r"""
         Function to draw a sequence of :attr:`n` points from a Sobol sequence.
         Note that the samples are dependent on the previous samples. The size
@@ -92,12 +101,22 @@
                 result = self._first_point.to(dtype)
             else:
                 result, self.quasi = torch._sobol_engine_draw(
-                    self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated, dtype=dtype,
+                    self.quasi,
+                    n - 1,
+                    self.sobolstate,
+                    self.dimension,
+                    self.num_generated,
+                    dtype=dtype,
                 )
                 result = torch.cat((self._first_point.to(dtype), result), dim=-2)
         else:
             result, self.quasi = torch._sobol_engine_draw(
-                self.quasi, n, self.sobolstate, self.dimension, self.num_generated - 1, dtype=dtype,
+                self.quasi,
+                n,
+                self.sobolstate,
+                self.dimension,
+                self.num_generated - 1,
+                dtype=dtype,
             )
 
         self.num_generated += n
@@ -108,8 +127,12 @@
 
         return result
 
-    def draw_base2(self, m: int, out: Optional[torch.Tensor] = None,
-                   dtype: Optional[torch.dtype] = None) -> torch.Tensor:
+    def draw_base2(
+        self,
+        m: int,
+        out: Optional[torch.Tensor] = None,
+        dtype: Optional[torch.dtype] = None,
+    ) -> torch.Tensor:
         r"""
         Function to draw a sequence of :attr:`2**m` points from a Sobol sequence.
         Note that the samples are dependent on the previous samples. The size
@@ -122,15 +145,16 @@
                                                     returned tensor.
                                                     Default: ``None``
         """
-        n = 2 ** m
+        n = 2**m
         total_n = self.num_generated + n
         if not (total_n & (total_n - 1) == 0):
-            raise ValueError("The balance properties of Sobol' points require "
-                             f"n to be a power of 2. {self.num_generated} points have been "
-                             f"previously generated, then: n={self.num_generated}+2**{m}={total_n}. "
-                             "If you still want to do this, please use "
-                             "'SobolEngine.draw()' instead."
-                             )
+            raise ValueError(
+                "The balance properties of Sobol' points require "
+                f"n to be a power of 2. {self.num_generated} points have been "
+                f"previously generated, then: n={self.num_generated}+2**{m}={total_n}. "
+                "If you still want to do this, please use "
+                "'SobolEngine.draw()' instead."
+            )
         return self.draw(n=n, out=out, dtype=dtype)
 
     def reset(self):
@@ -151,9 +175,13 @@
             n (Int): The number of steps to fast-forward by.
         """
         if self.num_generated == 0:
-            torch._sobol_engine_ff_(self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated)
+            torch._sobol_engine_ff_(
+                self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated
+            )
         else:
-            torch._sobol_engine_ff_(self.quasi, n, self.sobolstate, self.dimension, self.num_generated - 1)
+            torch._sobol_engine_ff_(
+                self.quasi, n, self.sobolstate, self.dimension, self.num_generated - 1
+            )
         self.num_generated += n
         return self
 
@@ -166,8 +194,12 @@
         cpu = torch.device("cpu")
 
         # Generate shift vector
-        shift_ints = torch.randint(2, (self.dimension, self.MAXBIT), device=cpu, generator=g)
-        self.shift = torch.mv(shift_ints, torch.pow(2, torch.arange(0, self.MAXBIT, device=cpu)))
+        shift_ints = torch.randint(
+            2, (self.dimension, self.MAXBIT), device=cpu, generator=g
+        )
+        self.shift = torch.mv(
+            shift_ints, torch.pow(2, torch.arange(0, self.MAXBIT, device=cpu))
+        )
 
         # Generate lower triangular matrices (stacked across dimensions)
         ltm_dims = (self.dimension, self.MAXBIT, self.MAXBIT)
@@ -176,9 +208,9 @@
         torch._sobol_engine_scramble_(self.sobolstate, ltm, self.dimension)
 
     def __repr__(self):
-        fmt_string = [f'dimension={self.dimension}']
+        fmt_string = [f"dimension={self.dimension}"]
         if self.scramble:
-            fmt_string += ['scramble=True']
+            fmt_string += ["scramble=True"]
         if self.seed is not None:
-            fmt_string += [f'seed={self.seed}']
-        return self.__class__.__name__ + '(' + ', '.join(fmt_string) + ')'
+            fmt_string += [f"seed={self.seed}"]
+        return self.__class__.__name__ + "(" + ", ".join(fmt_string) + ")"
diff --git a/torch/random.py b/torch/random.py
index 0916fe1..7833311 100644
--- a/torch/random.py
+++ b/torch/random.py
@@ -1,10 +1,10 @@
 # mypy: allow-untyped-defs
 import contextlib
-from typing import Generator
 import warnings
+from typing import Generator
 
-from torch._C import default_generator
 import torch
+from torch._C import default_generator
 
 
 def set_rng_state(new_state: torch.Tensor) -> None:
@@ -46,10 +46,12 @@
         torch.cuda.manual_seed_all(seed)
 
     import torch.mps
+
     if not torch.mps._is_in_bad_fork():
         torch.mps.manual_seed(seed)
 
     import torch.xpu
+
     if not torch.xpu._is_in_bad_fork():
         torch.xpu.manual_seed_all(seed)
 
@@ -69,10 +71,12 @@
         torch.cuda.manual_seed_all(seed)
 
     import torch.mps
+
     if not torch.mps._is_in_bad_fork():
         torch.mps.manual_seed(seed)
 
     import torch.xpu
+
     if not torch.xpu._is_in_bad_fork():
         torch.xpu.manual_seed_all(seed)
 
@@ -95,7 +99,9 @@
         custom_device_mod = getattr(torch, custom_backend_name)
         _bad_fork_name = "_is_in_bad_fork"
         _seed_all_name = "manual_seed_all"
-        if hasattr(custom_device_mod, _bad_fork_name) and hasattr(custom_device_mod, _seed_all_name):
+        if hasattr(custom_device_mod, _bad_fork_name) and hasattr(
+            custom_device_mod, _seed_all_name
+        ):
             if not getattr(custom_device_mod, _bad_fork_name)():
                 getattr(custom_device_mod, _seed_all_name)(seed)
         else:
@@ -117,7 +123,13 @@
 
 
 @contextlib.contextmanager
-def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="devices", device_type="cuda") -> Generator:
+def fork_rng(
+    devices=None,
+    enabled=True,
+    _caller="fork_rng",
+    _devices_kw="devices",
+    device_type="cuda",
+) -> Generator:
     """
     Forks the RNG, so that when you return, the RNG is reset
     to the state that it was previously in.
@@ -138,8 +150,10 @@
     device_type = torch.device(device_type).type
     device_mod = getattr(torch, device_type, None)
     if device_mod is None:
-        raise RuntimeError(f"torch has no module of `{device_type}`, you should register " +
-                           "a module by `torch._register_device_module`.")
+        raise RuntimeError(
+            f"torch has no module of `{device_type}`, you should register "
+            + "a module by `torch._register_device_module`."
+        )
     global _fork_rng_warned_already
 
     # Internal arguments:
@@ -153,17 +167,19 @@
     if devices is None:
         num_devices = device_mod.device_count()
         if num_devices > 1 and not _fork_rng_warned_already:
-            message = (f"{device_type.upper()} reports that you have {num_devices} available devices, and "
-                       f"you have used {_caller} without explicitly specifying which devices are being used. "
-                       f"For safety, we initialize *every* {device_type.upper()} device by default, which can "
-                       f"be quite slow if you have a lot of {device_type.upper()}s. If you know that you are only"
-                       f" making use of a few {device_type.upper()} devices, set the environment variable "
-                       f"{device_type.upper()}_VISIBLE_DEVICES or the '{_devices_kw}' keyword argument of {_caller} "
-                       "with the set of devices you are actually using. For example, if you are using CPU only, "
-                       "set device.upper()_VISIBLE_DEVICES= or devices=[]; if you are using device 0 only, "
-                       f"set {device_type.upper()}_VISIBLE_DEVICES=0 or devices=[0].  To initialize all devices "
-                       f"and suppress this warning, set the '{_devices_kw}' keyword argument to "
-                       f"`range(torch.{device_type}.device_count())`.")
+            message = (
+                f"{device_type.upper()} reports that you have {num_devices} available devices, and "
+                f"you have used {_caller} without explicitly specifying which devices are being used. "
+                f"For safety, we initialize *every* {device_type.upper()} device by default, which can "
+                f"be quite slow if you have a lot of {device_type.upper()}s. If you know that you are only"
+                f" making use of a few {device_type.upper()} devices, set the environment variable "
+                f"{device_type.upper()}_VISIBLE_DEVICES or the '{_devices_kw}' keyword argument of {_caller} "
+                "with the set of devices you are actually using. For example, if you are using CPU only, "
+                "set device.upper()_VISIBLE_DEVICES= or devices=[]; if you are using device 0 only, "
+                f"set {device_type.upper()}_VISIBLE_DEVICES=0 or devices=[0].  To initialize all devices "
+                f"and suppress this warning, set the '{_devices_kw}' keyword argument to "
+                f"`range(torch.{device_type}.device_count())`."
+            )
             warnings.warn(message)
             _fork_rng_warned_already = True
         devices = list(range(num_devices))
diff --git a/torch/return_types.py b/torch/return_types.py
index a74150a..d456742 100644
--- a/torch/return_types.py
+++ b/torch/return_types.py
@@ -1,8 +1,9 @@
-import torch
 import inspect
 
+import torch
 from torch.utils._pytree import register_pytree_node, SequenceKey
 
+
 __all__ = ["pytree_register_structseq", "all_return_types"]
 
 all_return_types = []
@@ -10,6 +11,7 @@
 # error: Module has no attribute "_return_types"
 return_types = torch._C._return_types  # type: ignore[attr-defined]
 
+
 def pytree_register_structseq(cls):
     def structseq_flatten(structseq):
         return list(structseq), None
@@ -28,14 +30,15 @@
         flatten_with_keys_fn=structseq_flatten_with_keys,
     )
 
+
 for name in dir(return_types):
-    if name.startswith('__'):
+    if name.startswith("__"):
         continue
 
     _attr = getattr(return_types, name)
     globals()[name] = _attr
 
-    if not name.startswith('_'):
+    if not name.startswith("_"):
         __all__.append(name)
         all_return_types.append(_attr)
 
diff --git a/torch/serialization.py b/torch/serialization.py
index 311aac2..738af26 100644
--- a/torch/serialization.py
+++ b/torch/serialization.py
@@ -1,70 +1,87 @@
 # mypy: allow-untyped-defs
+import copyreg
 import difflib
 import functools
-import os
 import io
+import os
+import pickle
 import shutil
 import struct
 import sys
-import torch
 import tarfile
 import tempfile
 import warnings
 from contextlib import closing, contextmanager
 from enum import Enum
-from ._utils import _import_dotted_name
-from torch._sources import get_source_lines_and_file
-from torch.types import Storage
-from torch.storage import _get_dtype_from_pickle_storage_type
-from typing import Any, BinaryIO, Callable, cast, Dict, Optional, Type, Tuple, Union, IO, List
+from typing import (
+    Any,
+    BinaryIO,
+    Callable,
+    cast,
+    Dict,
+    IO,
+    List,
+    Optional,
+    Tuple,
+    Type,
+    Union,
+)
 from typing_extensions import TypeAlias, TypeGuard  # Python 3.10+
-import copyreg
-import pickle
+
+import torch
 import torch._weights_only_unpickler as _weights_only_unpickler
+from torch._sources import get_source_lines_and_file
+from torch._utils import _import_dotted_name
+from torch.storage import _get_dtype_from_pickle_storage_type
+from torch.types import Storage
+
+
+__all__ = [
+    "SourceChangeWarning",
+    "mkdtemp",
+    "register_package",
+    "check_module_version_greater_or_equal",
+    "validate_cuda_device",
+    "validate_hpu_device",
+    "location_tag",
+    "default_restore_location",
+    "normalize_storage_type",
+    "storage_to_tensor_type",
+    "save",
+    "load",
+    "StorageType",
+    "LoadEndianness",
+    "get_default_load_endianness",
+    "set_default_load_endianness",
+    "clear_safe_globals",
+    "get_safe_globals",
+    "add_safe_globals",
+]
+
 
 DEFAULT_PROTOCOL = 2
 
-LONG_SIZE = struct.Struct('=l').size
-INT_SIZE = struct.Struct('=i').size
-SHORT_SIZE = struct.Struct('=h').size
+LONG_SIZE = struct.Struct("=l").size
+INT_SIZE = struct.Struct("=i").size
+SHORT_SIZE = struct.Struct("=h").size
 
-MAGIC_NUMBER = 0x1950a86a20f9469cfc6c
+MAGIC_NUMBER = 0x1950A86A20F9469CFC6C
 PROTOCOL_VERSION = 1001
-STORAGE_KEY_SEPARATOR = ','
+STORAGE_KEY_SEPARATOR = ","
 
 FILE_LIKE: TypeAlias = Union[str, os.PathLike, BinaryIO, IO[bytes]]
-MAP_LOCATION: TypeAlias = Optional[Union[Callable[[Storage, str], Storage], torch.device, str, Dict[str, str]]]
+MAP_LOCATION: TypeAlias = Optional[
+    Union[Callable[[Storage, str], Storage], torch.device, str, Dict[str, str]]
+]
 STORAGE: TypeAlias = Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage]
 
 IS_WINDOWS = sys.platform == "win32"
 
 if not IS_WINDOWS:
-    from mmap import MAP_SHARED, MAP_PRIVATE
+    from mmap import MAP_PRIVATE, MAP_SHARED
 else:
     MAP_SHARED, MAP_PRIVATE = None, None  # type: ignore[assignment]
 
-__all__ = [
-    'SourceChangeWarning',
-    'mkdtemp',
-    'register_package',
-    'check_module_version_greater_or_equal',
-    'validate_cuda_device',
-    'validate_hpu_device',
-    'location_tag',
-    'default_restore_location',
-    'normalize_storage_type',
-    'storage_to_tensor_type',
-    'save',
-    'load',
-    'StorageType',
-    'LoadEndianness',
-    'get_default_load_endianness',
-    'set_default_load_endianness',
-    'clear_safe_globals',
-    'get_safe_globals',
-    'add_safe_globals',
-]
-
 
 class SourceChangeWarning(Warning):
     pass
@@ -79,17 +96,26 @@
         shutil.rmtree(path)
 
 
-_package_registry: List[Tuple[int, Callable[[STORAGE], Optional[str]], Callable[[STORAGE, str], Optional[STORAGE]]]] = []
+_package_registry: List[
+    Tuple[
+        int,
+        Callable[[STORAGE], Optional[str]],
+        Callable[[STORAGE, str], Optional[STORAGE]],
+    ]
+] = []
+
 
 class LoadEndianness(Enum):
     NATIVE = 1
     LITTLE = 2
     BIG = 3
 
+
 _default_load_endian: Optional[LoadEndianness] = None
 
+
 def get_default_load_endianness() -> Optional[LoadEndianness]:
-    '''
+    """
     Get fallback byte order for loading files
 
     If byteorder mark is not present in saved checkpoint,
@@ -98,11 +124,12 @@
 
     Returns:
         default_load_endian: Optional[LoadEndianness]
-    '''
+    """
     return _default_load_endian
 
+
 def set_default_load_endianness(endianness):
-    '''
+    """
     Set fallback byte order for loading files
 
     If byteorder mark is not present in saved checkpoint,
@@ -111,16 +138,18 @@
 
     Args:
         endianness: the new fallback byte order
-    '''
+    """
     global _default_load_endian
     if not isinstance(endianness, LoadEndianness) and endianness is not None:
         raise TypeError("Invalid argument type in function set_default_load_endianness")
     _default_load_endian = endianness
 
+
 _default_mmap_options: int = MAP_PRIVATE
 
+
 def get_default_mmap_options() -> int:
-    '''
+    """
     Get default mmap options for :func:`torch.load` with ``mmap=True``.
 
     Defaults to ``mmap.MAP_PRIVATE``.
@@ -128,11 +157,12 @@
 
     Returns:
         default_mmap_options: int
-    '''
+    """
     return _default_mmap_options
 
+
 def set_default_mmap_options(flags: int):
-    '''
+    """
     Set default mmap options for :func:`torch.load` with ``mmap=True`` to flags.
 
     For now, only either ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED`` are supported.
@@ -143,36 +173,44 @@
 
     Args:
         flags: ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED``
-    '''
+    """
     global _default_mmap_options
     if IS_WINDOWS:
-        raise RuntimeError("Changing the default mmap options is currently not supported for Windows")
-    if (flags != MAP_PRIVATE and flags != MAP_SHARED):
-        raise ValueError("Invalid argument in function set_default_mmap_options, "
-                         f"expected mmap.MAP_PRIVATE or mmap.MAP_SHARED, but got {flags}")
+        raise RuntimeError(
+            "Changing the default mmap options is currently not supported for Windows"
+        )
+    if flags != MAP_PRIVATE and flags != MAP_SHARED:
+        raise ValueError(
+            "Invalid argument in function set_default_mmap_options, "
+            f"expected mmap.MAP_PRIVATE or mmap.MAP_SHARED, but got {flags}"
+        )
     _default_mmap_options = flags
 
+
 def clear_safe_globals() -> None:
-    '''
+    """
     Clears the list of globals that are safe for ``weights_only`` load.
-    '''
+    """
     _weights_only_unpickler._clear_safe_globals()
 
+
 def get_safe_globals() -> List[Any]:
-    '''
+    """
     Returns the list of user-added globals that are safe for ``weights_only`` load.
-    '''
+    """
     return _weights_only_unpickler._get_safe_globals()
 
+
 def add_safe_globals(safe_globals: List[Any]) -> None:
-    '''
+    """
     Marks the given globals as safe for ``weights_only`` load.
 
     Args:
         safe_globals (List[Any]): list of globals to mark as safe
-    '''
+    """
     _weights_only_unpickler._add_safe_globals(safe_globals)
 
+
 def _is_zipfile(f) -> bool:
     # This is a stricter implementation than zipfile.is_zipfile().
     # zipfile.is_zipfile() is True if the magic number appears anywhere in the
@@ -183,7 +221,7 @@
 
     start = f.tell()
     # Read the first few bytes and match against the ZIP file signature
-    local_header_magic_number = b'PK\x03\x04'
+    local_header_magic_number = b"PK\x03\x04"
     read_bytes = f.read(len(local_header_magic_number))
     f.seek(start)
     return read_bytes == local_header_magic_number
@@ -192,9 +230,9 @@
 def register_package(
     priority: int,
     tagger: Callable[[STORAGE], Optional[str]],
-    deserializer: Callable[[STORAGE, str], Optional[STORAGE]]
+    deserializer: Callable[[STORAGE, str], Optional[STORAGE]],
 ):
-    '''
+    """
     Registers callables for tagging and deserializing storage objects with an associated priority.
     Tagging associates a device with a storage object at save time while deserializing moves a
     storage object to an appropriate device at load time. :attr:`tagger` and :attr:`deserializer`
@@ -228,14 +266,16 @@
         >>>         assert torch.ipu.is_available(), "ipu is not available"
         >>>         return obj.ipu(location)
         >>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize)
-    '''
+    """
     queue_elem = (priority, tagger, deserializer)
     _package_registry.append(queue_elem)
     _package_registry.sort()
 
 
-def check_module_version_greater_or_equal(module, req_version_tuple, error_if_malformed=True):
-    '''
+def check_module_version_greater_or_equal(
+    module, req_version_tuple, error_if_malformed=True
+):
+    """
     Check if a module's version satisfies requirements
 
     Usually, a module's version string will be like 'x.y.z', which would be represented
@@ -250,12 +290,13 @@
 
     Returns:
         requirement_is_met: bool
-    '''
+    """
     try:
-        version_strs = module.__version__.split('.')
+        version_strs = module.__version__.split(".")
         # Cast module version fields to match the types of the required version
         module_version = tuple(
-            type(req_field)(version_strs[idx]) for idx, req_field in enumerate(req_version_tuple)
+            type(req_field)(version_strs[idx])
+            for idx, req_field in enumerate(req_version_tuple)
         )
         requirement_is_met = module_version >= req_version_tuple
 
@@ -267,54 +308,54 @@
         if error_if_malformed:
             raise RuntimeError(message) from e
         else:
-            warnings.warn(message + ', but continuing assuming that requirement is met')
+            warnings.warn(message + ", but continuing assuming that requirement is met")
             requirement_is_met = True
 
     return requirement_is_met
 
 
 def _cpu_tag(obj):
-    if obj.device.type == 'cpu':
-        return 'cpu'
+    if obj.device.type == "cpu":
+        return "cpu"
 
 
 def _mps_tag(obj):
-    if obj.device.type == 'mps':
-        return 'mps'
+    if obj.device.type == "mps":
+        return "mps"
 
 
 def _meta_tag(obj):
-    if obj.device.type == 'meta':
-        return 'meta'
+    if obj.device.type == "meta":
+        return "meta"
 
 
 def _backend_tag(backend_name, obj):
-    if backend_name == 'privateuse1':
+    if backend_name == "privateuse1":
         backend_name = torch._C._get_privateuse1_backend_name()
     if obj.device.type == backend_name:
         if obj.device.index is None:
             return backend_name
         else:
-            return backend_name + ':' + str(obj.device.index)
+            return backend_name + ":" + str(obj.device.index)
 
 
 def _cpu_deserialize(obj, location):
-    if location == 'cpu':
+    if location == "cpu":
         return obj
 
 
 def _mps_deserialize(obj, location):
-    if location.startswith('mps'):
+    if location.startswith("mps"):
         return obj.mps()
 
 
 def _meta_deserialize(obj, location):
-    if location == 'meta':
-        return torch.UntypedStorage(obj.nbytes(), device='meta')
+    if location == "meta":
+        return torch.UntypedStorage(obj.nbytes(), device="meta")
 
 
 def _validate_device(location, backend_name):
-    '''
+    """
     Check whether the device index of specified backend is valid
 
     In case of privateuse1 backend, your must first register a device_module for
@@ -328,45 +369,53 @@
 
     Returns:
         device_index: int
-    '''
+    """
     if not hasattr(torch, backend_name):
-        raise RuntimeError(f'The {backend_name.upper()} device module is not registered. '
-                           'If you are running on a CPU-only machine, '
-                           'please use torch.load with map_location=torch.device(\'cpu\') '
-                           'to map your storages to the CPU.')
+        raise RuntimeError(
+            f"The {backend_name.upper()} device module is not registered. "
+            "If you are running on a CPU-only machine, "
+            "please use torch.load with map_location=torch.device('cpu') "
+            "to map your storages to the CPU."
+        )
     device_module = getattr(torch, backend_name)
-    if hasattr(device_module, '_utils') and hasattr(device_module._utils, '_get_device_index'):
+    if hasattr(device_module, "_utils") and hasattr(
+        device_module._utils, "_get_device_index"
+    ):
         device_index = device_module._utils._get_device_index(location, True)
         device = torch.device(backend_name, device_index)
     else:
         device = torch.device(location)
         device_index = device.index if device.index else 0
-    if hasattr(device_module, 'is_available') and not device_module.is_available():
-        raise RuntimeError(f'Attempting to deserialize object on a {backend_name.upper()} '
-                           f'device but torch.{backend_name}.is_available() is False. '
-                           'If you are running on a CPU-only machine, '
-                           'please use torch.load with map_location=torch.device(\'cpu\') '
-                           'to map your storages to the CPU.')
-    if hasattr(device_module, 'device_count'):
+    if hasattr(device_module, "is_available") and not device_module.is_available():
+        raise RuntimeError(
+            f"Attempting to deserialize object on a {backend_name.upper()} "
+            f"device but torch.{backend_name}.is_available() is False. "
+            "If you are running on a CPU-only machine, "
+            "please use torch.load with map_location=torch.device('cpu') "
+            "to map your storages to the CPU."
+        )
+    if hasattr(device_module, "device_count"):
         device_count = device_module.device_count()
         if device_index >= device_count:
-            raise RuntimeError(f'Attempting to deserialize object on {backend_name.upper()} device '
-                               f'{device_index} but torch.{backend_name}.device_count() is {device_count}. '
-                               'Please use torch.load with map_location to map your storages '
-                               'to an existing device.')
+            raise RuntimeError(
+                f"Attempting to deserialize object on {backend_name.upper()} device "
+                f"{device_index} but torch.{backend_name}.device_count() is {device_count}. "
+                "Please use torch.load with map_location to map your storages "
+                "to an existing device."
+            )
     return device
 
 
 def validate_cuda_device(location):
-    return _validate_device(location, 'cuda').index
+    return _validate_device(location, "cuda").index
 
 
 def validate_hpu_device(location):
-    return _validate_device(location, 'hpu').index
+    return _validate_device(location, "hpu").index
 
 
 def _deserialize(backend_name, obj, location):
-    if backend_name == 'privateuse1':
+    if backend_name == "privateuse1":
         backend_name = torch._C._get_privateuse1_backend_name()
     if location.startswith(backend_name):
         device = _validate_device(location, backend_name)
@@ -374,20 +423,34 @@
 
 
 register_package(10, _cpu_tag, _cpu_deserialize)
-register_package(20, functools.partial(_backend_tag, 'cuda'), functools.partial(_deserialize, 'cuda'))
+register_package(
+    20, functools.partial(_backend_tag, "cuda"), functools.partial(_deserialize, "cuda")
+)
 register_package(21, _mps_tag, _mps_deserialize)
 register_package(22, _meta_tag, _meta_deserialize)
-register_package(23, functools.partial(_backend_tag, 'privateuse1'), functools.partial(_deserialize, 'privateuse1'))
-register_package(24, functools.partial(_backend_tag, 'hpu'), functools.partial(_deserialize, 'hpu'))
-register_package(25, functools.partial(_backend_tag, 'xpu'), functools.partial(_deserialize, 'xpu'))
+register_package(
+    23,
+    functools.partial(_backend_tag, "privateuse1"),
+    functools.partial(_deserialize, "privateuse1"),
+)
+register_package(
+    24, functools.partial(_backend_tag, "hpu"), functools.partial(_deserialize, "hpu")
+)
+register_package(
+    25, functools.partial(_backend_tag, "xpu"), functools.partial(_deserialize, "xpu")
+)
 
-def location_tag(storage: Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage]):
+
+def location_tag(
+    storage: Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage]
+):
     for _, tagger, _ in _package_registry:
         location = tagger(storage)
         if location:
             return location
-    raise RuntimeError("don't know how to determine data location of "
-                       + torch.typename(storage))
+    raise RuntimeError(
+        "don't know how to determine data location of " + torch.typename(storage)
+    )
 
 
 def default_restore_location(storage, location):
@@ -414,9 +477,13 @@
         result = fn(storage, location)
         if result is not None:
             return result
-    raise RuntimeError("don't know how to restore data location of "
-                       + torch.typename(storage) + " (tagged with "
-                       + location + ")")
+    raise RuntimeError(
+        "don't know how to restore data location of "
+        + torch.typename(storage)
+        + " (tagged with "
+        + location
+        + ")"
+    )
 
 
 def normalize_storage_type(storage_type):
@@ -426,7 +493,7 @@
 def storage_to_tensor_type(storage):
     storage_type = type(storage)
     module = _import_dotted_name(storage_type.__module__)
-    return getattr(module, storage_type.__name__.replace('Storage', 'Tensor'))
+    return getattr(module, storage_type.__name__.replace("Storage", "Tensor"))
 
 
 def _is_path(name_or_buffer) -> TypeGuard[Union[str, os.PathLike]]:
@@ -467,9 +534,9 @@
     if _is_path(name_or_buffer):
         return _open_file(name_or_buffer, mode)
     else:
-        if 'w' in mode:
+        if "w" in mode:
             return _open_buffer_writer(name_or_buffer)
-        elif 'r' in mode:
+        elif "r" in mode:
             return _open_buffer_reader(name_or_buffer)
         else:
             raise RuntimeError(f"Expected 'r' or 'w' in mode but got {mode}")
@@ -485,12 +552,12 @@
         self.file_stream = None
         self.name = str(name)
         try:
-            self.name.encode('ascii')
+            self.name.encode("ascii")
         except UnicodeEncodeError:
             # PyTorchFileWriter only supports ascii filename.
             # For filenames with non-ascii characters, we rely on Python
             # for writing out the file.
-            self.file_stream = io.FileIO(self.name, mode='w')
+            self.file_stream = io.FileIO(self.name, mode="w")
             super().__init__(torch._C.PyTorchFileWriter(self.file_stream))
         else:
             super().__init__(torch._C.PyTorchFileWriter(self.name))
@@ -526,7 +593,7 @@
 
 
 def _is_compressed_file(f) -> bool:
-    compress_modules = ['gzip']
+    compress_modules = ["gzip"]
     try:
         return f.__module__ in compress_modules
     except AttributeError:
@@ -550,13 +617,15 @@
 
 
 def _check_seekable(f) -> bool:
-
     def raise_err_msg(patterns, e):
         for p in patterns:
             if p in str(e):
-                msg = (str(e) + ". You can only torch.load from a file that is seekable."
-                                + " Please pre-load the data into a buffer like io.BytesIO and"
-                                + " try to load from it instead.")
+                msg = (
+                    str(e)
+                    + ". You can only torch.load from a file that is seekable."
+                    + " Please pre-load the data into a buffer like io.BytesIO and"
+                    + " try to load from it instead."
+                )
                 raise type(e)(msg)
         raise e
 
@@ -569,30 +638,35 @@
 
 
 def _check_dill_version(pickle_module) -> None:
-    '''Checks if using dill as the pickle module, and if so, checks if it is the correct version.
+    """Checks if using dill as the pickle module, and if so, checks if it is the correct version.
     If dill version is lower than 0.3.1, a ValueError is raised.
 
     Args:
         pickle_module: module used for pickling metadata and objects
 
-    '''
-    if pickle_module is not None and pickle_module.__name__ == 'dill':
+    """
+    if pickle_module is not None and pickle_module.__name__ == "dill":
         required_dill_version = (0, 3, 1)
-        if not check_module_version_greater_or_equal(pickle_module, required_dill_version, False):
-            raise ValueError((
-                "'torch' supports dill >= {}, but you have dill {}."
-                " Please upgrade dill or switch to 'pickle'"
-            ).format(
-                '.'.join([str(num) for num in required_dill_version]),
-                pickle_module.__version__
-            ))
+        if not check_module_version_greater_or_equal(
+            pickle_module, required_dill_version, False
+        ):
+            raise ValueError(
+                (
+                    "'torch' supports dill >= {}, but you have dill {}."
+                    " Please upgrade dill or switch to 'pickle'"
+                ).format(
+                    ".".join([str(num) for num in required_dill_version]),
+                    pickle_module.__version__,
+                )
+            )
 
 
 def _check_save_filelike(f):
-    if not _is_path(f) and not hasattr(f, 'write'):
+    if not _is_path(f) and not hasattr(f, "write"):
         raise AttributeError(
             "expected 'f' to be string, path, or a file-like object with "
-            "a 'write' attribute")
+            "a 'write' attribute"
+        )
 
 
 def save(
@@ -601,7 +675,7 @@
     pickle_module: Any = pickle,
     pickle_protocol: int = DEFAULT_PROTOCOL,
     _use_new_zipfile_serialization: bool = True,
-    _disable_byteorder_record: bool = False
+    _disable_byteorder_record: bool = False,
 ) -> None:
     # Reference: https://github.com/pytorch/pytorch/issues/54354
     # The first line of this docstring overrides the one Sphinx generates for the
@@ -649,15 +723,22 @@
 
     if _use_new_zipfile_serialization:
         with _open_zipfile_writer(f) as opened_zipfile:
-            _save(obj, opened_zipfile, pickle_module, pickle_protocol, _disable_byteorder_record)
+            _save(
+                obj,
+                opened_zipfile,
+                pickle_module,
+                pickle_protocol,
+                _disable_byteorder_record,
+            )
             return
     else:
-        with _open_file_like(f, 'wb') as opened_file:
+        with _open_file_like(f, "wb") as opened_file:
             _legacy_save(obj, opened_file, pickle_module, pickle_protocol)
 
 
 def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
     import torch.nn as nn
+
     serialized_container_types = {}
     serialized_storages = {}
 
@@ -680,12 +761,16 @@
             source_file = source = None
             try:
                 source_lines, _, source_file = get_source_lines_and_file(obj)
-                source = ''.join(source_lines)
-            except Exception:  # saving the source is optional, so we can ignore any errors
-                warnings.warn("Couldn't retrieve source code for container of "
-                              "type " + obj.__name__ + ". It won't be checked "
-                              "for correctness upon loading.")
-            return ('module', obj, source_file, source)
+                source = "".join(source_lines)
+            except (
+                Exception
+            ):  # saving the source is optional, so we can ignore any errors
+                warnings.warn(
+                    "Couldn't retrieve source code for container of "
+                    "type " + obj.__name__ + ". It won't be checked "
+                    "for correctness upon loading."
+                )
+            return ("module", obj, source_file, source)
 
         if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
             storage: torch.UntypedStorage
@@ -707,7 +792,7 @@
                 dtype = torch.uint8
                 storage_numel = storage.nbytes()
             else:
-                raise TypeError(f'type not recognized: {type(obj)}')
+                raise TypeError(f"type not recognized: {type(obj)}")
 
             # If storage is allocated, ensure that any other saved storages
             # pointing to the same data all have the same dtype. If storage is
@@ -716,8 +801,9 @@
                 if storage.data_ptr() in storage_dtypes:
                     if storage_dtype != storage_dtypes[storage.data_ptr()]:
                         raise RuntimeError(
-                            'Cannot save multiple tensors or storages that '
-                            'view the same data as different types')
+                            "Cannot save multiple tensors or storages that "
+                            "view the same data as different types"
+                        )
                 else:
                     storage_dtypes[storage.data_ptr()] = storage_dtype
 
@@ -766,18 +852,20 @@
             else:
                 view_metadata = None
 
-            res = ('storage',
-                   storage_type,
-                   storage_key,
-                   location,
-                   storage_numel,
-                   view_metadata)
+            res = (
+                "storage",
+                storage_type,
+                storage_key,
+                location,
+                storage_numel,
+                view_metadata,
+            )
             return res
         return None
 
     sys_info = dict(
         protocol_version=PROTOCOL_VERSION,
-        little_endian=sys.byteorder == 'little',
+        little_endian=sys.byteorder == "little",
         type_sizes=dict(
             short=SHORT_SIZE,
             int=INT_SIZE,
@@ -797,7 +885,9 @@
     f.flush()
     for key in serialized_storage_keys:
         storage, dtype = serialized_storages[key]
-        storage._write_file(f, _should_read_directly(f), True, torch._utils._element_size(dtype))
+        storage._write_file(
+            f, _should_read_directly(f), True, torch._utils._element_size(dtype)
+        )
 
 
 def _save(obj, zip_file, pickle_module, pickle_protocol, _disable_byteorder_record):
@@ -817,7 +907,6 @@
         # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
         # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
         if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
-
             if isinstance(obj, torch.storage.TypedStorage):
                 # TODO: Once we decide to break serialization FC, this case
                 # can be deleted
@@ -840,8 +929,9 @@
                 if storage.data_ptr() in storage_dtypes:
                     if storage_dtype != storage_dtypes[storage.data_ptr()]:
                         raise RuntimeError(
-                            'Cannot save multiple tensors or storages that '
-                            'view the same data as different types')
+                            "Cannot save multiple tensors or storages that "
+                            "view the same data as different types"
+                        )
                 else:
                     storage_dtypes[storage.data_ptr()] = storage_dtype
 
@@ -849,11 +939,7 @@
             location = location_tag(storage)
             serialized_storages[storage_key] = storage
 
-            return ('storage',
-                    storage_type,
-                    storage_key,
-                    location,
-                    storage_numel)
+            return ("storage", storage_type, storage_key, location, storage_numel)
 
         return None
 
@@ -863,23 +949,23 @@
     pickler.persistent_id = persistent_id
     pickler.dump(obj)
     data_value = data_buf.getvalue()
-    zip_file.write_record('data.pkl', data_value, len(data_value))
+    zip_file.write_record("data.pkl", data_value, len(data_value))
 
     # Write byte order marker
     if not _disable_byteorder_record:
-        if sys.byteorder not in ['little', 'big']:
-            raise ValueError('Unknown endianness type: ' + sys.byteorder)
+        if sys.byteorder not in ["little", "big"]:
+            raise ValueError("Unknown endianness type: " + sys.byteorder)
 
-        zip_file.write_record('byteorder', sys.byteorder, len(sys.byteorder))
+        zip_file.write_record("byteorder", sys.byteorder, len(sys.byteorder))
 
     # Write each tensor to a file named tensor/the_tensor_key in the zip archive
     for key in sorted(serialized_storages.keys()):
-        name = f'data/{key}'
+        name = f"data/{key}"
         storage = serialized_storages[key]
         # given that we copy things around anyway, we might use storage.cpu()
         # this means to that to get tensors serialized, you need to implement
         # .cpu() on the underlying Storage
-        if storage.device.type != 'cpu':
+        if storage.device.type != "cpu":
             storage = storage.cpu()
         # Now that it is on the CPU we can directly copy it into the zip file
         num_bytes = storage.nbytes()
@@ -893,7 +979,7 @@
     *,
     weights_only: bool = False,
     mmap: Optional[bool] = None,
-    **pickle_load_args: Any
+    **pickle_load_args: Any,
 ) -> Any:
     # Reference: https://github.com/pytorch/pytorch/issues/54354
     # The first line of this docstring overrides the one Sphinx generates for the
@@ -1002,12 +1088,19 @@
         " WeightsUnpickler error: "
     )
     # Add ability to force safe only weight loads via environment variable
-    if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() in ['1', 'y', 'yes', 'true']:
+    if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() in [
+        "1",
+        "y",
+        "yes",
+        "true",
+    ]:
         weights_only = True
 
     if weights_only:
         if pickle_module is not None:
-            raise RuntimeError("Can not safely load weights when explicit pickle_module is specified")
+            raise RuntimeError(
+                "Can not safely load weights when explicit pickle_module is specified"
+            )
     else:
         if pickle_module is None:
             pickle_module = pickle
@@ -1018,10 +1111,10 @@
 
     _check_dill_version(pickle_module)
 
-    if 'encoding' not in pickle_load_args.keys():
-        pickle_load_args['encoding'] = 'utf-8'
+    if "encoding" not in pickle_load_args.keys():
+        pickle_load_args["encoding"] = "utf-8"
 
-    with _open_file_like(f, 'rb') as opened_file:
+    with _open_file_like(f, "rb") as opened_file:
         if _is_zipfile(opened_file):
             # The zipfile reader is going to advance the current file position.
             # If we want to actually tail call to torch.jit.load, we need to
@@ -1030,61 +1123,81 @@
             overall_storage = None
             with _open_zipfile_reader(opened_file) as opened_zipfile:
                 if _is_torchscript_zip(opened_zipfile):
-                    warnings.warn("'torch.load' received a zip file that looks like a TorchScript archive"
-                                  " dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to"
-                                  " silence this warning)", UserWarning)
+                    warnings.warn(
+                        "'torch.load' received a zip file that looks like a TorchScript archive"
+                        " dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to"
+                        " silence this warning)",
+                        UserWarning,
+                    )
                     opened_file.seek(orig_position)
                     return torch.jit.load(opened_file, map_location=map_location)
                 if mmap:
                     if not _is_path(f):
-                        raise ValueError("f must be a file path in order to use the mmap argument")
+                        raise ValueError(
+                            "f must be a file path in order to use the mmap argument"
+                        )
                     size = os.path.getsize(f)
                     if not IS_WINDOWS:
                         shared = get_default_mmap_options() == MAP_SHARED
                     else:
                         shared = False
-                    overall_storage = torch.UntypedStorage.from_file(os.fspath(f), shared, size)
+                    overall_storage = torch.UntypedStorage.from_file(
+                        os.fspath(f), shared, size
+                    )
                 if weights_only:
                     try:
-                        return _load(opened_zipfile,
-                                     map_location,
-                                     _weights_only_unpickler,
-                                     overall_storage=overall_storage,
-                                     **pickle_load_args)
+                        return _load(
+                            opened_zipfile,
+                            map_location,
+                            _weights_only_unpickler,
+                            overall_storage=overall_storage,
+                            **pickle_load_args,
+                        )
                     except RuntimeError as e:
                         raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None
-                return _load(opened_zipfile,
-                             map_location,
-                             pickle_module,
-                             overall_storage=overall_storage,
-                             **pickle_load_args)
+                return _load(
+                    opened_zipfile,
+                    map_location,
+                    pickle_module,
+                    overall_storage=overall_storage,
+                    **pickle_load_args,
+                )
         if mmap:
             f_name = "" if not isinstance(f, str) else f"{f}, "
-            raise RuntimeError("mmap can only be used with files saved with "
-                               f"`torch.save({f_name}_use_new_zipfile_serialization=True), "
-                               "please torch.save your checkpoint with this option in order to use mmap.")
+            raise RuntimeError(
+                "mmap can only be used with files saved with "
+                f"`torch.save({f_name}_use_new_zipfile_serialization=True), "
+                "please torch.save your checkpoint with this option in order to use mmap."
+            )
         if weights_only:
             try:
-                return _legacy_load(opened_file, map_location, _weights_only_unpickler, **pickle_load_args)
+                return _legacy_load(
+                    opened_file,
+                    map_location,
+                    _weights_only_unpickler,
+                    **pickle_load_args,
+                )
             except RuntimeError as e:
                 raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None
-        return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
+        return _legacy_load(
+            opened_file, map_location, pickle_module, **pickle_load_args
+        )
 
 
 # Register pickling support for layout instances such as
 # torch.sparse_coo, etc
 def _get_layout(name):
-    """Get layout extension object from its string representation.
-    """
-    cache = _get_layout.cache   # type: ignore[attr-defined]
+    """Get layout extension object from its string representation."""
+    cache = _get_layout.cache  # type: ignore[attr-defined]
     if not cache:
         for v in torch.__dict__.values():
             if isinstance(v, torch.layout):
                 cache[str(v)] = v
     return cache[name]
 
+
 # There are yet not good way to type annotate function attributes https://github.com/python/mypy/issues/2087
-_get_layout.cache = {}   # type: ignore[attr-defined]
+_get_layout.cache = {}  # type: ignore[attr-defined]
 copyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),)))
 
 
@@ -1094,9 +1207,8 @@
     restore_location = _get_restore_location(map_location)
 
     class UnpicklerWrapper(pickle_module.Unpickler):  # type: ignore[name-defined]
-
         def find_class(self, mod_name, name):
-            if type(name) is str and 'Storage' in name:
+            if type(name) is str and "Storage" in name:
                 try:
                     return StorageType(name)
                 except KeyError:
@@ -1105,41 +1217,52 @@
 
     def _check_container_source(container_type, source_file, original_source):
         try:
-            current_source = ''.join(get_source_lines_and_file(container_type)[0])
+            current_source = "".join(get_source_lines_and_file(container_type)[0])
         except Exception:  # saving the source is optional, so we can ignore any errors
-            warnings.warn("Couldn't retrieve source code for container of "
-                          "type " + container_type.__name__ + ". It won't be checked "
-                          "for correctness upon loading.")
+            warnings.warn(
+                "Couldn't retrieve source code for container of "
+                "type " + container_type.__name__ + ". It won't be checked "
+                "for correctness upon loading."
+            )
             return
         if original_source != current_source:
             if container_type.dump_patches:
-                file_name = container_type.__name__ + '.patch'
-                diff = difflib.unified_diff(current_source.split('\n'),
-                                            original_source.split('\n'),
-                                            source_file,
-                                            source_file, lineterm="")
-                lines = '\n'.join(diff)
+                file_name = container_type.__name__ + ".patch"
+                diff = difflib.unified_diff(
+                    current_source.split("\n"),
+                    original_source.split("\n"),
+                    source_file,
+                    source_file,
+                    lineterm="",
+                )
+                lines = "\n".join(diff)
                 try:
-                    with open(file_name, 'a+') as f:
+                    with open(file_name, "a+") as f:
                         file_size = f.seek(0, 2)
                         f.seek(0)
                         if file_size == 0:
                             f.write(lines)
                         elif file_size != len(lines) or f.read() != lines:
                             raise OSError
-                    msg = ("Saved a reverse patch to " + file_name + ". "
-                           "Run `patch -p0 < " + file_name + "` to revert your "
-                           "changes.")
+                    msg = (
+                        "Saved a reverse patch to " + file_name + ". "
+                        "Run `patch -p0 < " + file_name + "` to revert your "
+                        "changes."
+                    )
                 except OSError:
-                    msg = ("Tried to save a patch, but couldn't create a "
-                           "writable file " + file_name + ". Make sure it "
-                           "doesn't exist and your working directory is "
-                           "writable.")
+                    msg = (
+                        "Tried to save a patch, but couldn't create a "
+                        "writable file " + file_name + ". Make sure it "
+                        "doesn't exist and your working directory is "
+                        "writable."
+                    )
             else:
-                msg = ("you can retrieve the original source code by "
-                       "accessing the object's source attribute or set "
-                       "`torch.nn.Module.dump_patches = True` and use the "
-                       "patch tool to revert the changes.")
+                msg = (
+                    "you can retrieve the original source code by "
+                    "accessing the object's source attribute or set "
+                    "`torch.nn.Module.dump_patches = True` and use the "
+                    "patch tool to revert the changes."
+                )
             msg = f"source code of class '{torch.typename(container_type)}' has changed. {msg}"
             warnings.warn(msg, SourceChangeWarning)
 
@@ -1154,24 +1277,25 @@
                 return saved_id[0]
             return deserialized_objects[int(saved_id)]
 
-        with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \
-                mkdtemp() as tmpdir:
-
-            tar.extract('storages', path=tmpdir)
-            with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f:
+        with closing(
+            tarfile.open(fileobj=f, mode="r:", format=tarfile.PAX_FORMAT)
+        ) as tar, mkdtemp() as tmpdir:
+            tar.extract("storages", path=tmpdir)
+            with open(os.path.join(tmpdir, "storages"), "rb", 0) as f:
                 num_storages = pickle_module.load(f, **pickle_load_args)
                 for i in range(num_storages):
                     args = pickle_module.load(f, **pickle_load_args)
                     key, location, storage_type = args
                     dtype = storage_type._dtype
-                    obj = cast(Storage, torch.UntypedStorage)._new_with_file(f, torch._utils._element_size(dtype))
+                    obj = cast(Storage, torch.UntypedStorage)._new_with_file(
+                        f, torch._utils._element_size(dtype)
+                    )
                     obj = restore_location(obj, location)
                     # TODO: Once we decide to break serialization FC, we can
                     # stop wrapping with TypedStorage
                     deserialized_objects[key] = torch.storage.TypedStorage(
-                        wrap_storage=obj,
-                        dtype=dtype,
-                        _internal=True)
+                        wrap_storage=obj, dtype=dtype, _internal=True
+                    )
 
                 storage_views = pickle_module.load(f, **pickle_load_args)
                 for target_cdata, root_cdata, offset, numel in storage_views:
@@ -1181,28 +1305,32 @@
                     # TODO: Once we decide to break serialization FC, we can
                     # stop wrapping with TypedStorage
                     deserialized_objects[target_cdata] = torch.storage.TypedStorage(
-                        wrap_storage=root._untyped_storage[offset_bytes:offset_bytes + numel * element_size],
+                        wrap_storage=root._untyped_storage[
+                            offset_bytes : offset_bytes + numel * element_size
+                        ],
                         dtype=root.dtype,
-                        _internal=True)
+                        _internal=True,
+                    )
 
-            tar.extract('tensors', path=tmpdir)
-            with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f:
+            tar.extract("tensors", path=tmpdir)
+            with open(os.path.join(tmpdir, "tensors"), "rb", 0) as f:
                 num_tensors = pickle_module.load(f, **pickle_load_args)
                 for _ in range(num_tensors):
                     args = pickle_module.load(f, **pickle_load_args)
                     key, storage_id, original_tensor_type = args
                     storage = deserialized_objects[storage_id]
-                    ndim, = struct.unpack('<i', f.read(4))
+                    (ndim,) = struct.unpack("<i", f.read(4))
                     # skip next 4 bytes; legacy encoding treated ndim as 8 bytes
                     f.read(4)
-                    numel = struct.unpack(f'<{ndim}q', f.read(8 * ndim))
-                    stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim))
-                    storage_offset, = struct.unpack('<q', f.read(8))
+                    numel = struct.unpack(f"<{ndim}q", f.read(8 * ndim))
+                    stride = struct.unpack(f"<{ndim}q", f.read(8 * ndim))
+                    (storage_offset,) = struct.unpack("<q", f.read(8))
                     tensor = torch.empty((0,), dtype=storage.dtype).set_(
-                        storage._untyped_storage, storage_offset, numel, stride)
+                        storage._untyped_storage, storage_offset, numel, stride
+                    )
                     deserialized_objects[key] = tensor
 
-            pickle_file = tar.extractfile('pickle')
+            pickle_file = tar.extractfile("pickle")
             unpickler = UnpicklerWrapper(pickle_file, **pickle_load_args)
             unpickler.persistent_load = persistent_load
             result = unpickler.load()
@@ -1215,12 +1343,12 @@
         typename = _maybe_decode_ascii(saved_id[0])
         data = saved_id[1:]
 
-        if typename == 'module':
+        if typename == "module":
             # Ignore containers that don't have any sources saved
             if all(data[1:]):
                 _check_container_source(*data)
             return data[0]
-        elif typename == 'storage':
+        elif typename == "storage":
             storage_type, root_key, location, numel, view_metadata = data
             location = _maybe_decode_ascii(location)
             dtype = storage_type.dtype
@@ -1229,7 +1357,7 @@
 
             if root_key not in deserialized_objects:
                 if torch._guards.active_fake_mode() is not None:
-                    obj = cast(Storage, torch.UntypedStorage(nbytes, device='meta'))
+                    obj = cast(Storage, torch.UntypedStorage(nbytes, device="meta"))
                 else:
                     obj = cast(Storage, torch.UntypedStorage(nbytes))
                     obj._torch_load_uninitialized = True
@@ -1237,9 +1365,8 @@
                 # TODO: Once we decide to break serialization FC, we can
                 # stop wrapping with TypedStorage
                 typed_storage = torch.storage.TypedStorage(
-                    wrap_storage=obj,
-                    dtype=dtype,
-                    _internal=True)
+                    wrap_storage=obj, dtype=dtype, _internal=True
+                )
                 deserialized_objects[root_key] = typed_storage
             else:
                 typed_storage = deserialized_objects[root_key]
@@ -1247,7 +1374,8 @@
                     typed_storage = torch.storage.TypedStorage(
                         device=typed_storage._untyped_storage.device,
                         dtype=dtype,
-                        _internal=True)
+                        _internal=True,
+                    )
 
             if view_metadata is not None:
                 view_key, offset, view_size = view_metadata
@@ -1257,9 +1385,12 @@
                     # TODO: Once we decide to break serialization FC, we can
                     # stop wrapping with TypedStorage
                     deserialized_objects[view_key] = torch.storage.TypedStorage(
-                        wrap_storage=typed_storage._untyped_storage[offset_bytes:offset_bytes + view_size_bytes],
+                        wrap_storage=typed_storage._untyped_storage[
+                            offset_bytes : offset_bytes + view_size_bytes
+                        ],
                         dtype=dtype,
-                        _internal=True)
+                        _internal=True,
+                    )
                 res = deserialized_objects[view_key]
 
             else:
@@ -1280,15 +1411,17 @@
             if _is_zipfile(f):
                 # .zip is used for torch.jit.save and will throw an un-pickling error here
                 raise RuntimeError(
-                    f"{f.name} is a zip archive (did you mean to use torch.jit.load()?)") from None
+                    f"{f.name} is a zip archive (did you mean to use torch.jit.load()?)"
+                ) from None
             # if not a tarfile, reset file offset and proceed
             f.seek(0)
 
-    if not hasattr(f, 'readinto') and (3, 8, 0) <= sys.version_info < (3, 8, 2):
+    if not hasattr(f, "readinto") and (3, 8, 0) <= sys.version_info < (3, 8, 2):
         raise RuntimeError(
             "torch.load does not work with file-like objects that do not implement readinto on Python 3.8.0 and 3.8.1. "
             f'Received object of type "{type(f)}". Please update to Python 3.8.2 or newer to restore this '
-            "functionality.")
+            "functionality."
+        )
 
     magic_number = pickle_module.load(f, **pickle_load_args)
     if magic_number != MAGIC_NUMBER:
@@ -1310,8 +1443,11 @@
             assert key in deserialized_objects
             typed_storage = deserialized_objects[key]
             typed_storage._untyped_storage._set_from_file(
-                f, offset, f_should_read_directly,
-                torch._utils._element_size(typed_storage.dtype))
+                f,
+                offset,
+                f_should_read_directly,
+                torch._utils._element_size(typed_storage.dtype),
+            )
             if offset is not None:
                 offset = f.tell()
 
@@ -1328,7 +1464,7 @@
     # NOTE: This should only be used on internal keys (e.g., `typename` and
     #       `location` in `persistent_load` below!
     if isinstance(bytes_str, bytes):
-        return bytes_str.decode('ascii')
+        return bytes_str.decode("ascii")
     return bytes_str
 
 
@@ -1336,21 +1472,29 @@
     if map_location is None:
         restore_location = default_restore_location
     elif isinstance(map_location, dict):
+
         def restore_location(storage, location):
             location = map_location.get(location, location)
             return default_restore_location(storage, location)
+
     elif isinstance(map_location, (str, bytes)):
+
         def restore_location(storage, location):
             return default_restore_location(storage, map_location)
+
     elif isinstance(map_location, torch.device):
+
         def restore_location(storage, location):
             return default_restore_location(storage, str(map_location))
+
     else:
+
         def restore_location(storage, location):
             result = map_location(storage, location)
             if result is None:
                 result = default_restore_location(storage, location)
             return result
+
     return restore_location
 
 
@@ -1363,53 +1507,70 @@
         return self._dtype
 
     def __str__(self):
-        return f'StorageType(dtype={self.dtype})'
+        return f"StorageType(dtype={self.dtype})"
 
 
-def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', overall_storage=None, **pickle_load_args):
+def _load(
+    zip_file,
+    map_location,
+    pickle_module,
+    pickle_file="data.pkl",
+    overall_storage=None,
+    **pickle_load_args,
+):
     restore_location = _get_restore_location(map_location)
 
     loaded_storages = {}
 
     # check if byteswapping is needed
-    byteordername = 'byteorder'
+    byteordername = "byteorder"
     byteorderdata = None
     if zip_file.has_record(byteordername):
         byteorderdata = zip_file.get_record(byteordername)
-        if byteorderdata not in [b'little', b'big']:
-            raise ValueError('Unknown endianness type: ' + byteorderdata.decode())
-    elif get_default_load_endianness() == LoadEndianness.LITTLE or \
-            get_default_load_endianness() is None:
-        byteorderdata = b'little'
+        if byteorderdata not in [b"little", b"big"]:
+            raise ValueError("Unknown endianness type: " + byteorderdata.decode())
+    elif (
+        get_default_load_endianness() == LoadEndianness.LITTLE
+        or get_default_load_endianness() is None
+    ):
+        byteorderdata = b"little"
     elif get_default_load_endianness() == LoadEndianness.BIG:
-        byteorderdata = b'big'
+        byteorderdata = b"big"
     elif get_default_load_endianness() == LoadEndianness.NATIVE:
         pass
     else:
-        raise ValueError('Invalid load endianness type')
+        raise ValueError("Invalid load endianness type")
 
-    if not zip_file.has_record(byteordername) and \
-            get_default_load_endianness() is None and \
-            sys.byteorder == 'big':
+    if (
+        not zip_file.has_record(byteordername)
+        and get_default_load_endianness() is None
+        and sys.byteorder == "big"
+    ):
         # Default behaviour was changed
         # See https://github.com/pytorch/pytorch/issues/101688
-        warnings.warn("The default load endianness for checkpoints without a byteorder mark "
-                      "on big endian machines was changed from 'native' to 'little' endian, "
-                      "to avoid this behavior please use "
-                      "torch.serialization.set_default_load_endianness to set "
-                      "the desired default load endianness",
-                      UserWarning)
+        warnings.warn(
+            "The default load endianness for checkpoints without a byteorder mark "
+            "on big endian machines was changed from 'native' to 'little' endian, "
+            "to avoid this behavior please use "
+            "torch.serialization.set_default_load_endianness to set "
+            "the desired default load endianness",
+            UserWarning,
+        )
 
     def load_tensor(dtype, numel, key, location):
-        name = f'data/{key}'
+        name = f"data/{key}"
         if torch._guards.detect_fake_mode(None) is not None:
             nbytes = numel * torch._utils._element_size(dtype)
-            storage = torch.UntypedStorage(nbytes, device='meta')
+            storage = torch.UntypedStorage(nbytes, device="meta")
         elif overall_storage is not None:
             storage_offset = zip_file.get_record_offset(name)
-            storage = overall_storage[storage_offset:storage_offset + numel]
+            storage = overall_storage[storage_offset : storage_offset + numel]
         else:
-            storage = zip_file.get_storage_from_record(name, numel, torch.UntypedStorage)._typed_storage()._untyped_storage
+            storage = (
+                zip_file.get_storage_from_record(name, numel, torch.UntypedStorage)
+                ._typed_storage()
+                ._untyped_storage
+            )
         # swap here if byteswapping is needed
         if byteorderdata is not None:
             if byteorderdata.decode() != sys.byteorder:
@@ -1420,7 +1581,8 @@
         typed_storage = torch.storage.TypedStorage(
             wrap_storage=restore_location(storage, location),
             dtype=dtype,
-            _internal=True)
+            _internal=True,
+        )
 
         if typed_storage._data_ptr() != 0:
             loaded_storages[key] = typed_storage
@@ -1432,8 +1594,9 @@
         typename = _maybe_decode_ascii(saved_id[0])
         data = saved_id[1:]
 
-        assert typename == 'storage', \
-            f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
+        assert (
+            typename == "storage"
+        ), f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
         storage_type, key, location, numel = data
         if storage_type is torch.UntypedStorage:
             dtype = torch.uint8
@@ -1444,13 +1607,15 @@
             typed_storage = loaded_storages[key]
         else:
             nbytes = numel * torch._utils._element_size(dtype)
-            typed_storage = load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location))
+            typed_storage = load_tensor(
+                dtype, nbytes, key, _maybe_decode_ascii(location)
+            )
 
         return typed_storage
 
     load_module_mapping: Dict[str, str] = {
         # See https://github.com/pytorch/pytorch/pull/51633
-        'torch.tensor': 'torch._tensor'
+        "torch.tensor": "torch._tensor"
     }
 
     # Need to subclass Unpickler instead of directly monkey-patching the find_class method
@@ -1461,7 +1626,7 @@
         # Lets us override the imports that pickle uses when unpickling an object.
         # This is useful for maintaining BC if we change a module path that tensor instantiation relies on.
         def find_class(self, mod_name, name):
-            if type(name) is str and 'Storage' in name:
+            if type(name) is str and "Storage" in name:
                 try:
                     return StorageType(name)
                 except KeyError:
@@ -1488,4 +1653,4 @@
 
 
 def _is_torchscript_zip(zip_file):
-    return 'constants.pkl' in zip_file.get_all_records()
+    return "constants.pkl" in zip_file.get_all_records()
diff --git a/torch/torch_version.py b/torch/torch_version.py
index e8814cf..1b18f7d 100644
--- a/torch/torch_version.py
+++ b/torch/torch_version.py
@@ -2,8 +2,9 @@
 
 from typing import Any, Iterable
 
-from ._vendor.packaging.version import InvalidVersion, Version
-from .version import __version__ as internal_version
+from torch._vendor.packaging.version import InvalidVersion, Version
+from torch.version import __version__ as internal_version
+
 
 __all__ = ["TorchVersion"]
 
diff --git a/torch/types.py b/torch/types.py
index a522d62..67875f8 100644
--- a/torch/types.py
+++ b/torch/types.py
@@ -1,9 +1,14 @@
 # mypy: allow-untyped-defs
 import builtins
-from typing import Any, List, Optional, Sequence, Tuple, Union
+from typing import Any, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
 
 import torch
 
+
+if TYPE_CHECKING:
+    from torch.autograd.graph import GradientEdge
+
+
 # Convenience aliases for common composite types that we need
 # to talk about in PyTorch
 
@@ -11,8 +16,8 @@
 _TensorOrTensorsOrGradEdge = Union[
     torch.Tensor,
     Sequence[torch.Tensor],
-    "torch.autograd.graph.GradientEdge",
-    Sequence["torch.autograd.graph.GradientEdge"],
+    "GradientEdge",
+    Sequence["GradientEdge"],
 ]
 
 # In some cases, these basic types are shadowed by corresponding