[BE] enable UFMT for top-level files `torch/*.py` (#127707)
Part of #123062
- #123062
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127707
Approved by: https://github.com/ezyang
diff --git a/.lintrunner.toml b/.lintrunner.toml
index 92a7fc0..c28399b 100644
--- a/.lintrunner.toml
+++ b/.lintrunner.toml
@@ -1556,7 +1556,6 @@
'torch/distributed/tensor/parallel/style.py',
'torch/fft/__init__.py',
'torch/func/__init__.py',
- 'torch/functional.py',
'torch/futures/__init__.py',
'torch/fx/__init__.py',
'torch/fx/_compatibility.py',
@@ -1642,8 +1641,6 @@
'torch/fx/subgraph_rewriter.py',
'torch/fx/tensor_type.py',
'torch/fx/traceback.py',
- 'torch/hub.py',
- 'torch/library.py',
'torch/linalg/__init__.py',
'torch/monitor/__init__.py',
'torch/nested/__init__.py',
@@ -1767,11 +1764,6 @@
'torch/nn/utils/rnn.py',
'torch/nn/utils/spectral_norm.py',
'torch/nn/utils/weight_norm.py',
- 'torch/overrides.py',
- 'torch/quasirandom.py',
- 'torch/random.py',
- 'torch/return_types.py',
- 'torch/serialization.py',
'torch/signal/__init__.py',
'torch/signal/windows/__init__.py',
'torch/signal/windows/windows.py',
diff --git a/torch/_guards.py b/torch/_guards.py
index 9204170..917f5dc 100644
--- a/torch/_guards.py
+++ b/torch/_guards.py
@@ -2,7 +2,6 @@
from __future__ import annotations
import contextlib
-
import dataclasses
import enum
import functools
@@ -31,6 +30,7 @@
from torch.utils._traceback import CapturedTraceback
from torch.utils.weak import WeakTensorKeyDictionary
+
log = logging.getLogger(__name__)
@@ -40,7 +40,6 @@
# Import the following modules during type checking to enable code intelligence features,
# such as auto-completion in tools like pylance, even when these modules are not explicitly
# imported in user code.
-
import torch
@@ -176,7 +175,7 @@
def sort_key(self):
# Put the duplicate input guards at the end. The duplicate guards have
# two sources while guard.name only considers one source.
- from ._dynamo.guards import GuardBuilder
+ from torch._dynamo.guards import GuardBuilder
is_duplicate_input = (
isinstance(self.create_fn, functools.partial)
diff --git a/torch/_lobpcg.py b/torch/_lobpcg.py
index 3f7bdf4..6a14883 100644
--- a/torch/_lobpcg.py
+++ b/torch/_lobpcg.py
@@ -7,9 +7,8 @@
from typing import Dict, Optional, Tuple
import torch
-from torch import Tensor
-from . import _linalg_utils as _utils
-from .overrides import handle_torch_function, has_torch_function
+from torch import _linalg_utils as _utils, Tensor
+from torch.overrides import handle_torch_function, has_torch_function
__all__ = ["lobpcg"]
diff --git a/torch/_lowrank.py b/torch/_lowrank.py
index 4641c4c..bbe01ed 100644
--- a/torch/_lowrank.py
+++ b/torch/_lowrank.py
@@ -6,9 +6,8 @@
from typing import Optional, Tuple
import torch
-from torch import Tensor
-from . import _linalg_utils as _utils
-from .overrides import handle_torch_function, has_torch_function
+from torch import _linalg_utils as _utils, Tensor
+from torch.overrides import handle_torch_function, has_torch_function
def get_approximate_basis(
diff --git a/torch/_tensor.py b/torch/_tensor.py
index 5ea2985..36df3d6 100644
--- a/torch/_tensor.py
+++ b/torch/_tensor.py
@@ -761,22 +761,22 @@
return torch.norm(self, p, dim, keepdim, dtype=dtype)
def solve(self, other):
- from ._linalg_utils import solve
+ from torch._linalg_utils import solve
return solve(self, other)
def lstsq(self, other):
- from ._linalg_utils import lstsq
+ from torch._linalg_utils import lstsq
return lstsq(self, other)
def eig(self, eigenvectors=False):
- from ._linalg_utils import eig
+ from torch._linalg_utils import eig
return eig(self, eigenvectors=eigenvectors)
def symeig(self, eigenvectors=False):
- from ._linalg_utils import _symeig
+ from torch._linalg_utils import _symeig
return _symeig(self, eigenvectors=eigenvectors)
diff --git a/torch/functional.py b/torch/functional.py
index a836c06..20e1cf1 100644
--- a/torch/functional.py
+++ b/torch/functional.py
@@ -1,47 +1,46 @@
# mypy: allow-untyped-defs
-from typing import (
- List, Tuple, Optional, Union, Any, Sequence, TYPE_CHECKING
-)
-import operator
import itertools
+import operator
+from typing import Any, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
import torch
-from torch._C import _add_docstr
import torch.nn.functional as F
-from ._lowrank import svd_lowrank, pca_lowrank
-from .overrides import (
- has_torch_function, has_torch_function_unary, has_torch_function_variadic,
- handle_torch_function)
-from ._jit_internal import boolean_dispatch
-from ._jit_internal import _overload as overload
+from torch import _VF, Tensor
+from torch._C import _add_docstr
+from torch._jit_internal import _overload as overload, boolean_dispatch
+from torch._lowrank import pca_lowrank, svd_lowrank
+from torch.overrides import (
+ handle_torch_function,
+ has_torch_function,
+ has_torch_function_unary,
+ has_torch_function_variadic,
+)
-Tensor = torch.Tensor
-from torch import _VF
__all__ = [
- 'atleast_1d',
- 'atleast_2d',
- 'atleast_3d',
- 'align_tensors',
- 'broadcast_shapes',
- 'broadcast_tensors',
- 'cartesian_prod',
- 'block_diag',
- 'cdist',
- 'chain_matmul',
- 'einsum',
- 'istft',
- 'lu',
- 'norm',
- 'meshgrid',
- 'pca_lowrank',
- 'split',
- 'stft',
- 'svd_lowrank',
- 'tensordot',
- 'unique',
- 'unique_consecutive',
- 'unravel_index',
+ "atleast_1d",
+ "atleast_2d",
+ "atleast_3d",
+ "align_tensors",
+ "broadcast_shapes",
+ "broadcast_tensors",
+ "cartesian_prod",
+ "block_diag",
+ "cdist",
+ "chain_matmul",
+ "einsum",
+ "istft",
+ "lu",
+ "norm",
+ "meshgrid",
+ "pca_lowrank",
+ "split",
+ "stft",
+ "svd_lowrank",
+ "tensordot",
+ "unique",
+ "unique_consecutive",
+ "unravel_index",
]
@@ -124,16 +123,25 @@
if isinstance(shape, (tuple, list)):
for i in range(-1, -1 - len(shape), -1):
if shape[i] < 0:
- raise RuntimeError(f"Trying to create tensor with negative dimension ({shape[i]}): ({shape[i]})")
+ raise RuntimeError(
+ f"Trying to create tensor with negative dimension ({shape[i]}): ({shape[i]})"
+ )
# NB: result is initialized to 1 so this is effectively an
# equals one test
- if guard_size_oblivious(shape[i] == 1) or guard_size_oblivious(shape[i] == result[i]):
+ if guard_size_oblivious(shape[i] == 1) or guard_size_oblivious(
+ shape[i] == result[i]
+ ):
continue
if result[i] != 1:
- raise RuntimeError("Shape mismatch: objects cannot be broadcast to a single shape")
+ raise RuntimeError(
+ "Shape mismatch: objects cannot be broadcast to a single shape"
+ )
result[i] = shape[i]
else:
- raise RuntimeError("Input shapes should be of type ints, a tuple of ints, or a list of ints, got ", shape)
+ raise RuntimeError(
+ "Input shapes should be of type ints, a tuple of ints, or a list of ints, got ",
+ shape,
+ )
return torch.Size(result)
else:
# with implementation above, torch.jit.trace hardcodes the sizes which makes subsequent replays fail
@@ -188,7 +196,8 @@
"""
if has_torch_function_unary(tensor):
return handle_torch_function(
- split, (tensor,), tensor, split_size_or_sections, dim=dim)
+ split, (tensor,), tensor, split_size_or_sections, dim=dim
+ )
# Overwriting reason:
# This dispatches to two ATen functions depending on the type of
# split_size_or_sections. The branching code is in _tensor.py, which we
@@ -335,10 +344,13 @@
[ 0.3311, 5.5201, -3.0356]])
"""
import torch.backends.opt_einsum as opt_einsum
+
# This wrapper exists to support variadic args.
if len(args) < 2:
- raise ValueError('einsum(): must specify the equation string and at least one operand, '
- 'or at least one operand and its subscripts list')
+ raise ValueError(
+ "einsum(): must specify the equation string and at least one operand, "
+ "or at least one operand and its subscripts list"
+ )
equation = None
operands = None
@@ -350,19 +362,21 @@
# input operands into a tensorlist (List[Tensor]).
def parse_subscript(n: int) -> str:
if n == Ellipsis:
- return '...'
+ return "..."
if n >= 0 and n < 26:
- return chr(ord('A') + n)
+ return chr(ord("A") + n)
if n >= 26 and n < 52:
- return chr(ord('a') + n - 26)
- raise ValueError('einsum(): subscript in subscript list is not within the valid range [0, 52)')
+ return chr(ord("a") + n - 26)
+ raise ValueError(
+ "einsum(): subscript in subscript list is not within the valid range [0, 52)"
+ )
# Parse subscripts for input operands
- equation = ','.join(''.join(parse_subscript(s) for s in l) for l in args[1::2])
+ equation = ",".join("".join(parse_subscript(s) for s in l) for l in args[1::2])
# Parse optional output subscripts (provided when the number of arguments is odd)
if len(args) % 2 == 1:
- equation += '->' + ''.join(parse_subscript(s) for s in args[-1])
+ equation += "->" + "".join(parse_subscript(s) for s in args[-1])
operands = args[:-1:2]
else:
operands = args[::2]
@@ -388,7 +402,9 @@
path = None
if opt_einsum.is_available():
_opt_einsum = opt_einsum.get_opt_einsum()
- tupled_path = _opt_einsum.contract_path(equation, *operands, optimize=opt_einsum.strategy)[0]
+ tupled_path = _opt_einsum.contract_path(
+ equation, *operands, optimize=opt_einsum.strategy
+ )[0]
# flatten path for dispatching to C++
path = [item for pair in tupled_path for item in pair]
return _VF.einsum(equation, operands, path=path) # type: ignore[attr-defined]
@@ -397,10 +413,13 @@
# This wrapper exists to support variadic args.
if TYPE_CHECKING:
# The JIT doesn't understand Union, so only add type annotation for mypy
- def meshgrid(*tensors: Union[Tensor, List[Tensor]],
- indexing: Optional[str] = None) -> Tuple[Tensor, ...]:
+ def meshgrid(
+ *tensors: Union[Tensor, List[Tensor]], indexing: Optional[str] = None
+ ) -> Tuple[Tensor, ...]:
return _meshgrid(*tensors, indexing=indexing)
+
else:
+
def meshgrid(*tensors, indexing: Optional[str] = None) -> Tuple[Tensor, ...]:
r"""Creates grids of coordinates specified by the 1D inputs in `attr`:tensors.
@@ -509,15 +528,22 @@
# kwarg for forward compatibility reasons.
#
# Remove this two weeks after landing.
- kwargs = {} if indexing is None else {'indexing': indexing}
+ kwargs = {} if indexing is None else {"indexing": indexing}
return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
-def stft(input: Tensor, n_fft: int, hop_length: Optional[int] = None,
- win_length: Optional[int] = None, window: Optional[Tensor] = None,
- center: bool = True, pad_mode: str = 'reflect', normalized: bool = False,
- onesided: Optional[bool] = None,
- return_complex: Optional[bool] = None) -> Tensor:
+def stft(
+ input: Tensor,
+ n_fft: int,
+ hop_length: Optional[int] = None,
+ win_length: Optional[int] = None,
+ window: Optional[Tensor] = None,
+ center: bool = True,
+ pad_mode: str = "reflect",
+ normalized: bool = False,
+ onesided: Optional[bool] = None,
+ return_complex: Optional[bool] = None,
+) -> Tensor:
r"""Short-time Fourier transform (STFT).
.. warning::
@@ -652,9 +678,19 @@
"""
if has_torch_function_unary(input):
return handle_torch_function(
- stft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length,
- window=window, center=center, pad_mode=pad_mode, normalized=normalized,
- onesided=onesided, return_complex=return_complex)
+ stft,
+ (input,),
+ input,
+ n_fft,
+ hop_length=hop_length,
+ win_length=win_length,
+ window=window,
+ center=center,
+ pad_mode=pad_mode,
+ normalized=normalized,
+ onesided=onesided,
+ return_complex=return_complex,
+ )
# NOTE: Do not edit. This code will be removed once the forward-compatibility
# period is over for PR #73432
if center:
@@ -663,8 +699,16 @@
pad = int(n_fft // 2)
input = F.pad(input.view(extended_shape), [pad, pad], pad_mode)
input = input.view(input.shape[-signal_dim:])
- return _VF.stft(input, n_fft, hop_length, win_length, window, # type: ignore[attr-defined]
- normalized, onesided, return_complex)
+ return _VF.stft( # type: ignore[attr-defined]
+ input,
+ n_fft,
+ hop_length,
+ win_length,
+ window,
+ normalized,
+ onesided,
+ return_complex,
+ )
istft = _add_docstr(
@@ -746,7 +790,8 @@
Returns:
Tensor: Least squares estimation of the original signal of shape `(B?, length)` where
`B?` is an optional batch dimension from the input tensor.
-""")
+""",
+)
if TYPE_CHECKING:
@@ -758,9 +803,13 @@
_unique_impl_out = Tuple[Tensor, Tensor, Tensor]
-def _unique_impl(input: Tensor, sorted: bool = True,
- return_inverse: bool = False, return_counts: bool = False,
- dim: Optional[int] = None) -> _unique_impl_out:
+def _unique_impl(
+ input: Tensor,
+ sorted: bool = True,
+ return_inverse: bool = False,
+ return_counts: bool = False,
+ dim: Optional[int] = None,
+) -> _unique_impl_out:
r"""unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None) -> Tuple[Tensor, Tensor, Tensor]
Returns the unique elements of the input tensor.
@@ -896,8 +945,14 @@
"""
if has_torch_function_unary(input):
return handle_torch_function(
- unique, (input,), input, sorted=sorted, return_inverse=return_inverse,
- return_counts=return_counts, dim=dim)
+ unique,
+ (input,),
+ input,
+ sorted=sorted,
+ return_inverse=return_inverse,
+ return_counts=return_counts,
+ dim=dim,
+ )
if dim is not None:
output, inverse_indices, counts = _VF.unique_dim(
@@ -917,9 +972,12 @@
return output, inverse_indices, counts
-def _unique_consecutive_impl(input: Tensor, return_inverse: bool = False,
- return_counts: bool = False,
- dim: Optional[int] = None) -> _unique_impl_out:
+def _unique_consecutive_impl(
+ input: Tensor,
+ return_inverse: bool = False,
+ return_counts: bool = False,
+ dim: Optional[int] = None,
+) -> _unique_impl_out:
r"""Eliminates all but the first element from every consecutive group of equivalent elements.
.. note:: This function is different from :func:`torch.unique` in the sense that this function
@@ -971,14 +1029,22 @@
"""
if has_torch_function_unary(input):
return handle_torch_function(
- unique_consecutive, (input,), input, return_inverse=return_inverse,
- return_counts=return_counts, dim=dim)
+ unique_consecutive,
+ (input,),
+ input,
+ return_inverse=return_inverse,
+ return_counts=return_counts,
+ dim=dim,
+ )
output, inverse_indices, counts = _VF.unique_consecutive( # type: ignore[attr-defined]
- input, return_inverse=return_inverse, return_counts=return_counts, dim=dim)
+ input, return_inverse=return_inverse, return_counts=return_counts, dim=dim
+ )
return output, inverse_indices, counts
-def _return_counts(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
+def _return_counts(
+ input, sorted=True, return_inverse=False, return_counts=False, dim=None
+):
# type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
if has_torch_function_unary(input):
@@ -988,7 +1054,9 @@
return output, counts
-def _return_output(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
+def _return_output(
+ input, sorted=True, return_inverse=False, return_counts=False, dim=None
+):
# type: (Tensor, bool, bool, bool, Optional[int]) -> Tensor
if has_torch_function_unary(input):
@@ -998,59 +1066,72 @@
return output
-def _return_inverse(input, sorted=True, return_inverse=False, return_counts=False, dim=None):
+def _return_inverse(
+ input, sorted=True, return_inverse=False, return_counts=False, dim=None
+):
# type: (Tensor, bool, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
if has_torch_function_unary(input):
return _unique_impl(input, sorted, return_inverse, return_counts, dim)
- output, inverse_indices, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim)
+ output, inverse_indices, _ = _unique_impl(
+ input, sorted, return_inverse, return_counts, dim
+ )
return output, inverse_indices
_return_inverse_false = boolean_dispatch(
- arg_name='return_counts',
+ arg_name="return_counts",
arg_index=3,
default=False,
if_true=_return_counts,
if_false=_return_output,
module_name=__name__,
- func_name='unique')
+ func_name="unique",
+)
_return_inverse_true = boolean_dispatch(
- arg_name='return_counts',
+ arg_name="return_counts",
arg_index=3,
default=False,
if_true=_unique_impl,
if_false=_return_inverse,
module_name=__name__,
- func_name='unique')
+ func_name="unique",
+)
# The return type of unique depends on `return_inverse`, and `return_counts` so in order to
# resolve the output type in TorchScript we need to statically know the value of both parameters
unique = boolean_dispatch(
- arg_name='return_inverse',
+ arg_name="return_inverse",
arg_index=2,
default=False,
if_true=_return_inverse_true,
if_false=_return_inverse_false,
module_name=__name__,
- func_name='unique')
+ func_name="unique",
+)
unique.__doc__ = _unique_impl.__doc__
-def _consecutive_return_counts(input, return_inverse=False, return_counts=False, dim=None):
+def _consecutive_return_counts(
+ input, return_inverse=False, return_counts=False, dim=None
+):
# type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
if has_torch_function_unary(input):
return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
- output, _, counts = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
+ output, _, counts = _unique_consecutive_impl(
+ input, return_inverse, return_counts, dim
+ )
return output, counts
-def _consecutive_return_output(input, return_inverse=False, return_counts=False, dim=None):
+def _consecutive_return_output(
+ input, return_inverse=False, return_counts=False, dim=None
+):
# type: (Tensor, bool, bool, Optional[int]) -> Tensor
if has_torch_function_unary(input):
@@ -1060,45 +1141,52 @@
return output
-def _consecutive_return_inverse(input, return_inverse=False, return_counts=False, dim=None):
+def _consecutive_return_inverse(
+ input, return_inverse=False, return_counts=False, dim=None
+):
# type: (Tensor, bool, bool, Optional[int]) -> Tuple[Tensor, Tensor]
if has_torch_function_unary(input):
return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
- output, inverse_indices, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
+ output, inverse_indices, _ = _unique_consecutive_impl(
+ input, return_inverse, return_counts, dim
+ )
return output, inverse_indices
_consecutive_return_inverse_false = boolean_dispatch(
- arg_name='return_counts',
+ arg_name="return_counts",
arg_index=1,
default=False,
if_true=_consecutive_return_counts,
if_false=_consecutive_return_output,
module_name=__name__,
- func_name='unique_consecutive')
+ func_name="unique_consecutive",
+)
_consecutive_return_inverse_true = boolean_dispatch(
- arg_name='return_counts',
+ arg_name="return_counts",
arg_index=1,
default=False,
if_true=_unique_consecutive_impl,
if_false=_consecutive_return_inverse,
module_name=__name__,
- func_name='unique_consecutive')
+ func_name="unique_consecutive",
+)
# The return type of unique depends on `return_inverse`, and `return_counts` so in order to
# resolve the output type in TorchScript we need to statically know the value of both parameters
unique_consecutive = boolean_dispatch(
- arg_name='return_inverse',
+ arg_name="return_inverse",
arg_index=2,
default=False,
if_true=_consecutive_return_inverse_true,
if_false=_consecutive_return_inverse_false,
module_name=__name__,
- func_name='unique_consecutive')
+ func_name="unique_consecutive",
+)
unique_consecutive.__doc__ = _unique_consecutive_impl.__doc__
if TYPE_CHECKING:
@@ -1106,24 +1194,50 @@
# There's no good way to use this type annotation without breaking JIT
# overloads. So leave untyped for mypy for now.
else:
+
@overload
- def tensordot(a, b, dims: int = 2, out: Optional[torch.Tensor] = None):
+ def tensordot(
+ a,
+ b,
+ dims: int = 2,
+ out: Optional[torch.Tensor] = None,
+ ):
pass
- @overload # noqa: F811
- def tensordot(a, b, dims: Tuple[List[int], List[int]], out: Optional[torch.Tensor] = None): # noqa: F811
+ @overload
+ def tensordot( # noqa: F811
+ a,
+ b,
+ dims: Tuple[List[int], List[int]],
+ out: Optional[torch.Tensor] = None,
+ ):
pass
- @overload # noqa: F811
- def tensordot(a, b, dims: List[List[int]], out: Optional[torch.Tensor] = None): # noqa: F811
+ @overload
+ def tensordot( # noqa: F811
+ a,
+ b,
+ dims: List[List[int]],
+ out: Optional[torch.Tensor] = None,
+ ):
pass
- @overload # noqa: F811
- def tensordot(a, b, dims: torch.Tensor, out: Optional[torch.Tensor] = None): # noqa: F811
+ @overload
+ def tensordot( # noqa: F811
+ a,
+ b,
+ dims: torch.Tensor,
+ out: Optional[torch.Tensor] = None,
+ ):
pass
-def tensordot(a, b, dims=2, out: Optional[torch.Tensor] = None): # noqa: F811
+def tensordot( # noqa: F811
+ a,
+ b,
+ dims=2,
+ out: Optional[torch.Tensor] = None,
+):
r"""Returns a contraction of a and b over multiple dimensions.
:attr:`tensordot` implements a generalized matrix product.
@@ -1178,10 +1292,12 @@
return handle_torch_function(tensordot, (a, b), a, b, dims=dims, out=out)
if not isinstance(dims, (tuple, list, torch.Tensor, int, torch.SymInt)):
- raise RuntimeError("tensordot expects dims to be int or "
- + "Tuple[List[int], List[int]] or "
- + "List[List[int]] containing two lists, but got "
- + f"dims={dims}")
+ raise RuntimeError(
+ "tensordot expects dims to be int or "
+ + "Tuple[List[int], List[int]] or "
+ + "List[List[int]] containing two lists, but got "
+ + f"dims={dims}"
+ )
dims_a: List[int] = []
dims_b: List[int] = []
@@ -1206,7 +1322,9 @@
if dims < 0:
raise RuntimeError(f"tensordot expects dims >= 0, but got dims={dims}")
if dims > min(a.dim(), b.dim()):
- raise RuntimeError(f"tensordot expects dims < ndim_a or ndim_b, but got dims={dims}")
+ raise RuntimeError(
+ f"tensordot expects dims < ndim_a or ndim_b, but got dims={dims}"
+ )
dims_a = list(range(-dims, 0))
dims_b = list(range(dims))
@@ -1287,7 +1405,7 @@
return torch._C._VariableFunctions.block_diag(tensors) # type: ignore[attr-defined]
-def cdist(x1, x2, p=2., compute_mode='use_mm_for_euclid_dist_if_necessary'):
+def cdist(x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary"):
# type: (Tensor, Tensor, float, str) -> (Tensor)
r"""Computes batched the p-norm distance between each pair of the two collections of row vectors.
@@ -1331,12 +1449,13 @@
"""
if has_torch_function_variadic(x1, x2):
return handle_torch_function(
- cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode)
- if compute_mode == 'use_mm_for_euclid_dist_if_necessary':
+ cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode
+ )
+ if compute_mode == "use_mm_for_euclid_dist_if_necessary":
return _VF.cdist(x1, x2, p, None) # type: ignore[attr-defined]
- elif compute_mode == 'use_mm_for_euclid_dist':
+ elif compute_mode == "use_mm_for_euclid_dist":
return _VF.cdist(x1, x2, p, 1) # type: ignore[attr-defined]
- elif compute_mode == 'donot_use_mm_for_euclid_dist':
+ elif compute_mode == "donot_use_mm_for_euclid_dist":
return _VF.cdist(x1, x2, p, 2) # type: ignore[attr-defined]
else:
raise ValueError(f"{compute_mode} is not a valid value for compute_mode")
@@ -1478,27 +1597,62 @@
# TODO: type dim as BroadcastingList when
# https://github.com/pytorch/pytorch/issues/33782 is fixed
@overload
- def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None):
+ def norm(
+ input,
+ p="fro",
+ dim=None,
+ keepdim=False,
+ out=None,
+ dtype=None,
+ ):
# type: (Tensor, str, Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor
pass
- @overload # noqa: F811
- def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa: F811
+ @overload
+ def norm( # noqa: F811
+ input,
+ p="fro",
+ dim=None,
+ keepdim=False,
+ out=None,
+ dtype=None,
+ ):
# type: (Tensor, Optional[number], Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor
pass
- @overload # noqa: F811
- def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa: F811
+ @overload
+ def norm( # noqa: F811
+ input,
+ p="fro",
+ dim=None,
+ keepdim=False,
+ out=None,
+ dtype=None,
+ ):
# type: (Tensor, Optional[number], Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor
pass
- @overload # noqa: F811
- def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa: F811
+ @overload
+ def norm( # noqa: F811
+ input,
+ p="fro",
+ dim=None,
+ keepdim=False,
+ out=None,
+ dtype=None,
+ ):
# type: (Tensor, str, Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor
pass
-def norm(input, p: Optional[Union[float, str]] = "fro", dim=None, keepdim=False, out=None, dtype=None): # noqa: F811
+def norm( # noqa: F811
+ input,
+ p: Optional[Union[float, str]] = "fro",
+ dim=None,
+ keepdim=False,
+ out=None,
+ dtype=None,
+):
r"""Returns the matrix norm or vector norm of a given tensor.
.. warning::
@@ -1594,14 +1748,19 @@
if has_torch_function_unary(input):
return handle_torch_function(
- norm, (input,), input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype)
+ norm, (input,), input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype
+ )
# NB. All the repeated code and weird python is to please TorchScript.
# For a more compact implementation see the relevant function in `_refs/__init__.py`
# We don't do this for MPS or sparse tensors
- if input.layout == torch.strided and input.device.type in \
- ("cpu", "cuda", "meta", torch.utils.backend_registration._privateuse1_backend_name):
+ if input.layout == torch.strided and input.device.type in (
+ "cpu",
+ "cuda",
+ "meta",
+ torch.utils.backend_registration._privateuse1_backend_name,
+ ):
if dim is not None:
if isinstance(dim, (int, torch.SymInt)):
_dim = [dim]
@@ -1611,11 +1770,17 @@
_dim = None # type: ignore[assignment]
if isinstance(p, str):
- if p == "fro" and (dim is None or isinstance(dim, (int, torch.SymInt)) or len(dim) <= 2):
+ if p == "fro" and (
+ dim is None or isinstance(dim, (int, torch.SymInt)) or len(dim) <= 2
+ ):
if out is None:
- return torch.linalg.vector_norm(input, 2, _dim, keepdim, dtype=dtype)
+ return torch.linalg.vector_norm(
+ input, 2, _dim, keepdim, dtype=dtype
+ )
else:
- return torch.linalg.vector_norm(input, 2, _dim, keepdim, dtype=dtype, out=out)
+ return torch.linalg.vector_norm(
+ input, 2, _dim, keepdim, dtype=dtype, out=out
+ )
# Here we either call the nuclear norm, or we call matrix_norm with some arguments
# that will throw an error
@@ -1624,14 +1789,18 @@
if out is None:
return torch.linalg.matrix_norm(input, p, _dim, keepdim, dtype=dtype)
else:
- return torch.linalg.matrix_norm(input, p, _dim, keepdim, dtype=dtype, out=out)
+ return torch.linalg.matrix_norm(
+ input, p, _dim, keepdim, dtype=dtype, out=out
+ )
else:
# NB. p should be Union[str, number], not Optional!
_p = 2.0 if p is None else p
if out is None:
return torch.linalg.vector_norm(input, _p, _dim, keepdim, dtype=dtype)
else:
- return torch.linalg.vector_norm(input, _p, _dim, keepdim, dtype=dtype, out=out)
+ return torch.linalg.vector_norm(
+ input, _p, _dim, keepdim, dtype=dtype, out=out
+ )
ndim = input.dim()
@@ -1641,7 +1810,7 @@
if p == "fro":
return _VF.frobenius_norm(input, dim=(), keepdim=keepdim)
if not isinstance(p, str):
- _dim = [i for i in range(ndim)] # noqa: C416 TODO: rewrite as list(range(m))
+ _dim = list(range(ndim))
return _VF.norm(input, p, dim=_dim, keepdim=keepdim) # type: ignore[attr-defined]
# TODO: when https://github.com/pytorch/pytorch/issues/33782 is fixed
@@ -1695,7 +1864,10 @@
else:
return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype, out=out) # type: ignore[attr-defined]
-def unravel_index(indices: Tensor, shape: Union[int, Sequence[int], torch.Size]) -> Tuple[Tensor, ...]:
+
+def unravel_index(
+ indices: Tensor, shape: Union[int, Sequence[int], torch.Size]
+) -> Tuple[Tensor, ...]:
r"""Converts a tensor of flat indices into a tuple of coordinate tensors that
index into an arbitrary tensor of the specified shape.
@@ -1745,19 +1917,23 @@
tensor([[34], [78]]))
"""
if has_torch_function_unary(indices):
- return handle_torch_function(
- unravel_index, (indices,), indices, shape=shape)
+ return handle_torch_function(unravel_index, (indices,), indices, shape=shape)
res_tensor = _unravel_index(indices, shape)
return res_tensor.unbind(-1)
+
def _unravel_index(indices: Tensor, shape: Union[int, Sequence[int]]) -> Tensor:
torch._check_type(
- not indices.is_complex() and not indices.is_floating_point() and not indices.dtype == torch.bool,
- lambda: f"expected 'indices' to be integer dtype, but got {indices.dtype}")
+ not indices.is_complex()
+ and not indices.is_floating_point()
+ and not indices.dtype == torch.bool,
+ lambda: f"expected 'indices' to be integer dtype, but got {indices.dtype}",
+ )
torch._check_type(
isinstance(shape, (int, torch.SymInt, Sequence)),
- lambda: f"expected 'shape' to be int or sequence of ints, but got {type(shape)}")
+ lambda: f"expected 'shape' to be int or sequence of ints, but got {type(shape)}",
+ )
if isinstance(shape, (int, torch.SymInt)):
shape = torch.Size([shape])
@@ -1765,18 +1941,29 @@
for dim in shape:
torch._check_type(
isinstance(dim, (int, torch.SymInt)),
- lambda: f"expected 'shape' sequence to only contain ints, but got {type(dim)}")
+ lambda: f"expected 'shape' sequence to only contain ints, but got {type(dim)}",
+ )
shape = torch.Size(shape)
torch._check_value(
all(dim >= 0 for dim in shape),
- lambda: f"'shape' cannot have negative values, but got {tuple(shape)}")
+ lambda: f"'shape' cannot have negative values, but got {tuple(shape)}",
+ )
- coefs = list(reversed(list(itertools.accumulate(reversed(shape[1:] + torch.Size([1])), func=operator.mul))))
+ coefs = list(
+ reversed(
+ list(
+ itertools.accumulate(
+ reversed(shape[1:] + torch.Size([1])), func=operator.mul
+ )
+ )
+ )
+ )
return indices.unsqueeze(-1).floor_divide(
torch.tensor(coefs, device=indices.device, dtype=torch.int64)
) % torch.tensor(shape, device=indices.device, dtype=torch.int64)
+
def chain_matmul(*matrices, out=None):
r"""Returns the matrix product of the :math:`N` 2-D tensors. This product is efficiently computed
using the matrix chain order algorithm which selects the order in which incurs the lowest cost in terms
@@ -1923,6 +2110,7 @@
# If get_infos is True, then we don't need to check for errors and vice versa
return torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos))
+
if TYPE_CHECKING:
_ListOrSeq = Sequence[Tensor]
else:
@@ -1932,16 +2120,21 @@
def _check_list_size(out_len: int, get_infos: bool, out: _ListOrSeq) -> None:
get_infos_int = 1 if get_infos else 0
if out_len - get_infos_int != 2:
- raise TypeError(f"expected tuple of {2 + int(get_infos)} elements but got {out_len}")
+ raise TypeError(
+ f"expected tuple of {2 + int(get_infos)} elements but got {out_len}"
+ )
if not isinstance(out, (tuple, list)):
- raise TypeError(f"argument 'out' must be tuple of Tensors, not {type(out).__name__}")
+ raise TypeError(
+ f"argument 'out' must be tuple of Tensors, not {type(out).__name__}"
+ )
def _lu_with_infos(A, pivot=True, get_infos=False, out=None):
# type: (Tensor, bool, bool, Optional[Tuple[Tensor, Tensor, Tensor]]) -> Tuple[Tensor, Tensor, Tensor]
if has_torch_function_unary(A):
return handle_torch_function(
- lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out)
+ lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out
+ )
result = _lu_impl(A, pivot, get_infos, out)
if out is not None:
_check_list_size(len(out), get_infos, out)
@@ -1957,7 +2150,8 @@
# need to check for torch_function here so that we exit if
if has_torch_function_unary(A):
return handle_torch_function(
- lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out)
+ lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out
+ )
result = _lu_impl(A, pivot, get_infos, out)
if out is not None:
_check_list_size(len(out), get_infos, out)
@@ -1967,18 +2161,20 @@
else:
return result[0], result[1] # A_LU, pivots
+
# The return type of lu depends on `get_infos`, so in order to resolve the output type
# of lu in TorchScript we need to statically know the value of `get_infos`
lu = boolean_dispatch(
- arg_name='get_infos',
+ arg_name="get_infos",
arg_index=2,
default=False,
if_true=_lu_with_infos,
if_false=_lu_no_infos,
module_name=__name__,
- func_name='lu')
+ func_name="lu",
+)
lu.__doc__ = _lu_impl.__doc__
def align_tensors(*tensors):
- raise RuntimeError('`align_tensors` not yet implemented.')
+ raise RuntimeError("`align_tensors` not yet implemented.")
diff --git a/torch/hub.py b/torch/hub.py
index 213a129..57a07e2 100644
--- a/torch/hub.py
+++ b/torch/hub.py
@@ -8,22 +8,22 @@
import shutil
import sys
import tempfile
-import torch
import uuid
import warnings
import zipfile
from pathlib import Path
-from typing import Dict, Optional, Any
+from typing import Any, Dict, Optional
from typing_extensions import deprecated
from urllib.error import HTTPError, URLError
-from urllib.request import urlopen, Request
from urllib.parse import urlparse # noqa: F401
+from urllib.request import Request, urlopen
+
+import torch
from torch.serialization import MAP_LOCATION
-class _Faketqdm: # type: ignore[no-redef]
- def __init__(self, total=None, disable=False,
- unit=None, *args, **kwargs):
+class _Faketqdm: # type: ignore[no-redef]
+ def __init__(self, total=None, disable=False, unit=None, *args, **kwargs):
self.total = total
self.disable = disable
self.n = 0
@@ -57,7 +57,8 @@
if self.disable:
return
- sys.stderr.write('\n')
+ sys.stderr.write("\n")
+
try:
from tqdm import tqdm # If tqdm is installed use it, otherwise use the fake wrapper
@@ -65,25 +66,30 @@
tqdm = _Faketqdm
__all__ = [
- 'download_url_to_file',
- 'get_dir',
- 'help',
- 'list',
- 'load',
- 'load_state_dict_from_url',
- 'set_dir',
+ "download_url_to_file",
+ "get_dir",
+ "help",
+ "list",
+ "load",
+ "load_state_dict_from_url",
+ "set_dir",
]
# matches bfd8deac from resnet18-bfd8deac.pth
-HASH_REGEX = re.compile(r'-([a-f0-9]*)\.')
+HASH_REGEX = re.compile(r"-([a-f0-9]*)\.")
-_TRUSTED_REPO_OWNERS = ("facebookresearch", "facebookincubator", "pytorch", "fairinternal")
-ENV_GITHUB_TOKEN = 'GITHUB_TOKEN'
-ENV_TORCH_HOME = 'TORCH_HOME'
-ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
-DEFAULT_CACHE_DIR = '~/.cache'
-VAR_DEPENDENCY = 'dependencies'
-MODULE_HUBCONF = 'hubconf.py'
+_TRUSTED_REPO_OWNERS = (
+ "facebookresearch",
+ "facebookincubator",
+ "pytorch",
+ "fairinternal",
+)
+ENV_GITHUB_TOKEN = "GITHUB_TOKEN"
+ENV_TORCH_HOME = "TORCH_HOME"
+ENV_XDG_CACHE_HOME = "XDG_CACHE_HOME"
+DEFAULT_CACHE_DIR = "~/.cache"
+VAR_DEPENDENCY = "dependencies"
+MODULE_HUBCONF = "hubconf.py"
READ_DATA_CHUNK = 128 * 1024
_hub_dir: Optional[str] = None
@@ -101,6 +107,7 @@
def _import_module(name, path):
import importlib.util
from importlib.abc import Loader
+
spec = importlib.util.spec_from_file_location(name, path)
assert spec is not None
module = importlib.util.module_from_spec(spec)
@@ -131,18 +138,20 @@
def _get_torch_home():
torch_home = os.path.expanduser(
- os.getenv(ENV_TORCH_HOME,
- os.path.join(os.getenv(ENV_XDG_CACHE_HOME,
- DEFAULT_CACHE_DIR), 'torch')))
+ os.getenv(
+ ENV_TORCH_HOME,
+ os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), "torch"),
+ )
+ )
return torch_home
def _parse_repo_info(github):
- if ':' in github:
- repo_info, ref = github.split(':')
+ if ":" in github:
+ repo_info, ref = github.split(":")
else:
repo_info, ref = github, None
- repo_owner, repo_name = repo_info.split('/')
+ repo_owner, repo_name = repo_info.split("/")
if ref is None:
# The ref wasn't specified by the user, so we need to figure out the
@@ -150,16 +159,18 @@
# then it's the default branch, otherwise it's master.
try:
with urlopen(f"https://github.com/{repo_owner}/{repo_name}/tree/main/"):
- ref = 'main'
+ ref = "main"
except HTTPError as e:
if e.code == 404:
- ref = 'master'
+ ref = "master"
else:
raise
except URLError as e:
# No internet connection, need to check for cache as last resort
for possible_ref in ("main", "master"):
- if os.path.exists(f"{get_dir()}/{repo_owner}_{repo_name}_{possible_ref}"):
+ if os.path.exists(
+ f"{get_dir()}/{repo_owner}_{repo_name}_{possible_ref}"
+ ):
ref = possible_ref
break
if ref is None:
@@ -172,35 +183,40 @@
def _read_url(url):
with urlopen(url) as r:
- return r.read().decode(r.headers.get_content_charset('utf-8'))
+ return r.read().decode(r.headers.get_content_charset("utf-8"))
def _validate_not_a_forked_repo(repo_owner, repo_name, ref):
# Use urlopen to avoid depending on local git.
- headers = {'Accept': 'application/vnd.github.v3+json'}
+ headers = {"Accept": "application/vnd.github.v3+json"}
token = os.environ.get(ENV_GITHUB_TOKEN)
if token is not None:
- headers['Authorization'] = f'token {token}'
+ headers["Authorization"] = f"token {token}"
for url_prefix in (
- f'https://api.github.com/repos/{repo_owner}/{repo_name}/branches',
- f'https://api.github.com/repos/{repo_owner}/{repo_name}/tags'):
+ f"https://api.github.com/repos/{repo_owner}/{repo_name}/branches",
+ f"https://api.github.com/repos/{repo_owner}/{repo_name}/tags",
+ ):
page = 0
while True:
page += 1
- url = f'{url_prefix}?per_page=100&page={page}'
+ url = f"{url_prefix}?per_page=100&page={page}"
response = json.loads(_read_url(Request(url, headers=headers)))
# Empty response means no more data to process
if not response:
break
for br in response:
- if br['name'] == ref or br['commit']['sha'].startswith(ref):
+ if br["name"] == ref or br["commit"]["sha"].startswith(ref):
return
- raise ValueError(f'Cannot find {ref} in https://github.com/{repo_owner}/{repo_name}. '
- 'If it\'s a commit from a forked repo, please call hub.load() with forked repo directly.')
+ raise ValueError(
+ f"Cannot find {ref} in https://github.com/{repo_owner}/{repo_name}. "
+ "If it's a commit from a forked repo, please call hub.load() with forked repo directly."
+ )
-def _get_cache_or_reload(github, force_reload, trust_repo, calling_fn, verbose=True, skip_validation=False):
+def _get_cache_or_reload(
+ github, force_reload, trust_repo, calling_fn, verbose=True, skip_validation=False
+):
# Setup hub_dir to save downloaded files
hub_dir = get_dir()
os.makedirs(hub_dir, exist_ok=True)
@@ -210,27 +226,33 @@
# this causes confusion with path on both Linux and Windows.
# Backslash is not allowed in Github branch name so no need to
# to worry about it.
- normalized_br = ref.replace('/', '_')
+ normalized_br = ref.replace("/", "_")
# Github renames folder repo-v1.x.x to repo-1.x.x
# We don't know the repo name before downloading the zip file
# and inspect name from it.
# To check if cached repo exists, we need to normalize folder names.
- owner_name_branch = '_'.join([repo_owner, repo_name, normalized_br])
+ owner_name_branch = "_".join([repo_owner, repo_name, normalized_br])
repo_dir = os.path.join(hub_dir, owner_name_branch)
# Check that the repo is in the trusted list
- _check_repo_is_trusted(repo_owner, repo_name, owner_name_branch, trust_repo=trust_repo, calling_fn=calling_fn)
+ _check_repo_is_trusted(
+ repo_owner,
+ repo_name,
+ owner_name_branch,
+ trust_repo=trust_repo,
+ calling_fn=calling_fn,
+ )
use_cache = (not force_reload) and os.path.exists(repo_dir)
if use_cache:
if verbose:
- sys.stderr.write(f'Using cache found in {repo_dir}\n')
+ sys.stderr.write(f"Using cache found in {repo_dir}\n")
else:
# Validate the tag/branch is from the original repo instead of a forked repo
if not skip_validation:
_validate_not_a_forked_repo(repo_owner, repo_name, ref)
- cached_file = os.path.join(hub_dir, normalized_br + '.zip')
+ cached_file = os.path.join(hub_dir, normalized_br + ".zip")
_remove_if_exists(cached_file)
try:
@@ -250,7 +272,9 @@
"refs/tags/tag_name as the ref. That might require using skip_validation=True."
)
disambiguated_branch_ref = f"refs/heads/{ref}"
- url = _git_archive_link(repo_owner, repo_name, ref=disambiguated_branch_ref)
+ url = _git_archive_link(
+ repo_owner, repo_name, ref=disambiguated_branch_ref
+ )
download_url_to_file(url, cached_file, progress=False)
else:
raise
@@ -269,7 +293,9 @@
return repo_dir
-def _check_repo_is_trusted(repo_owner, repo_name, owner_name_branch, trust_repo, calling_fn="load"):
+def _check_repo_is_trusted(
+ repo_owner, repo_name, owner_name_branch, trust_repo, calling_fn="load"
+):
hub_dir = get_dir()
filepath = os.path.join(hub_dir, "trusted_list")
@@ -282,7 +308,7 @@
# if a repo was already downloaded by torchhub, then it is already trusted (even if it's not in the allowlist)
trusted_repos_legacy = next(os.walk(hub_dir))[1]
- owner_name = '_'.join([repo_owner, repo_name])
+ owner_name = "_".join([repo_owner, repo_name])
is_trusted = (
owner_name in trusted_repos
or owner_name_branch in trusted_repos_legacy
@@ -298,13 +324,15 @@
"trust_repo=False) and a command prompt will appear asking for an explicit confirmation of trust, "
f"or {calling_fn}(..., trust_repo=True), which will assume that the prompt is to be answered with "
f"'yes'. You can also use {calling_fn}(..., trust_repo='check') which will only prompt for "
- f"confirmation if the repo is not already trusted. This will eventually be the default behaviour")
+ f"confirmation if the repo is not already trusted. This will eventually be the default behaviour"
+ )
return
if (trust_repo is False) or (trust_repo == "check" and not is_trusted):
response = input(
f"The repository {owner_name} does not belong to the list of trusted repositories and as such cannot be downloaded. "
- "Do you trust this repository and wish to add it to the trusted list of repositories (y/N)?")
+ "Do you trust this repository and wish to add it to the trusted list of repositories (y/N)?"
+ )
if response.lower() in ("y", "yes"):
if is_trusted:
print("The repository is already trusted.")
@@ -321,6 +349,7 @@
def _check_module_exists(name):
import importlib.util
+
return importlib.util.find_spec(name) is not None
@@ -335,7 +364,7 @@
def _load_entry_from_hubconf(m, model):
if not isinstance(model, str):
- raise ValueError('Invalid input: model should be a string of function name')
+ raise ValueError("Invalid input: model should be a string of function name")
# Note that if a missing dependency is imported at top level of hubconf, it will
# throw before this function. It's a chicken and egg situation where we have to
@@ -346,7 +375,7 @@
func = _load_attr_from_module(m, model)
if func is None or not callable(func):
- raise RuntimeError(f'Cannot find callable {model} in hubconf')
+ raise RuntimeError(f"Cannot find callable {model} in hubconf")
return func
@@ -362,12 +391,12 @@
variable is not set.
"""
# Issue warning to move data if old env is set
- if os.getenv('TORCH_HUB'):
- warnings.warn('TORCH_HUB is deprecated, please use env TORCH_HOME instead')
+ if os.getenv("TORCH_HUB"):
+ warnings.warn("TORCH_HUB is deprecated, please use env TORCH_HOME instead")
if _hub_dir is not None:
return _hub_dir
- return os.path.join(_get_torch_home(), 'hub')
+ return os.path.join(_get_torch_home(), "hub")
def set_dir(d):
@@ -381,7 +410,9 @@
_hub_dir = os.path.expanduser(d)
-def list(github, force_reload=False, skip_validation=False, trust_repo=None, verbose=True):
+def list(
+ github, force_reload=False, skip_validation=False, trust_repo=None, verbose=True
+):
r"""
List all callable entrypoints available in the repo specified by ``github``.
@@ -424,15 +455,25 @@
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB)
>>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True)
"""
- repo_dir = _get_cache_or_reload(github, force_reload, trust_repo, "list", verbose=verbose,
- skip_validation=skip_validation)
+ repo_dir = _get_cache_or_reload(
+ github,
+ force_reload,
+ trust_repo,
+ "list",
+ verbose=verbose,
+ skip_validation=skip_validation,
+ )
with _add_to_sys_path(repo_dir):
hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF)
hub_module = _import_module(MODULE_HUBCONF, hubconf_path)
# We take functions starts with '_' as internal helper functions
- entrypoints = [f for f in dir(hub_module) if callable(getattr(hub_module, f)) and not f.startswith('_')]
+ entrypoints = [
+ f
+ for f in dir(hub_module)
+ if callable(getattr(hub_module, f)) and not f.startswith("_")
+ ]
return entrypoints
@@ -474,8 +515,14 @@
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_HUB)
>>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True))
"""
- repo_dir = _get_cache_or_reload(github, force_reload, trust_repo, "help", verbose=True,
- skip_validation=skip_validation)
+ repo_dir = _get_cache_or_reload(
+ github,
+ force_reload,
+ trust_repo,
+ "help",
+ verbose=True,
+ skip_validation=skip_validation,
+ )
with _add_to_sys_path(repo_dir):
hubconf_path = os.path.join(repo_dir, MODULE_HUBCONF)
@@ -486,9 +533,17 @@
return entry.__doc__
-def load(repo_or_dir, model, *args, source='github', trust_repo=None, force_reload=False, verbose=True,
- skip_validation=False,
- **kwargs):
+def load(
+ repo_or_dir,
+ model,
+ *args,
+ source="github",
+ trust_repo=None,
+ force_reload=False,
+ verbose=True,
+ skip_validation=False,
+ **kwargs,
+):
r"""
Load a model from a github repo or a local directory.
@@ -559,13 +614,20 @@
"""
source = source.lower()
- if source not in ('github', 'local'):
+ if source not in ("github", "local"):
raise ValueError(
- f'Unknown source: "{source}". Allowed values: "github" | "local".')
+ f'Unknown source: "{source}". Allowed values: "github" | "local".'
+ )
- if source == 'github':
- repo_or_dir = _get_cache_or_reload(repo_or_dir, force_reload, trust_repo, "load",
- verbose=verbose, skip_validation=skip_validation)
+ if source == "github":
+ repo_or_dir = _get_cache_or_reload(
+ repo_or_dir,
+ force_reload,
+ trust_repo,
+ "load",
+ verbose=verbose,
+ skip_validation=skip_validation,
+ )
model = _load_local(repo_or_dir, model, *args, **kwargs)
return model
@@ -601,8 +663,9 @@
return model
-def download_url_to_file(url: str, dst: str, hash_prefix: Optional[str] = None,
- progress: bool = True) -> None:
+def download_url_to_file(
+ url: str, dst: str, hash_prefix: Optional[str] = None, progress: bool = True
+) -> None:
r"""Download object at the given URL to a local path.
Args:
@@ -623,7 +686,7 @@
req = Request(url, headers={"User-Agent": "torch.hub"})
u = urlopen(req)
meta = u.info()
- if hasattr(meta, 'getheaders'):
+ if hasattr(meta, "getheaders"):
content_length = meta.getheaders("Content-Length")
else:
content_length = meta.get_all("Content-Length")
@@ -637,20 +700,25 @@
# file permissions being applied to the downloaded file.
dst = os.path.expanduser(dst)
for seq in range(tempfile.TMP_MAX):
- tmp_dst = dst + '.' + uuid.uuid4().hex + '.partial'
+ tmp_dst = dst + "." + uuid.uuid4().hex + ".partial"
try:
- f = open(tmp_dst, 'w+b')
+ f = open(tmp_dst, "w+b")
except FileExistsError:
continue
break
else:
- raise FileExistsError(errno.EEXIST, 'No usable temporary file name found')
+ raise FileExistsError(errno.EEXIST, "No usable temporary file name found")
try:
if hash_prefix is not None:
sha256 = hashlib.sha256()
- with tqdm(total=file_size, disable=not progress,
- unit='B', unit_scale=True, unit_divisor=1024) as pbar:
+ with tqdm(
+ total=file_size,
+ disable=not progress,
+ unit="B",
+ unit_scale=True,
+ unit_divisor=1024,
+ ) as pbar:
while True:
buffer = u.read(READ_DATA_CHUNK)
if len(buffer) == 0:
@@ -663,8 +731,10 @@
f.close()
if hash_prefix is not None:
digest = sha256.hexdigest() # type: ignore[possibly-undefined]
- if digest[:len(hash_prefix)] != hash_prefix:
- raise RuntimeError(f'invalid hash value (expected "{hash_prefix}", got "{digest}")')
+ if digest[: len(hash_prefix)] != hash_prefix:
+ raise RuntimeError(
+ f'invalid hash value (expected "{hash_prefix}", got "{digest}")'
+ )
shutil.move(f.name, dst)
finally:
f.close()
@@ -683,23 +753,30 @@
@deprecated(
- 'Falling back to the old format < 1.6. This support will be '
- 'deprecated in favor of default zipfile format introduced in 1.6. '
- 'Please redo torch.save() to save it in the new zipfile format.',
+ "Falling back to the old format < 1.6. This support will be "
+ "deprecated in favor of default zipfile format introduced in 1.6. "
+ "Please redo torch.save() to save it in the new zipfile format.",
category=FutureWarning,
)
-def _legacy_zip_load(filename: str, model_dir: str, map_location: MAP_LOCATION, weights_only: bool) -> Dict[str, Any]:
+def _legacy_zip_load(
+ filename: str,
+ model_dir: str,
+ map_location: MAP_LOCATION,
+ weights_only: bool,
+) -> Dict[str, Any]:
# Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand.
# We deliberately don't handle tarfile here since our legacy serialization format was in tar.
# E.g. resnet18-5c106cde.pth which is widely used.
with zipfile.ZipFile(filename) as f:
members = f.infolist()
if len(members) != 1:
- raise RuntimeError('Only one file(not dir) is allowed in the zipfile')
+ raise RuntimeError("Only one file(not dir) is allowed in the zipfile")
f.extractall(model_dir)
extraced_name = members[0].filename
extracted_file = os.path.join(model_dir, extraced_name)
- return torch.load(extracted_file, map_location=map_location, weights_only=weights_only)
+ return torch.load(
+ extracted_file, map_location=map_location, weights_only=weights_only
+ )
def load_state_dict_from_url(
@@ -742,12 +819,14 @@
"""
# Issue warning to move data if old env is set
- if os.getenv('TORCH_MODEL_ZOO'):
- warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
+ if os.getenv("TORCH_MODEL_ZOO"):
+ warnings.warn(
+ "TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead"
+ )
if model_dir is None:
hub_dir = get_dir()
- model_dir = os.path.join(hub_dir, 'checkpoints')
+ model_dir = os.path.join(hub_dir, "checkpoints")
os.makedirs(model_dir, exist_ok=True)
diff --git a/torch/library.py b/torch/library.py
index d0a4cf2..bda34c2 100644
--- a/torch/library.py
+++ b/torch/library.py
@@ -1,28 +1,34 @@
# mypy: allow-untyped-defs
-from ._ops import OpOverload
-from typing import Any, Optional, Set, List, Union, Callable, Tuple, Dict, Sequence
-from typing_extensions import deprecated
-import traceback
-import torch
-import weakref
+import contextlib
import functools
import inspect
import re
-import contextlib
import sys
-from torch._library.custom_ops import custom_op, _maybe_get_opdef, device_types_t, CustomOpDef
+import traceback
+import weakref
+from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
+from typing_extensions import deprecated
+
+import torch
import torch._library as _library
+from torch._library.custom_ops import (
+ _maybe_get_opdef,
+ custom_op,
+ CustomOpDef,
+ device_types_t,
+)
+from torch._ops import OpOverload
__all__ = [
- 'Library',
- 'impl',
- 'define',
- 'fallthrough_kernel',
- 'impl_abstract',
- 'register_fake',
- 'get_ctx',
- 'custom_op',
+ "Library",
+ "impl",
+ "define",
+ "fallthrough_kernel",
+ "impl_abstract",
+ "register_fake",
+ "get_ctx",
+ "custom_op",
]
# Set containing the combination of (namespace, operator, DispatchKey) for which a new kernel has been registered
@@ -33,7 +39,8 @@
_defs: Set[str] = set()
# prim is reserved by TorchScript interpreter
-_reserved_namespaces = ['prim']
+_reserved_namespaces = ["prim"]
+
def fallthrough_kernel():
"""
@@ -41,6 +48,7 @@
"""
raise NotImplementedError("fallthrough_kernel() should never be called.")
+
class Library:
"""
A class to create libraries that can be used to register new operators or
@@ -59,16 +67,22 @@
kind: "DEF", "IMPL" (default: "IMPL"), "FRAGMENT"
dispatch_key: PyTorch dispatch key (default: "")
"""
+
def __init__(self, ns, kind, dispatch_key=""):
- if kind not in ('IMPL', 'DEF', 'FRAGMENT'):
+ if kind not in ("IMPL", "DEF", "FRAGMENT"):
raise ValueError("Unsupported kind: ", kind)
- if ns in _reserved_namespaces and (kind == "DEF" or kind == 'FRAGMENT'):
- raise ValueError(ns, " is a reserved namespace. Please try creating a library with another name.")
+ if ns in _reserved_namespaces and (kind == "DEF" or kind == "FRAGMENT"):
+ raise ValueError(
+ ns,
+ " is a reserved namespace. Please try creating a library with another name.",
+ )
frame = traceback.extract_stack(limit=3)[0]
filename, lineno = frame.filename, frame.lineno
- self.m: Optional[Any] = torch._C._dispatch_library(kind, ns, dispatch_key, filename, lineno)
+ self.m: Optional[Any] = torch._C._dispatch_library(
+ kind, ns, dispatch_key, filename, lineno
+ )
self.ns = ns
self._op_defs: Set[str] = set()
self._op_impls: Set[str] = set()
@@ -79,13 +93,21 @@
# Python __del__ can lead to weird things (globals and locals may already
# be gone when __del__ actually gets called!). finalizers help the
# situation because it lets us capture references and keeps them alive
- weakref.finalize(self, _del_library, _impls, self._op_impls, _defs, self._op_defs, self._registration_handles)
+ weakref.finalize(
+ self,
+ _del_library,
+ _impls,
+ self._op_impls,
+ _defs,
+ self._op_defs,
+ self._registration_handles,
+ )
def __repr__(self):
return f"Library(kind={self.kind}, ns={self.ns}, dispatch_key={self.dispatch_key})>"
def define(self, schema, alias_analysis="", *, tags=()):
- r'''Defines a new operator and its semantics in the ns namespace.
+ r"""Defines a new operator and its semantics in the ns namespace.
Args:
schema: function schema to define a new operator.
@@ -102,7 +124,7 @@
Example::
>>> my_lib = Library("mylib", "DEF")
>>> my_lib.define("sum(Tensor self) -> Tensor")
- '''
+ """
# This is added because we also want to disallow PURE_FUNCTION alias analysis which is a valid
# AliasAnalysis type in C++
if alias_analysis not in ["", "FROM_SCHEMA", "CONSERVATIVE"]:
@@ -113,7 +135,9 @@
name = schema.split("(")[0]
packet_name = name.split(".")[0] if "." in name else name
- has_preexisting_packet = hasattr(torch.ops, self.ns) and hasattr(getattr(torch.ops, self.ns), packet_name)
+ has_preexisting_packet = hasattr(torch.ops, self.ns) and hasattr(
+ getattr(torch.ops, self.ns), packet_name
+ )
result = self.m.define(schema, alias_analysis, tuple(tags))
name = schema.split("(")[0]
@@ -131,7 +155,7 @@
return result
def _register_fake(self, op_name, fn, _stacklevel=1):
- r'''Registers the fake impl for an operator defined in the library.'''
+ r"""Registers the fake impl for an operator defined in the library."""
source = torch._library.utils.get_source(_stacklevel + 1)
frame = sys._getframe(_stacklevel)
caller_module = inspect.getmodule(frame)
@@ -141,7 +165,9 @@
# TODO(rzou): We're gonna need to stage this change with torchvision,
# since torchvision is github first.
- if caller_module_name is not None and caller_module_name.startswith("torchvision."):
+ if caller_module_name is not None and caller_module_name.startswith(
+ "torchvision."
+ ):
caller_module_name = None
qualname = f"{self.ns}::{op_name}"
@@ -154,8 +180,8 @@
handle = entry.abstract_impl.register(func_to_register, source)
self._registration_handles.append(handle)
- def _impl_with_aoti_compile(self, op_name, dispatch_key=''):
- r'''Register the operator to use the AOTI-compiled implementation.
+ def _impl_with_aoti_compile(self, op_name, dispatch_key=""):
+ r"""Register the operator to use the AOTI-compiled implementation.
Args:
op_name: operator name (along with the overload) or OpOverload object.
@@ -165,8 +191,8 @@
Example::
>>> my_lib = Library("aten", "IMPL")
>>> my_lib._impl_with_aoti_compile("div.Tensor", "CPU")
- '''
- if dispatch_key == '':
+ """
+ if dispatch_key == "":
dispatch_key = self.dispatch_key
assert torch.DispatchKeySet(dispatch_key).has(torch._C.DispatchKey.Dense)
@@ -175,19 +201,24 @@
elif isinstance(op_name, OpOverload):
name = op_name._schema.name
overload_name = op_name._schema.overload_name
- if overload_name != '':
- name = name + '.' + overload_name
+ if overload_name != "":
+ name = name + "." + overload_name
else:
- raise RuntimeError("_impl_with_aoti_compile should be passed either a name or an OpOverload object "
- "as the first argument")
+ raise RuntimeError(
+ "_impl_with_aoti_compile should be passed either a name or an OpOverload object "
+ "as the first argument"
+ )
key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key
if key in _impls:
# TODO: in future, add more info about where the existing function is registered (this info is
# today already returned by the C++ warning when _impl_with_aoti_compile is called but we error out before that)
- raise RuntimeError("This is not allowed since there's already a kernel registered from python overriding {}"
- "'s behavior for {} dispatch key and {} namespace.".
- format(name.split("::")[-1], dispatch_key, self.ns))
+ raise RuntimeError(
+ "This is not allowed since there's already a kernel registered from python overriding {}"
+ "'s behavior for {} dispatch key and {} namespace.".format(
+ name.split("::")[-1], dispatch_key, self.ns
+ )
+ )
assert self.m is not None
impl_fn: Callable = self.m.impl_with_aoti_compile
@@ -196,8 +227,8 @@
_impls.add(key)
self._op_impls.add(key)
- def impl(self, op_name, fn, dispatch_key='', *, with_keyset=False):
- r'''Registers the function implementation for an operator defined in the library.
+ def impl(self, op_name, fn, dispatch_key="", *, with_keyset=False):
+ r"""Registers the function implementation for an operator defined in the library.
Args:
op_name: operator name (along with the overload) or OpOverload object.
@@ -211,10 +242,12 @@
>>> def div_cpu(self, other):
>>> return self * (1 / other)
>>> my_lib.impl("div.Tensor", div_cpu, "CPU")
- '''
+ """
if not callable(fn):
- raise TypeError(f"Input function is required to be a callable but found type {type(fn)}")
- if dispatch_key == '':
+ raise TypeError(
+ f"Input function is required to be a callable but found type {type(fn)}"
+ )
+ if dispatch_key == "":
dispatch_key = self.dispatch_key
if isinstance(op_name, str):
@@ -222,37 +255,50 @@
elif isinstance(op_name, OpOverload):
name = op_name._schema.name
overload_name = op_name._schema.overload_name
- if overload_name != '':
- name = name + '.' + overload_name
+ if overload_name != "":
+ name = name + "." + overload_name
else:
- raise RuntimeError("impl should be passed either a name or an OpOverload object as the first argument")
+ raise RuntimeError(
+ "impl should be passed either a name or an OpOverload object as the first argument"
+ )
key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key
if key in _impls:
# TODO: in future, add more info about where the existing function is registered (this info is
# today already returned by the C++ warning when impl is called but we error out before that)
- raise RuntimeError("This is not allowed since there's already a kernel registered from python overriding {}"
- "'s behavior for {} dispatch key and {} namespace.".
- format(name.split("::")[-1], dispatch_key, self.ns))
+ raise RuntimeError(
+ "This is not allowed since there's already a kernel registered from python overriding {}"
+ "'s behavior for {} dispatch key and {} namespace.".format(
+ name.split("::")[-1], dispatch_key, self.ns
+ )
+ )
if dispatch_key == "Meta":
dispatcher_op_name = name
- if '::' not in dispatcher_op_name:
- dispatcher_op_name = f'{self.ns}::{dispatcher_op_name}'
+ if "::" not in dispatcher_op_name:
+ dispatcher_op_name = f"{self.ns}::{dispatcher_op_name}"
# Internally, we shouldn't be registering meta kernels for any operators that
# have CompositeImplicitAutograd kernels.
# Instead, we should be letting those decompositions run, and writing meta kernels
# only for the base operators.
- if torch._C._dispatch_has_kernel_for_dispatch_key(dispatcher_op_name, "CompositeImplicitAutograd"):
+ if torch._C._dispatch_has_kernel_for_dispatch_key(
+ dispatcher_op_name, "CompositeImplicitAutograd"
+ ):
raise RuntimeError(
f"We should not register a meta kernel directly to the operator '{name}',"
" because it has a CompositeImplicitAutograd kernel in core."
" Instead we should let the operator decompose, and ensure that we have meta kernels"
- " for the base ops that it decomposes into.")
+ " for the base ops that it decomposes into."
+ )
assert self.m is not None
- self.m.impl(name, dispatch_key if dispatch_key != "" else "CompositeImplicitAutograd", fn, with_keyset)
+ self.m.impl(
+ name,
+ dispatch_key if dispatch_key != "" else "CompositeImplicitAutograd",
+ fn,
+ with_keyset,
+ )
_impls.add(key)
self._op_impls.add(key)
@@ -283,7 +329,9 @@
delattr(namespace, name)
-def _del_library(captured_impls, op_impls, captured_defs, op_defs, registration_handles):
+def _del_library(
+ captured_impls, op_impls, captured_defs, op_defs, registration_handles
+):
captured_impls -= op_impls
captured_defs -= op_defs
for handle in registration_handles:
@@ -357,7 +405,8 @@
if not isinstance(qualname, str):
raise ValueError(
f"define(qualname, schema): expected qualname "
- f"to be instance of str, got {type(qualname)}")
+ f"to be instance of str, got {type(qualname)}"
+ )
namespace, name = torch._library.utils.parse_namespace(qualname)
if lib is None:
lib = Library(namespace, "FRAGMENT")
@@ -366,7 +415,8 @@
raise ValueError(
f"define(qualname, schema, ...): expected schema "
f'to look like e.g. "(Tensor x) -> Tensor" but '
- f'got "{schema}"')
+ f'got "{schema}"'
+ )
lib.define(name + schema, alias_analysis="", tags=tags)
@@ -375,10 +425,12 @@
"""The old torch.library.define.
We're keeping this around for BC reasons
"""
+
def wrap(f):
name = lib.define(schema, alias_analysis)
lib.impl(name, f)
return f
+
return wrap
@@ -460,9 +512,11 @@
@impl.register
def _(lib: Library, name, dispatch_key=""):
"""Legacy torch.library.impl API. Kept around for BC"""
+
def wrap(f):
lib.impl(name, f, dispatch_key)
return f
+
return wrap
@@ -480,16 +534,19 @@
return register_fake(qualname, func, lib=lib, _stacklevel=_stacklevel)
-_op_identifier = Union[str, "torch._ops.OpOverload", "torch._library.custom_ops.CustomOpDef"]
+_op_identifier = Union[
+ str, "torch._ops.OpOverload", "torch._library.custom_ops.CustomOpDef"
+]
def register_kernel(
- op: _op_identifier,
- device_types: device_types_t,
- func: Optional[Callable] = None,
- /,
- *,
- lib: Optional[Library] = None):
+ op: _op_identifier,
+ device_types: device_types_t,
+ func: Optional[Callable] = None,
+ /,
+ *,
+ lib: Optional[Library] = None,
+):
"""Register an implementation for a device type for this operator.
Some valid device_types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu".
@@ -530,7 +587,9 @@
"""
- if not isinstance(op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)):
+ if not isinstance(
+ op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
+ ):
raise ValueError("register_kernel(op): got unexpected type for op: {type(op)}")
if isinstance(op, torch._ops.OpOverload):
op = op._name
@@ -544,12 +603,13 @@
def register_fake(
- op: _op_identifier,
- func: Optional[Callable] = None,
- /,
- *,
- lib: Optional[Library] = None,
- _stacklevel: int = 1):
+ op: _op_identifier,
+ func: Optional[Callable] = None,
+ /,
+ *,
+ lib: Optional[Library] = None,
+ _stacklevel: int = 1,
+):
r"""Register a FakeTensor implementation ("fake impl") for this operator.
Also sometimes known as a "meta kernel", "abstract impl".
@@ -630,7 +690,9 @@
>>> assert torch.allclose(trace(x), torch.ops.mylib.custom_nonzero(x))
"""
- if not isinstance(op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)):
+ if not isinstance(
+ op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
+ ):
raise ValueError("register_fake(op): got unexpected type for op: {type(op)}")
if isinstance(op, torch._ops.OpOverload):
op = op._name
@@ -661,7 +723,14 @@
return register(func)
-def register_autograd(op: _op_identifier, backward: Callable, /, *, setup_context: Optional[Callable] = None, lib=None) -> None:
+def register_autograd(
+ op: _op_identifier,
+ backward: Callable,
+ /,
+ *,
+ setup_context: Optional[Callable] = None,
+ lib=None,
+) -> None:
r"""Register a backward formula for this custom op.
In order for an operator to work with autograd, you need to register
@@ -737,8 +806,12 @@
>>> assert torch.allclose(grad_x, torch.full_like(x, 3.14))
"""
- if not isinstance(op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)):
- raise ValueError(f"register_autograd(op): got unexpected type for op: {type(op)}")
+ if not isinstance(
+ op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
+ ):
+ raise ValueError(
+ f"register_autograd(op): got unexpected type for op: {type(op)}"
+ )
if isinstance(op, torch._ops.OpOverload):
op = op._name
opdef = _maybe_get_opdef(op)
@@ -760,7 +833,8 @@
raise NotImplementedError(
f"register_autograd with kwarg-only Tensor args. In the original "
f"definition of the op, please make your tensors not kwarg-only. "
- f"Got: {schema}")
+ f"Got: {schema}"
+ )
info = _library.autograd.Info(backward, setup_context)
autograd_kernel = _library.autograd.make_autograd_impl(op, info)
@@ -788,8 +862,8 @@
return func(*args, **kwargs)
maybe_pystub = torch._C._dispatch_pystub(
- op._schema.name,
- op._schema.overload_name)
+ op._schema.name, op._schema.overload_name
+ )
if maybe_pystub is None:
if torch._library.utils.requires_set_python_module():
namespace = op.namespace
@@ -800,7 +874,8 @@
f'companion C++ `m.set_python_module("{actual_module_name}")` '
f"call, but we could not find one. Please add that to "
f"to the top of the C++ TORCH_LIBRARY({namespace}, ...) block the "
- f"operator was registered in ({cpp_filename})")
+ f"operator was registered in ({cpp_filename})"
+ )
else:
pystub_module = maybe_pystub[0]
if actual_module_name != pystub_module:
@@ -809,9 +884,11 @@
f"Operator '{qualname}' specified that its python fake impl "
f"is in the Python module '{pystub_module}' but it was actually found "
f"in '{actual_module_name}'. Please either move the fake impl "
- f"or correct the m.set_python_module call ({cpp_filename})")
+ f"or correct the m.set_python_module call ({cpp_filename})"
+ )
checked = True
return func(*args, **kwargs)
+
return inner
@@ -929,4 +1006,7 @@
"""
import torch.testing._internal.optests as optests
- return optests.opcheck(op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception)
+
+ return optests.opcheck(
+ op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
+ )
diff --git a/torch/overrides.py b/torch/overrides.py
index 5095689..651912b 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -23,19 +23,26 @@
import __future__ # noqa: F404
import collections
+import contextlib
import functools
import types
import warnings
-from typing import Dict, Set, List, Any, Callable, Iterable, Type, Tuple
from functools import wraps
-import contextlib
+from typing import Any, Callable, Dict, Iterable, List, Set, Tuple, Type
import torch
from torch._C import (
- _has_torch_function, _has_torch_function_unary,
- _has_torch_function_variadic, _add_docstr,
- _push_on_torch_function_stack, _pop_torch_function_stack, _get_function_stack_at, _len_torch_function_stack,
- _is_torch_function_mode_enabled)
+ _add_docstr,
+ _get_function_stack_at,
+ _has_torch_function,
+ _has_torch_function_unary,
+ _has_torch_function_variadic,
+ _is_torch_function_mode_enabled,
+ _len_torch_function_stack,
+ _pop_torch_function_stack,
+ _push_on_torch_function_stack,
+)
+
__all__ = [
"get_ignored_functions",
@@ -52,7 +59,8 @@
def _disable_user_warnings(
- func: Callable, regex: str = '.*is deprecated, please use.*', module: str = 'torch') -> Callable:
+ func: Callable, regex: str = ".*is deprecated, please use.*", module: str = "torch"
+) -> Callable:
"""
Decorator that temporarily disables ``UserWarning``s for the given ``module`` if the warning message matches the
given ``regex`` pattern.
@@ -75,8 +83,11 @@
@wraps(func)
def wrapper(*args, **kwargs):
with warnings.catch_warnings():
- warnings.filterwarnings("ignore", category=UserWarning, message=regex, module=module)
+ warnings.filterwarnings(
+ "ignore", category=UserWarning, message=regex, module=module
+ )
return func(*args, **kwargs)
+
return wrapper
@@ -470,8 +481,9 @@
torch.batch_norm_update_stats: lambda input, running_mean, running_var, momentum: -1,
torch.bernoulli: lambda input, generator=None, out=None: -1,
torch.bilinear: lambda input1, input2, weight, bias: -1,
- torch.binary_cross_entropy_with_logits: (lambda input, target, weight=None, size_average=None, reduce=None,
- reduction='mean', pos_weight=None: -1),
+ torch.binary_cross_entropy_with_logits: (
+ lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean", pos_weight=None: -1
+ ),
torch.bincount: lambda input, weights=None, minlength=0: -1,
torch.binomial: lambda count, prob, generator=None: -1,
torch.bitwise_and: lambda input, other, out=None: -1,
@@ -489,11 +501,11 @@
torch.cat: lambda tensors, dim=0, out=None: -1,
torch.concat: lambda tensors, dim=0, out=None: -1, # alias for torch.cat
torch.concatenate: lambda tensors, dim=0, out=None: -1, # alias for torch.concatenate
- torch.cdist: lambda x1, x2, p=2.0, compute_mode='use_mm_for_euclid_dist_if_necessary': -1,
+ torch.cdist: lambda x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary": -1,
torch.ceil: lambda input, out=None: -1,
- torch.celu: lambda input, alpha=1., inplace=False: -1,
+ torch.celu: lambda input, alpha=1.0, inplace=False: -1,
torch.chain_matmul: lambda *matrices, out=None: -1,
- torch.channel_shuffle: lambda input, groups : -1,
+ torch.channel_shuffle: lambda input, groups: -1,
torch.cholesky: lambda input, upper=False, out=None: -1,
torch.linalg.cholesky: lambda input, out=None: -1,
torch.linalg.cholesky_ex: lambda input, check_errors=False, out=None: -1,
@@ -528,14 +540,15 @@
torch.conv_transpose3d: lambda input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1: -1,
torch.corrcoef: lambda input: -1,
torch.cos: lambda input, out=None: -1,
- torch.cosine_embedding_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean': -1,
+ torch.cosine_embedding_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1,
torch.cosh: lambda input, out=None: -1,
torch.cosine_similarity: lambda x1, x2, dim=1, eps=1e-8: -1,
torch.count_nonzero: lambda input: -1,
torch.cross: lambda input, other, dim=None, out=None: -1,
torch.linalg.cross: lambda input, other, dim=-1, out=None: -1,
- torch.ctc_loss: (lambda log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean',
- zero_infinity=False: -1),
+ torch.ctc_loss: (
+ lambda log_probs, targets, input_lengths, target_lengths, blank=0, reduction="mean", zero_infinity=False: -1
+ ),
torch.cummax: lambda input, dim, out=None: -1,
torch.cummin: lambda input, dim, out=None: -1,
torch.cumprod: lambda input, dim, out=None, dtype=None: -1,
@@ -570,10 +583,12 @@
torch.linalg.eigh: lambda input, UPLO="L", out=None: -1,
torch.linalg.eigvalsh: lambda input, UPLO="L", out=None: -1,
torch.einsum: lambda equation, *operands: -1,
- torch.embedding: (lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False,
- sparse=False: -1),
- torch.embedding_bag: (lambda input, weight, offsets, max_norm=None, norm_type=2, scale_grad_by_freq=False,
- mode='mean', sparse=False, per_sample_weights=None, padding_idx=None: -1),
+ torch.embedding: (
+ lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False: -1 # noqa: B950
+ ),
+ torch.embedding_bag: (
+ lambda input, weight, offsets, max_norm=None, norm_type=2, scale_grad_by_freq=False, mode="mean", sparse=False, per_sample_weights=None, padding_idx=None: -1 # noqa: B950
+ ),
torch.empty_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
torch.eq: lambda input, other, out=None: -1,
torch.equal: lambda input, other: -1,
@@ -585,14 +600,15 @@
torch.expm1: lambda input, out=None: -1,
torch.fake_quantize_per_channel_affine: lambda input, scale, zero_point, axis, quant_min, quant_max: -1,
torch.fake_quantize_per_tensor_affine: lambda input, scale, zero_point, quant_min, quant_max: -1,
- torch.fused_moving_avg_obs_fake_quant: (lambda x, observer_on, fake_quant_on, averaging_const, running_min,
- running_max, scale, zero_point, quant_min, quant_max, ch_axis,
- per_row_fake_quant=False, symmetric_quant=False: -1),
+ torch.fused_moving_avg_obs_fake_quant: (
+ lambda x, observer_on, fake_quant_on, averaging_const, running_min, running_max, scale, zero_point, quant_min, quant_max, ch_axis, per_row_fake_quant=False, symmetric_quant=False: -1 # noqa: B950
+ ),
torch.fbgemm_linear_fp16_weight: lambda input, packed_weight, bias: -1,
torch.fbgemm_linear_fp16_weight_fp32_activation: lambda input, packed_weight, bias: -1,
- torch.fbgemm_linear_int8_weight: lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1,
- torch.fbgemm_linear_int8_weight_fp32_activation: (lambda input, weight, packed, col_offsets, weight_scale,
- weight_zero_point, bias: -1),
+ torch.fbgemm_linear_int8_weight: lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1, # noqa: B950
+ torch.fbgemm_linear_int8_weight_fp32_activation: (
+ lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1
+ ),
torch.fbgemm_linear_quantize_weight: lambda input: -1,
torch.fbgemm_pack_gemm_matrix_fp16: lambda input: -1,
torch.fbgemm_pack_quantized_matrix: lambda input, a, b: -1,
@@ -630,7 +646,7 @@
torch.fmod: lambda input, other, out=None: -1,
torch.frac: lambda input, out=None: -1,
torch.frexp: lambda input, out=None: -1,
- torch.full_like: lambda input, fill_value, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1,
+ torch.full_like: lambda input, fill_value, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1, # noqa: B950
torch._functional_assert_async: lambda input, msg, dep_token: -1,
torch.lu_unpack: lambda LU_data, LU_pivots, unpack_data=True, unpack_pivots=True: -1,
torch.gather: lambda input, dim, index, out=None, sparse_grad=False: -1,
@@ -653,7 +669,7 @@
torch.greater: lambda input, other, out=None: -1,
torch.hardshrink: lambda input, lambd=0.5: -1,
torch.heaviside: lambda input, values, out=None: -1,
- torch.hinge_embedding_loss: lambda input, target, margin=1.0, size_average=None, reduce=None, reduction='mean': -1,
+ torch.hinge_embedding_loss: lambda input, target, margin=1.0, size_average=None, reduce=None, reduction="mean": -1, # noqa: B950
torch.histc: lambda input, bins=100, min=0, max=0, out=None: -1,
torch.histogram: lambda input, bins=100, min=None, max=None, weight=None, density=False, out=None: -1,
torch.histogramdd: lambda input, bins, range=None, weight=None, density=False: -1,
@@ -677,8 +693,9 @@
torch.isreal: lambda tensor: -1,
torch.isposinf: lambda input, out=None: -1,
torch.isneginf: lambda input, out=None: -1,
- torch.instance_norm: (lambda input, running_mean, running_var, weight, bias, use_input_stats, momentum, eps,
- cudnn_enabled: -1),
+ torch.instance_norm: (
+ lambda input, running_mean, running_var, weight, bias, use_input_stats, momentum, eps, cudnn_enabled: -1
+ ),
torch.int_repr: lambda input: -1,
torch.inverse: lambda input, out=None: -1,
torch.linalg.inv: lambda input, out=None: -1,
@@ -694,9 +711,10 @@
torch.is_signed: lambda input: -1,
torch.isclose: lambda input, other, rtol=1e-05, atol=1e-08, equal_nan=False: -1,
torch.isnan: lambda input: -1,
- torch.istft: (lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True,
- normalized=False, onesided=None, length=None, return_complex=False: -1),
- torch.kl_div: lambda input, target, size_average=None, reduce=None, reduction='mean', log_target=False: -1,
+ torch.istft: (
+ lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, normalized=False, onesided=None, length=None, return_complex=False: -1 # noqa: B950
+ ),
+ torch.kl_div: lambda input, target, size_average=None, reduce=None, reduction="mean", log_target=False: -1,
torch.kron: lambda input, other: -1,
torch.kthvalue: lambda input, k, dim=None, keepdim=False, out=None: -1,
torch.linalg.ldl_factor_ex: lambda input, hermitian=False, check_errors=False, out=None: -1,
@@ -709,8 +727,7 @@
torch.less_equal: lambda input, other, out=None: -1,
torch.lerp: lambda input, end, weight, out=None: -1,
torch.lgamma: lambda input, out=None: -1,
- torch.lobpcg: lambda input, k=None, B=None, X=None, n=None, iK=None, niter=None, tol=None, largest=None, method=None,
- tracker=None, ortho_iparams=None, ortho_fparams=None, ortho_bparams=None: -1,
+ torch.lobpcg: lambda input, k=None, B=None, X=None, n=None, iK=None, niter=None, tol=None, largest=None, method=None, tracker=None, ortho_iparams=None, ortho_fparams=None, ortho_bparams=None: -1, # noqa: B950
torch.log: lambda input, out=None: -1,
torch.log_softmax: lambda input, dim, dtype=None: -1,
torch.log10: lambda input, out=None: -1,
@@ -732,7 +749,7 @@
torch.less: lambda input, other, out=None: -1,
torch.lu: lambda A, pivot=True, get_infos=False, out=None: -1,
torch.lu_solve: lambda b, LU_data, LU_pivots, out=None: -1,
- torch.margin_ranking_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean': -1, # type: ignore[attr-defined] # noqa: B950
+ torch.margin_ranking_loss: lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1, # type: ignore[attr-defined] # noqa: B950
torch.masked_fill: lambda input, mask, value: -1,
torch.masked_scatter: lambda input, mask, source: -1,
torch.masked_select: lambda input, mask, out=None: -1,
@@ -754,8 +771,9 @@
torch.max_pool1d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
torch.max_pool2d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
torch.max_pool3d: lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False: -1,
- torch.max_pool1d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
- return_indices=False, ceil_mode=False: -1),
+ torch.max_pool1d_with_indices: (
+ lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
+ ),
torch.mean: lambda input, dim=None: -1,
torch.nanmean: lambda input, dim=None, keepdim=False, dtype=None, out=None: -1,
torch.median: lambda input, dim=None: -1,
@@ -764,17 +782,21 @@
torch.min: lambda input, out=None: -1,
torch.minimum: lambda input, other, out=None: -1,
torch.fmin: lambda input, other, out=None: -1,
- torch.miopen_batch_norm: (lambda input, weight, bias, running_mean, running_var, training,
- exponential_average_factor, epsilon: -1),
- torch.miopen_convolution: lambda input, weight, bias, padding, stride, dilation, groups, benchmark, deterministic: -1,
+ torch.miopen_batch_norm: (
+ lambda input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon: -1
+ ),
+ torch.miopen_convolution: lambda input, weight, bias, padding, stride, dilation, groups, benchmark, deterministic: -1, # noqa: B950
torch.miopen_convolution_add_relu: lambda input, weight, z, alpha, bias, stride, padding, dilation, groups: -1,
torch.miopen_convolution_relu: lambda input, weight, bias, stride, padding, dilation, groups: -1,
- torch.miopen_convolution_transpose: (lambda input, weight, bias, padding, output_padding, stride, dilation,
- groups, benchmark, deterministic: -1),
- torch.miopen_depthwise_convolution: (lambda input, weight, bias, padding, stride, dilation, groups, benchmark,
- deterministic: -1),
- torch.miopen_rnn: (lambda input, weight, weight_stride0, hx, cx, mode, hidden_size, num_layers, batch_first,
- dropout, train, bidirectional, batch_sizes, dropout_state: -1),
+ torch.miopen_convolution_transpose: (
+ lambda input, weight, bias, padding, output_padding, stride, dilation, groups, benchmark, deterministic: -1
+ ),
+ torch.miopen_depthwise_convolution: (
+ lambda input, weight, bias, padding, stride, dilation, groups, benchmark, deterministic: -1
+ ),
+ torch.miopen_rnn: (
+ lambda input, weight, weight_stride0, hx, cx, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state: -1 # noqa: B950
+ ),
torch.mm: lambda input, mat2, out=None: -1,
torch.mode: lambda input, dim=-1, keepdim=False, out=None: -1,
torch.movedim: lambda input, source, destination: -1,
@@ -793,7 +815,7 @@
torch.native_layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,
torch.native_group_norm: lambda input, weight, bias, N, C, HxW, group, eps: -1,
torch.native_norm: lambda input, p=2, dim=None, keepdim=False, dtype=None: -1,
- torch.native_channel_shuffle: lambda input, groups : -1,
+ torch.native_channel_shuffle: lambda input, groups: -1,
torch.ne: lambda input, other, out=None: -1,
torch.not_equal: lambda input, other, out=None: -1,
torch.neg: lambda input, out=None: -1,
@@ -809,62 +831,76 @@
torch.nn.functional.adaptive_max_pool3d_with_indices: lambda input, output_size, return_indices=False: -1,
torch.nn.functional.affine_grid: lambda theta, size, align_corners=None: -1,
torch.nn.functional.alpha_dropout: lambda input, p=0.5, training=False, inplace=False: -1,
- torch.nn.functional.avg_pool2d: (lambda input, kernel_size, stride=None, padding=0, ceil_mode=False,
- count_include_pad=True, divisor_override=None: -1),
- torch.nn.functional.avg_pool3d: (lambda input, kernel_size, stride=None, padding=0, ceil_mode=False,
- count_include_pad=True, divisor_override=None: -1),
- torch.nn.functional.batch_norm: (lambda input, running_mean, running_var, weight=None, bias=None, training=False,
- momentum=0.1, eps=1e-05: -1),
+ torch.nn.functional.avg_pool2d: (
+ lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None: -1 # noqa: B950
+ ),
+ torch.nn.functional.avg_pool3d: (
+ lambda input, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None: -1 # noqa: B950
+ ),
+ torch.nn.functional.batch_norm: (
+ lambda input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-05: -1
+ ),
torch.nn.functional.bilinear: lambda input1, input2, weight, bias=None: -1,
- torch.nn.functional.binary_cross_entropy: (lambda input, target, weight=None, size_average=None, reduce=None,
- reduction="mean": -1),
- torch.nn.functional.binary_cross_entropy_with_logits: (lambda input, target, weight=None, size_average=None,
- reduce=None, reduction="mean", pos_weight=None: -1),
+ torch.nn.functional.binary_cross_entropy: (
+ lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean": -1
+ ),
+ torch.nn.functional.binary_cross_entropy_with_logits: (
+ lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean", pos_weight=None: -1
+ ),
torch.nn.functional.celu: lambda input, alpha=1.0, inplace=False: -1,
- torch.nn.functional.cosine_embedding_loss: (lambda input1, input2, target, margin=0, size_average=None,
- reduce=None, reduction='mean': -1),
- torch.nn.functional.cross_entropy: (lambda input, target, weight=None, size_average=None, ignore_index=-100,
- reduce=None, reduction="mean", label_smoothing=0.0: -1),
- torch.nn.functional.ctc_loss: (lambda log_probs, targets, input_lengths, target_lengths, blank=0,
- reduction='mean', zero_infinity=False: -1),
+ torch.nn.functional.cosine_embedding_loss: (
+ lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1
+ ),
+ torch.nn.functional.cross_entropy: (
+ lambda input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean", label_smoothing=0.0: -1 # noqa: B950
+ ),
+ torch.nn.functional.ctc_loss: (
+ lambda log_probs, targets, input_lengths, target_lengths, blank=0, reduction="mean", zero_infinity=False: -1
+ ),
torch.nn.functional.dropout: lambda input, p=0.5, training=True, inplace=False: -1,
torch.nn.functional.dropout1d: lambda input, p=0.5, training=True, inplace=False: -1,
torch.nn.functional.dropout2d: lambda input, p=0.5, training=True, inplace=False: -1,
torch.nn.functional.dropout3d: lambda input, p=0.5, training=True, inplace=False: -1,
torch.nn.functional.elu: lambda input, alpha=1.0, inplace=False: -1,
- torch.nn.functional.embedding: (lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0,
- scale_grad_by_freq=False, sparse=False: -1),
- torch.nn.functional.embedding_bag: (lambda input, weight, offsets=None, max_norm=None, norm_type=2,
- scale_grad_by_freq=False, mode='mean', sparse=False, per_sample_weights=None,
- include_last_offset=False, padding_idx=None: -1),
+ torch.nn.functional.embedding: (
+ lambda input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False: -1 # noqa: B950
+ ),
+ torch.nn.functional.embedding_bag: (
+ lambda input, weight, offsets=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, mode="mean", sparse=False, per_sample_weights=None, include_last_offset=False, padding_idx=None: -1 # noqa: B950
+ ),
torch.nn.functional.feature_alpha_dropout: lambda input, p=0.5, training=False, inplace=False: -1,
torch.nn.functional.fold: lambda input, output_size, kernel_size, dilation=1, padding=0, stride=1: -1,
- torch.nn.functional.fractional_max_pool2d: (lambda input, kernel_size, output_size=None, output_ratio=None,
- return_indices=False, _random_samples=None: -1),
+ torch.nn.functional.fractional_max_pool2d: (
+ lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 # noqa: B950
+ ),
torch.nn.functional.fractional_max_pool2d_with_indices: (
- lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False,
- _random_samples=None: -1),
- torch.nn.functional.fractional_max_pool3d: (lambda input, kernel_size, output_size=None, output_ratio=None,
- return_indices=False, _random_samples=None: -1),
+ lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 # noqa: B950
+ ),
+ torch.nn.functional.fractional_max_pool3d: (
+ lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 # noqa: B950
+ ),
torch.nn.functional.fractional_max_pool3d_with_indices: (
- lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False,
- _random_samples=None: -1),
- torch.nn.functional.gaussian_nll_loss: lambda input, target, var, full=False, eps=1e-06, reduction='mean': -1,
- torch.nn.functional.gelu: lambda input, approximate='none': -1,
+ lambda input, kernel_size, output_size=None, output_ratio=None, return_indices=False, _random_samples=None: -1 # noqa: B950
+ ),
+ torch.nn.functional.gaussian_nll_loss: lambda input, target, var, full=False, eps=1e-06, reduction="mean": -1,
+ torch.nn.functional.gelu: lambda input, approximate="none": -1,
torch.nn.functional.glu: lambda input, dim=-1: -1,
- torch.nn.functional.grid_sample: lambda input, grid, mode='bilinear', padding_mode='zeros', align_corners=None: -1,
+ torch.nn.functional.grid_sample: lambda input, grid, mode="bilinear", padding_mode="zeros", align_corners=None: -1, # noqa: B950
torch.nn.functional.group_norm: lambda input, num_groups, weight=None, bias=None, eps=1e-05: -1,
torch.nn.functional.gumbel_softmax: lambda logits, tau=1, hard=False, eps=1e-10, dim=-1: -1,
torch.nn.functional.hardshrink: lambda input, lambd=0.5: -1,
- torch.nn.functional.hardtanh: lambda input, min_val=-1., max_val=1., inplace=False: -1,
- torch.nn.functional.hinge_embedding_loss: (lambda input, target, margin=1.0, size_average=None, reduce=None,
- reduction='mean': -1),
- torch.nn.functional.instance_norm: (lambda input, running_mean=None, running_var=None, weight=None, bias=None,
- use_input_stats=True, momentum=0.1, eps=1e-05: -1),
- torch.nn.functional.interpolate: (lambda input, size=None, scale_factor=None, mode='nearest', align_corners=None,
- recompute_scale_factor=None, antialias=False: -1),
- torch.nn.functional.kl_div: lambda input, target, size_average=None, reduce=None, reduction='mean', log_target=False: -1,
- torch.nn.functional.l1_loss: lambda input, target, size_average=None, reduce=None, reduction='mean': -1,
+ torch.nn.functional.hardtanh: lambda input, min_val=-1.0, max_val=1.0, inplace=False: -1,
+ torch.nn.functional.hinge_embedding_loss: (
+ lambda input, target, margin=1.0, size_average=None, reduce=None, reduction="mean": -1
+ ),
+ torch.nn.functional.instance_norm: (
+ lambda input, running_mean=None, running_var=None, weight=None, bias=None, use_input_stats=True, momentum=0.1, eps=1e-05: -1 # noqa: B950
+ ),
+ torch.nn.functional.interpolate: (
+ lambda input, size=None, scale_factor=None, mode="nearest", align_corners=None, recompute_scale_factor=None, antialias=False: -1 # noqa: B950
+ ),
+ torch.nn.functional.kl_div: lambda input, target, size_average=None, reduce=None, reduction="mean", log_target=False: -1, # noqa: B950
+ torch.nn.functional.l1_loss: lambda input, target, size_average=None, reduce=None, reduction="mean": -1,
torch.nn.functional.layer_norm: lambda input, normalized_shape, weight=None, bias=None, eps=1e-05: -1,
torch.nn.functional.leaky_relu: lambda input, negative_slope=0.01, inplace=False: -1,
torch.nn.functional.linear: lambda input, weight, bias=None: -1,
@@ -874,55 +910,65 @@
torch.nn.functional.lp_pool1d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
torch.nn.functional.lp_pool2d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
torch.nn.functional.lp_pool3d: lambda input, norm_type, kernel_size, stride=None, ceil_mode=False: -1,
- torch.nn.functional.margin_ranking_loss: (lambda input1, input2, target, margin=0, size_average=None,
- reduce=None, reduction='mean': -1),
- torch.nn.functional.max_pool1d: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
- ceil_mode=False, return_indices=False: -1),
- torch.nn.functional.max_pool1d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
- return_indices=False, ceil_mode=False: -1),
- torch.nn.functional.max_pool2d: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
- ceil_mode=False, return_indices=False: -1),
- torch.nn.functional.max_pool2d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
- return_indices=False, ceil_mode=False: -1),
- torch.nn.functional.max_pool3d: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
- return_indices=False, ceil_mode=False: -1),
- torch.nn.functional.max_pool3d_with_indices: (lambda input, kernel_size, stride=None, padding=0, dilation=1,
- return_indices=False, ceil_mode=False: -1),
- torch.nn.functional.max_unpool1d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1,
- torch.nn.functional.max_unpool2d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1,
- torch.nn.functional.max_unpool3d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1,
- torch.nn.functional.mse_loss: lambda input, target, size_average=None, reduce=None, reduction='mean': -1,
+ torch.nn.functional.margin_ranking_loss: (
+ lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction="mean": -1
+ ),
+ torch.nn.functional.max_pool1d: (
+ lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False: -1
+ ),
+ torch.nn.functional.max_pool1d_with_indices: (
+ lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
+ ),
+ torch.nn.functional.max_pool2d: (
+ lambda input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False, return_indices=False: -1
+ ),
+ torch.nn.functional.max_pool2d_with_indices: (
+ lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
+ ),
+ torch.nn.functional.max_pool3d: (
+ lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
+ ),
+ torch.nn.functional.max_pool3d_with_indices: (
+ lambda input, kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False: -1
+ ),
+ torch.nn.functional.max_unpool1d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, # noqa: B950
+ torch.nn.functional.max_unpool2d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, # noqa: B950
+ torch.nn.functional.max_unpool3d: lambda input, indices, kernel_size, stride=None, padding=0, output_size=None: -1, # noqa: B950
+ torch.nn.functional.mse_loss: lambda input, target, size_average=None, reduce=None, reduction="mean": -1,
torch.nn.functional.multi_head_attention_forward: (
- lambda query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v,
- add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training=True, key_padding_mask=None,
- need_weights=True, attn_mask=None, use_separate_proj_weight=False, q_proj_weight=None, k_proj_weight=None,
- v_proj_weight=None, static_k=None, static_v=None, average_attn_weights=None, is_causal=False: -1),
- torch.nn.functional.multi_margin_loss: (lambda input, target, p=1, margin=1.0, weight=None, size_average=None,
- reduce=None, reduction='mean': -1),
- torch.nn.functional.multilabel_margin_loss: (lambda input, target, size_average=None, reduce=None,
- reduction='mean': -1),
- torch.nn.functional.multilabel_soft_margin_loss: (lambda input, target, weight=None, size_average=None,
- reduce=None, reduction='mean': -1),
- torch.nn.functional.nll_loss: (lambda input, target, weight=None, size_average=None, ignore_index=-100,
- reduce=None, reduction='mean': -1),
+ lambda query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training=True, key_padding_mask=None, need_weights=True, attn_mask=None, use_separate_proj_weight=False, q_proj_weight=None, k_proj_weight=None, v_proj_weight=None, static_k=None, static_v=None, average_attn_weights=None, is_causal=False: -1 # noqa: B950
+ ),
+ torch.nn.functional.multi_margin_loss: (
+ lambda input, target, p=1, margin=1.0, weight=None, size_average=None, reduce=None, reduction="mean": -1
+ ),
+ torch.nn.functional.multilabel_margin_loss: (
+ lambda input, target, size_average=None, reduce=None, reduction="mean": -1
+ ),
+ torch.nn.functional.multilabel_soft_margin_loss: (
+ lambda input, target, weight=None, size_average=None, reduce=None, reduction="mean": -1
+ ),
+ torch.nn.functional.nll_loss: (
+ lambda input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean": -1
+ ),
torch.nn.functional.normalize: lambda input, p=2, dim=1, eps=1e-12, out=None: -1,
torch.nn.functional.one_hot: lambda tensor, num_classes=-1: -1,
- torch.nn.functional.pad: lambda input, pad, mode='constant', value=0: -1,
+ torch.nn.functional.pad: lambda input, pad, mode="constant", value=0: -1,
torch.nn.functional.pairwise_distance: lambda x1, x2, p=2.0, eps=1e-06, keepdim=False: -1,
- torch.nn.functional.poisson_nll_loss: (lambda input, target, log_input=True, full=False, size_average=None,
- eps=1e-08, reduce=None, reduction='mean': -1),
+ torch.nn.functional.poisson_nll_loss: (
+ lambda input, target, log_input=True, full=False, size_average=None, eps=1e-08, reduce=None, reduction="mean": -1 # noqa: B950
+ ),
torch.nn.functional.prelu: lambda input, weight: -1,
torch.nn.functional.relu: lambda input, inplace=False: -1,
torch.nn.functional.relu6: lambda input, inplace=False: -1,
torch.nn.functional.rms_norm: lambda input, normalized_shape, weight=None, eps=1e-6: -1,
- torch.nn.functional.rrelu: lambda input, lower=0.125, upper=0.3333333333333333, training=False, inplace=False: -1,
+ torch.nn.functional.rrelu: lambda input, lower=0.125, upper=0.3333333333333333, training=False, inplace=False: -1, # noqa: B950
torch.nn.functional.selu: lambda input, inplace=False: -1,
torch.nn.functional.silu: lambda input, inplace=False: -1,
torch.nn.functional.mish: lambda input, inplace=False: -1,
torch.nn.functional.scaled_dot_product_attention: lambda query, key, value, attn_mask=None, dropout_p=0.0: -1,
- torch.nn.functional.smooth_l1_loss: lambda input, target, size_average=None, reduce=None, reduction='mean', beta=1.: -1,
- torch.nn.functional.huber_loss: lambda input, target, reduction='mean', delta=1.: -1,
- torch.nn.functional.soft_margin_loss: lambda input, target, size_average=None, reduce=None, reduction='mean': -1,
+ torch.nn.functional.smooth_l1_loss: lambda input, target, size_average=None, reduce=None, reduction="mean", beta=1.0: -1, # noqa: B950
+ torch.nn.functional.huber_loss: lambda input, target, reduction="mean", delta=1.0: -1,
+ torch.nn.functional.soft_margin_loss: lambda input, target, size_average=None, reduce=None, reduction="mean": -1, # noqa: B950
torch.nn.functional.softmax: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
torch.nn.functional.softmin: lambda input, dim=None, _stacklevel=3, dtype=None: -1,
torch.nn.functional.softplus: lambda input, beta=1, threshold=20: -1,
@@ -930,25 +976,29 @@
torch.nn.functional.softsign: lambda input: -1,
torch.nn.functional.tanhshrink: lambda input: -1,
torch.nn.functional.threshold: lambda input, threshold, value, inplace=False: -1,
- torch.nn.functional.triplet_margin_loss: (lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06,
- swap=False, size_average=None, reduce=None, reduction='mean': -1),
- torch.nn.functional.triplet_margin_with_distance_loss: (lambda anchor, positive, negative, *,
- distance_function=None, margin=1.0,
- swap=False, reduction='mean': -1),
+ torch.nn.functional.triplet_margin_loss: (
+ lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False, size_average=None, reduce=None, reduction="mean": -1 # noqa: B950
+ ),
+ torch.nn.functional.triplet_margin_with_distance_loss: (
+ lambda anchor, positive, negative, *, distance_function=None, margin=1.0, swap=False, reduction="mean": -1
+ ),
torch.nn.functional.unfold: lambda input, kernel_size, dilation=1, padding=0, stride=1: -1,
- torch.nn.init.uniform_: lambda tensor, a=0., b=1., generator=None: -1,
- torch.nn.init.normal_: lambda tensor, mean=0., std=1., generator=None: -1,
+ torch.nn.init.uniform_: lambda tensor, a=0.0, b=1.0, generator=None: -1,
+ torch.nn.init.normal_: lambda tensor, mean=0.0, std=1.0, generator=None: -1,
torch.nn.init.constant_: lambda tensor, val: -1,
- torch.nn.init.kaiming_uniform_: lambda tensor, a=0, mode='fan_in', nonlinearity='leaky_relu', generator=None: -1,
+ torch.nn.init.kaiming_uniform_: lambda tensor, a=0, mode="fan_in", nonlinearity="leaky_relu", generator=None: -1, # noqa: B950
torch.nonzero: lambda input, as_tuple=False: -1,
torch.nonzero_static: lambda input, *, size, fill_value=-1: -1,
torch.argwhere: lambda input: -1,
- torch.norm: lambda input, p='fro', dim=None, keepdim=False, out=None, dtype=None: -1,
+ torch.norm: lambda input, p="fro", dim=None, keepdim=False, out=None, dtype=None: -1,
torch.linalg.norm: lambda input, ord=None, dim=None, keepdim=False, out=None, dtype=None: -1,
torch.linalg.vector_norm: lambda input, ord=2, dim=None, keepdim=False, out=None, dtype=None: -1,
- torch.linalg.matrix_norm: lambda input, ord='fro', dim=(-2, -1), keepdim=False, out=None, dtype=None: -1,
+ torch.linalg.matrix_norm: lambda input, ord="fro", dim=(
+ -2,
+ -1,
+ ), keepdim=False, out=None, dtype=None: -1,
torch.norm_except_dim: lambda v, pow=2, dim=0: -1,
- torch.nuclear_norm: lambda input, p='fro', dim=None, keepdim=False, out=None, dtype=None: -1,
+ torch.nuclear_norm: lambda input, p="fro", dim=None, keepdim=False, out=None, dtype=None: -1,
torch.numel: lambda input: -1,
torch.orgqr: lambda input, tau: -1,
torch.ormqr: lambda input, input2, input3, left=True, transpose=False: -1,
@@ -975,28 +1025,43 @@
torch.q_scale: lambda input: -1,
torch.q_zero_point: lambda input: -1,
torch.qr: lambda input, some=True, out=None: -1,
- torch.linalg.qr: lambda input, mode='reduced', out=None: -1,
- torch.quantile: lambda input, q, dim=None, keepdim=False, interpolation='linear', out=None: -1,
- torch.nanquantile: lambda input, q, dim=None, keepdim=False, interpolation='linear', out=None: -1,
+ torch.linalg.qr: lambda input, mode="reduced", out=None: -1,
+ torch.quantile: lambda input, q, dim=None, keepdim=False, interpolation="linear", out=None: -1,
+ torch.nanquantile: lambda input, q, dim=None, keepdim=False, interpolation="linear", out=None: -1,
torch.quantize_per_channel: lambda input, scales, zero_points, axis, dtype: -1,
torch.quantize_per_tensor: lambda input, scale, zero_point, dtype: -1,
torch.quantize_per_tensor_dynamic: lambda input, dtype, reduce_range: -1,
torch.quantized_batch_norm: lambda input, weight, bias, mean, var, eps, output_scale, output_zero_point: -1,
- torch.quantized_gru_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
- col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
-
- torch.quantized_lstm_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
- col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
- torch.quantized_max_pool1d: (lambda input, kernel_size, stride=tuple(), padding=(0,),
- dilation=(1,), ceil_mode=False: -1),
- torch.quantized_max_pool2d: (lambda input, kernel_size, stride=tuple(), padding=(0, 0),
- dilation=(1, 1), ceil_mode=False: -1),
- torch.quantized_max_pool3d: (lambda input, kernel_size, stride=tuple(), padding=(0, 0, 0),
- dilation=(1, 1, 1), ceil_mode=False: -1),
- torch.quantized_rnn_relu_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
- col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
- torch.quantized_rnn_tanh_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
- col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
+ torch.quantized_gru_cell: (
+ lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950
+ ),
+ torch.quantized_lstm_cell: (
+ lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950
+ ),
+ torch.quantized_max_pool1d: (
+ lambda input, kernel_size, stride=tuple(), padding=(0,), dilation=(
+ 1,
+ ), ceil_mode=False: -1
+ ),
+ torch.quantized_max_pool2d: (
+ lambda input, kernel_size, stride=tuple(), padding=(0, 0), dilation=(
+ 1,
+ 1,
+ ), ceil_mode=False: -1
+ ),
+ torch.quantized_max_pool3d: (
+ lambda input, kernel_size, stride=tuple(), padding=(0, 0, 0), dilation=(
+ 1,
+ 1,
+ 1,
+ ), ceil_mode=False: -1
+ ),
+ torch.quantized_rnn_relu_cell: (
+ lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950
+ ),
+ torch.quantized_rnn_tanh_cell: (
+ lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1 # noqa: B950
+ ),
torch.rad2deg: lambda input, out=None: -1,
torch.rand_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
torch.randint_like: lambda input, high, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1,
@@ -1014,16 +1079,16 @@
torch.repeat_interleave: lambda input, dim=None: -1,
torch.reshape: lambda input, shape: -1,
torch.rms_norm: lambda input, normalized_shape, weight=None, eps=1e-6: -1,
- torch.rnn_relu: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1,
+ torch.rnn_relu: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1, # noqa: B950
torch.rnn_relu_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
- torch.rnn_tanh: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1,
+ torch.rnn_tanh: lambda input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first: -1, # noqa: B950
torch.rnn_tanh_cell: lambda input, hx, w_ih, w_hh, b_ih=None, b_hh=None: -1,
torch.roll: lambda input, shifts, dims=None: -1,
torch.rot90: lambda input, k=1, dims=(0, 1): -1,
torch.round: lambda input, out=None: -1,
torch.row_stack: lambda tensors, out=None: -1, # alias for torch.vstack
torch._rowwise_prune: (lambda weight, mask, compressed_indices_dtype: -1),
- torch.rrelu: lambda input, lower=1. / 8, upper=1. / 3, training=False, inplace=False: -1,
+ torch.rrelu: lambda input, lower=1.0 / 8, upper=1.0 / 3, training=False, inplace=False: -1,
torch.rsqrt: lambda input, out=None: -1,
torch.rsub: lambda input, other, alpha=1: -1,
torch.saddmm: lambda input, mat1, mat2, beta=1, alpha=1, out=None: -1,
@@ -1031,7 +1096,7 @@
torch.scatter_add: lambda input, dim, index, src: -1,
torch.scatter_reduce: lambda input, dim, index, src, reduce, include_self=True: -1,
torch.searchsorted: lambda sorted_sequence, input, out_int32=False, right=False, out=None: -1,
- torch._segment_reduce: lambda data, reduce="max", lengths=None, indices=None, offsets=None, axis=0, unsafe=False: -1,
+ torch._segment_reduce: lambda data, reduce="max", lengths=None, indices=None, offsets=None, axis=0, unsafe=False: -1, # noqa: B950
torch.select: lambda input, dim, index: -1,
torch.select_scatter: lambda input, src, dim, index: -1,
torch.slice_inverse: lambda input, src, dim=0, start=None, end=None, step=1: -1,
@@ -1061,8 +1126,9 @@
torch.stack: lambda tensors, dim=0, out=None: -1,
torch.std: lambda input, dim=None: -1,
torch.std_mean: lambda input, dim=None: -1,
- torch.stft: (lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True,
- pad_mode='reflect', normalized=False, onesided=True, return_complex=None: -1),
+ torch.stft: (
+ lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, pad_mode="reflect", normalized=False, onesided=True, return_complex=None: -1 # noqa: B950
+ ),
torch.sub: lambda input, other, out=None: -1,
torch.subtract: lambda input, other, out=None: -1,
torch.sum: lambda input, dim=None: -1,
@@ -1164,9 +1230,9 @@
torch.triangular_solve: lambda input, A, upper=True, transpose=False, unitriangular=False: -1,
torch.linalg.solve_triangular: lambda input, B, upper, left=True, unitriangular=False: -1,
torch.tril: lambda input, diagonal=0, out=None: -1,
- torch.triplet_margin_loss: (lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False,
-
- size_average=None, reduce=None, reduction='mean': -1),
+ torch.triplet_margin_loss: (
+ lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06, swap=False, size_average=None, reduce=None, reduction="mean": -1 # noqa: B950
+ ),
torch.triu: lambda input, diagonal=0, out=None: -1,
torch.true_divide: lambda input, other: -1,
torch.trunc: lambda input, out=None: -1,
@@ -1436,10 +1502,16 @@
torch.linalg.lstsq: lambda self, b, cond=None, driver=None: -1,
}
- privateuse1_backend_name = torch.utils.backend_registration._privateuse1_backend_name
+ privateuse1_backend_name = (
+ torch.utils.backend_registration._privateuse1_backend_name
+ )
if hasattr(Tensor, privateuse1_backend_name):
- ret[getattr(Tensor, privateuse1_backend_name)] = lambda self, device=None, non_blocking=False, **kwargs: -1
- ret[getattr(Tensor, f'is_{privateuse1_backend_name}').__get__] = lambda self: -1 # noqa: B009
+ ret[
+ getattr(Tensor, privateuse1_backend_name)
+ ] = lambda self, device=None, non_blocking=False, **kwargs: -1
+ ret[
+ getattr(Tensor, f"is_{privateuse1_backend_name}").__get__
+ ] = lambda self: -1 # noqa: B009
ret2 = {}
ignored = get_ignored_functions()
@@ -1457,12 +1529,10 @@
if k.__name__.startswith("bitwise_"):
# bitwise_<op> have dunder methods of the form __<op>__
# And so on.
- subname = k.__name__[len("bitwise_"):]
- names.extend([
- "__" + subname + "__",
- "__i" + subname + "__",
- "__r" + subname + "__"
- ])
+ subname = k.__name__[len("bitwise_") :]
+ names.extend(
+ ["__" + subname + "__", "__i" + subname + "__", "__r" + subname + "__"]
+ )
for name in names:
func = getattr(Tensor, name, None)
@@ -1472,6 +1542,7 @@
ret.update(ret2)
return ret
+
def wrap_torch_function(dispatcher: Callable):
"""Wraps a given function with ``__torch_function__`` -related functionality.
@@ -1495,6 +1566,7 @@
>>> def func(a): # This will make func dispatchable by __torch_function__
... return a + 0
"""
+
def inner(func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
@@ -1508,7 +1580,10 @@
return inner
-def _get_overloaded_args(relevant_args: Iterable[Any], get_type_fn: Callable[[Any], Type] = None) -> List[Any]:
+
+def _get_overloaded_args(
+ relevant_args: Iterable[Any], get_type_fn: Callable[[Any], Type] = None
+) -> List[Any]:
"""Returns a list of arguments on which to call __torch_function__.
Checks arguments in relevant_args for __torch_function__ implementations,
@@ -1559,8 +1634,11 @@
#
# NB: Important to exclude _disabled_torch_function_impl, otherwise
# https://github.com/pytorch/pytorch/issues/64687
- if (arg_type not in overloaded_types and hasattr(arg_type, '__torch_function__') and
- arg_type.__torch_function__ != torch._C._disabled_torch_function_impl):
+ if (
+ arg_type not in overloaded_types
+ and hasattr(arg_type, "__torch_function__")
+ and arg_type.__torch_function__ != torch._C._disabled_torch_function_impl
+ ):
# Create lists explicitly for the first type (usually the only one
# done) to avoid setting up the iterator for overloaded_args.
if overloaded_types:
@@ -1581,7 +1659,8 @@
def handle_torch_function(
- public_api: Callable, relevant_args: Iterable[Any], *args, **kwargs) -> Any:
+ public_api: Callable, relevant_args: Iterable[Any], *args, **kwargs
+) -> Any:
"""Implement a function with checks for ``__torch_function__`` overrides.
See torch::autograd::handle_torch_function for the equivalent of this
@@ -1636,11 +1715,16 @@
# This call needs to become a classmethod call in the future.
# See https://github.com/pytorch/pytorch/issues/63767
torch_func_method = overloaded_arg.__torch_function__
- if hasattr(torch_func_method, "__self__") and torch_func_method.__self__ is overloaded_arg and \
- torch_func_method is not torch._C._disabled_torch_function_impl:
- warnings.warn("Defining your `__torch_function__ as a plain method is deprecated and "
- "will be an error in future, please define it as a classmethod.",
- DeprecationWarning)
+ if (
+ hasattr(torch_func_method, "__self__")
+ and torch_func_method.__self__ is overloaded_arg
+ and torch_func_method is not torch._C._disabled_torch_function_impl
+ ):
+ warnings.warn(
+ "Defining your `__torch_function__ as a plain method is deprecated and "
+ "will be an error in future, please define it as a classmethod.",
+ DeprecationWarning,
+ )
# Use `public_api` instead of `implementation` so __torch_function__
# implementations can do equality/identity comparisons.
@@ -1649,15 +1733,16 @@
if result is not NotImplemented:
return result
- func_name = f'{public_api.__module__}.{public_api.__name__}'
+ func_name = f"{public_api.__module__}.{public_api.__name__}"
msg = (
f"no implementation found for '{func_name}' on types that implement "
- f'__torch_function__: {[type(arg) for arg in overloaded_args]}'
+ f"__torch_function__: {[type(arg) for arg in overloaded_args]}"
)
if _is_torch_function_mode_enabled():
msg += f" nor in mode {_get_current_function_mode()}"
raise TypeError(msg)
+
has_torch_function = _add_docstr(
_has_torch_function,
r"""Check for __torch_function__ implementations in the elements of an iterable
@@ -1678,7 +1763,7 @@
________
torch.is_tensor_like
Checks if something is a Tensor-like, including an exact ``Tensor``.
- """
+ """,
)
has_torch_function_unary = _add_docstr(
@@ -1689,7 +1774,7 @@
call:
`has_torch_function_unary(t)`
which skips unnecessary packing and unpacking work.
- """
+ """,
)
has_torch_function_variadic = _add_docstr(
@@ -1703,11 +1788,14 @@
call:
`has_torch_function_variadic(a, b)`
which skips unnecessary packing and unpacking work.
- """
+ """,
)
+
@functools.lru_cache(None)
-def _get_overridable_functions() -> Tuple[Dict[Any, List[Callable]], Dict[Callable, str]]:
+def _get_overridable_functions() -> (
+ Tuple[Dict[Any, List[Callable]], Dict[Callable, str]]
+):
overridable_funcs = collections.defaultdict(list)
index = {}
tested_namespaces = [
@@ -1725,21 +1813,21 @@
ignore = False
# ignore private functions or functions that are deleted in torch.__init__
if namespace is not torch.Tensor:
- if func_name.startswith('__'):
+ if func_name.startswith("__"):
continue
- elif func_name.startswith('_'):
+ elif func_name.startswith("_"):
ignore = True
- elif func_name.endswith('_'):
+ elif func_name.endswith("_"):
ignore = True
elif not func_name[0].islower():
ignore = True
- elif func_name == 'unique_dim':
+ elif func_name == "unique_dim":
continue
else:
func = getattr(namespace, func_name)
if getattr(object, func_name, None) == func:
continue
- if func_name == '__weakref__':
+ if func_name == "__weakref__":
continue
func = getattr(namespace, func_name)
if namespace is torch.Tensor and getattr(object, func_name, None) == func:
@@ -1757,9 +1845,13 @@
if ignore:
continue
if func.__get__ in get_ignored_functions():
- msg = ("{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
- "but still has an explicit override")
- assert func.__get__ not in get_testing_overrides(), msg.format(namespace, func.__name__)
+ msg = (
+ "{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
+ "but still has an explicit override"
+ )
+ assert func.__get__ not in get_testing_overrides(), msg.format(
+ namespace, func.__name__
+ )
continue
else:
overridable_funcs[func].append(func.__get__)
@@ -1775,13 +1867,18 @@
# cannot be overriden by __torch_function__
if func in get_ignored_functions():
- msg = ("{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
- "but still has an explicit override")
- assert func not in get_testing_overrides(), msg.format(namespace, func.__name__)
+ msg = (
+ "{}.{} is in the tuple returned by torch._overrides.get_ignored_functions "
+ "but still has an explicit override"
+ )
+ assert func not in get_testing_overrides(), msg.format(
+ namespace, func.__name__
+ )
continue
overridable_funcs[namespace].append(func)
return overridable_funcs, index
+
@_disable_user_warnings
def get_overridable_functions() -> Dict[Any, List[Callable]]:
"""List functions that are overridable via __torch_function__
@@ -1794,6 +1891,7 @@
"""
return _get_overridable_functions()[0]
+
@_disable_user_warnings
def resolve_name(f):
"""Get a human readable string name for a function passed to
@@ -1814,13 +1912,15 @@
return str(f)
return _get_overridable_functions()[1].get(f)
+
@functools.lru_cache(None)
def _get_tensor_methods() -> Set[Callable]:
- """ Returns a set of the overridable methods on ``torch.Tensor`` """
+ """Returns a set of the overridable methods on ``torch.Tensor``"""
overridable_funcs = get_overridable_functions()
methods = set(overridable_funcs[torch.Tensor])
return methods
+
@_disable_user_warnings
def is_tensor_method_or_property(func: Callable) -> bool:
"""
@@ -1846,6 +1946,7 @@
"""
return func in _get_tensor_methods() or func.__name__ == "__get__"
+
def is_tensor_like(inp):
"""
Returns ``True`` if the passed-in input is a Tensor-like.
@@ -1882,6 +1983,7 @@
"""
return type(inp) is torch.Tensor or hasattr(inp, "__torch_function__")
+
class TorchFunctionMode:
"""
A ``TorchFunctionMode`` allows you to override the meaning of all
@@ -1912,6 +2014,7 @@
``enable_torch_function_mode(self, replace=self.inner)`` to make PyTorch
API self-referential (beware of infinite loops, in this case!)
"""
+
inner: "TorchFunctionMode"
# Force metaclass to generate constructor at the base of the hierarchy
@@ -1930,7 +2033,9 @@
@classmethod
def push(cls, *args, **kwargs):
- warnings.warn("`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`")
+ warnings.warn(
+ "`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`"
+ )
instance = cls(*args, **kwargs)
return instance
@@ -1944,6 +2049,7 @@
stack_len = _len_torch_function_stack()
return [_get_function_stack_at(i) for i in range(stack_len)]
+
def _push_mode(mode):
_push_on_torch_function_stack(mode)
@@ -1961,6 +2067,7 @@
finally:
_push_mode(old)
+
class BaseTorchFunctionMode(TorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
diff --git a/torch/quasirandom.py b/torch/quasirandom.py
index a121801..509be4f 100644
--- a/torch/quasirandom.py
+++ b/torch/quasirandom.py
@@ -1,7 +1,8 @@
# mypy: allow-untyped-defs
-import torch
from typing import Optional
+import torch
+
class SobolEngine:
r"""
@@ -48,8 +49,10 @@
def __init__(self, dimension, scramble=False, seed=None):
if dimension > self.MAXDIM or dimension < 1:
- raise ValueError("Supported range of dimensionality "
- f"for SobolEngine is [1, {self.MAXDIM}]")
+ raise ValueError(
+ "Supported range of dimensionality "
+ f"for SobolEngine is [1, {self.MAXDIM}]"
+ )
self.seed = seed
self.scramble = scramble
@@ -57,7 +60,9 @@
cpu = torch.device("cpu")
- self.sobolstate = torch.zeros(dimension, self.MAXBIT, device=cpu, dtype=torch.long)
+ self.sobolstate = torch.zeros(
+ dimension, self.MAXBIT, device=cpu, dtype=torch.long
+ )
torch._sobol_engine_initialize_state_(self.sobolstate, self.dimension)
if not self.scramble:
@@ -66,11 +71,15 @@
self._scramble()
self.quasi = self.shift.clone(memory_format=torch.contiguous_format)
- self._first_point = (self.quasi / 2 ** self.MAXBIT).reshape(1, -1)
+ self._first_point = (self.quasi / 2**self.MAXBIT).reshape(1, -1)
self.num_generated = 0
- def draw(self, n: int = 1, out: Optional[torch.Tensor] = None,
- dtype: Optional[torch.dtype] = None) -> torch.Tensor:
+ def draw(
+ self,
+ n: int = 1,
+ out: Optional[torch.Tensor] = None,
+ dtype: Optional[torch.dtype] = None,
+ ) -> torch.Tensor:
r"""
Function to draw a sequence of :attr:`n` points from a Sobol sequence.
Note that the samples are dependent on the previous samples. The size
@@ -92,12 +101,22 @@
result = self._first_point.to(dtype)
else:
result, self.quasi = torch._sobol_engine_draw(
- self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated, dtype=dtype,
+ self.quasi,
+ n - 1,
+ self.sobolstate,
+ self.dimension,
+ self.num_generated,
+ dtype=dtype,
)
result = torch.cat((self._first_point.to(dtype), result), dim=-2)
else:
result, self.quasi = torch._sobol_engine_draw(
- self.quasi, n, self.sobolstate, self.dimension, self.num_generated - 1, dtype=dtype,
+ self.quasi,
+ n,
+ self.sobolstate,
+ self.dimension,
+ self.num_generated - 1,
+ dtype=dtype,
)
self.num_generated += n
@@ -108,8 +127,12 @@
return result
- def draw_base2(self, m: int, out: Optional[torch.Tensor] = None,
- dtype: Optional[torch.dtype] = None) -> torch.Tensor:
+ def draw_base2(
+ self,
+ m: int,
+ out: Optional[torch.Tensor] = None,
+ dtype: Optional[torch.dtype] = None,
+ ) -> torch.Tensor:
r"""
Function to draw a sequence of :attr:`2**m` points from a Sobol sequence.
Note that the samples are dependent on the previous samples. The size
@@ -122,15 +145,16 @@
returned tensor.
Default: ``None``
"""
- n = 2 ** m
+ n = 2**m
total_n = self.num_generated + n
if not (total_n & (total_n - 1) == 0):
- raise ValueError("The balance properties of Sobol' points require "
- f"n to be a power of 2. {self.num_generated} points have been "
- f"previously generated, then: n={self.num_generated}+2**{m}={total_n}. "
- "If you still want to do this, please use "
- "'SobolEngine.draw()' instead."
- )
+ raise ValueError(
+ "The balance properties of Sobol' points require "
+ f"n to be a power of 2. {self.num_generated} points have been "
+ f"previously generated, then: n={self.num_generated}+2**{m}={total_n}. "
+ "If you still want to do this, please use "
+ "'SobolEngine.draw()' instead."
+ )
return self.draw(n=n, out=out, dtype=dtype)
def reset(self):
@@ -151,9 +175,13 @@
n (Int): The number of steps to fast-forward by.
"""
if self.num_generated == 0:
- torch._sobol_engine_ff_(self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated)
+ torch._sobol_engine_ff_(
+ self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated
+ )
else:
- torch._sobol_engine_ff_(self.quasi, n, self.sobolstate, self.dimension, self.num_generated - 1)
+ torch._sobol_engine_ff_(
+ self.quasi, n, self.sobolstate, self.dimension, self.num_generated - 1
+ )
self.num_generated += n
return self
@@ -166,8 +194,12 @@
cpu = torch.device("cpu")
# Generate shift vector
- shift_ints = torch.randint(2, (self.dimension, self.MAXBIT), device=cpu, generator=g)
- self.shift = torch.mv(shift_ints, torch.pow(2, torch.arange(0, self.MAXBIT, device=cpu)))
+ shift_ints = torch.randint(
+ 2, (self.dimension, self.MAXBIT), device=cpu, generator=g
+ )
+ self.shift = torch.mv(
+ shift_ints, torch.pow(2, torch.arange(0, self.MAXBIT, device=cpu))
+ )
# Generate lower triangular matrices (stacked across dimensions)
ltm_dims = (self.dimension, self.MAXBIT, self.MAXBIT)
@@ -176,9 +208,9 @@
torch._sobol_engine_scramble_(self.sobolstate, ltm, self.dimension)
def __repr__(self):
- fmt_string = [f'dimension={self.dimension}']
+ fmt_string = [f"dimension={self.dimension}"]
if self.scramble:
- fmt_string += ['scramble=True']
+ fmt_string += ["scramble=True"]
if self.seed is not None:
- fmt_string += [f'seed={self.seed}']
- return self.__class__.__name__ + '(' + ', '.join(fmt_string) + ')'
+ fmt_string += [f"seed={self.seed}"]
+ return self.__class__.__name__ + "(" + ", ".join(fmt_string) + ")"
diff --git a/torch/random.py b/torch/random.py
index 0916fe1..7833311 100644
--- a/torch/random.py
+++ b/torch/random.py
@@ -1,10 +1,10 @@
# mypy: allow-untyped-defs
import contextlib
-from typing import Generator
import warnings
+from typing import Generator
-from torch._C import default_generator
import torch
+from torch._C import default_generator
def set_rng_state(new_state: torch.Tensor) -> None:
@@ -46,10 +46,12 @@
torch.cuda.manual_seed_all(seed)
import torch.mps
+
if not torch.mps._is_in_bad_fork():
torch.mps.manual_seed(seed)
import torch.xpu
+
if not torch.xpu._is_in_bad_fork():
torch.xpu.manual_seed_all(seed)
@@ -69,10 +71,12 @@
torch.cuda.manual_seed_all(seed)
import torch.mps
+
if not torch.mps._is_in_bad_fork():
torch.mps.manual_seed(seed)
import torch.xpu
+
if not torch.xpu._is_in_bad_fork():
torch.xpu.manual_seed_all(seed)
@@ -95,7 +99,9 @@
custom_device_mod = getattr(torch, custom_backend_name)
_bad_fork_name = "_is_in_bad_fork"
_seed_all_name = "manual_seed_all"
- if hasattr(custom_device_mod, _bad_fork_name) and hasattr(custom_device_mod, _seed_all_name):
+ if hasattr(custom_device_mod, _bad_fork_name) and hasattr(
+ custom_device_mod, _seed_all_name
+ ):
if not getattr(custom_device_mod, _bad_fork_name)():
getattr(custom_device_mod, _seed_all_name)(seed)
else:
@@ -117,7 +123,13 @@
@contextlib.contextmanager
-def fork_rng(devices=None, enabled=True, _caller="fork_rng", _devices_kw="devices", device_type="cuda") -> Generator:
+def fork_rng(
+ devices=None,
+ enabled=True,
+ _caller="fork_rng",
+ _devices_kw="devices",
+ device_type="cuda",
+) -> Generator:
"""
Forks the RNG, so that when you return, the RNG is reset
to the state that it was previously in.
@@ -138,8 +150,10 @@
device_type = torch.device(device_type).type
device_mod = getattr(torch, device_type, None)
if device_mod is None:
- raise RuntimeError(f"torch has no module of `{device_type}`, you should register " +
- "a module by `torch._register_device_module`.")
+ raise RuntimeError(
+ f"torch has no module of `{device_type}`, you should register "
+ + "a module by `torch._register_device_module`."
+ )
global _fork_rng_warned_already
# Internal arguments:
@@ -153,17 +167,19 @@
if devices is None:
num_devices = device_mod.device_count()
if num_devices > 1 and not _fork_rng_warned_already:
- message = (f"{device_type.upper()} reports that you have {num_devices} available devices, and "
- f"you have used {_caller} without explicitly specifying which devices are being used. "
- f"For safety, we initialize *every* {device_type.upper()} device by default, which can "
- f"be quite slow if you have a lot of {device_type.upper()}s. If you know that you are only"
- f" making use of a few {device_type.upper()} devices, set the environment variable "
- f"{device_type.upper()}_VISIBLE_DEVICES or the '{_devices_kw}' keyword argument of {_caller} "
- "with the set of devices you are actually using. For example, if you are using CPU only, "
- "set device.upper()_VISIBLE_DEVICES= or devices=[]; if you are using device 0 only, "
- f"set {device_type.upper()}_VISIBLE_DEVICES=0 or devices=[0]. To initialize all devices "
- f"and suppress this warning, set the '{_devices_kw}' keyword argument to "
- f"`range(torch.{device_type}.device_count())`.")
+ message = (
+ f"{device_type.upper()} reports that you have {num_devices} available devices, and "
+ f"you have used {_caller} without explicitly specifying which devices are being used. "
+ f"For safety, we initialize *every* {device_type.upper()} device by default, which can "
+ f"be quite slow if you have a lot of {device_type.upper()}s. If you know that you are only"
+ f" making use of a few {device_type.upper()} devices, set the environment variable "
+ f"{device_type.upper()}_VISIBLE_DEVICES or the '{_devices_kw}' keyword argument of {_caller} "
+ "with the set of devices you are actually using. For example, if you are using CPU only, "
+ "set device.upper()_VISIBLE_DEVICES= or devices=[]; if you are using device 0 only, "
+ f"set {device_type.upper()}_VISIBLE_DEVICES=0 or devices=[0]. To initialize all devices "
+ f"and suppress this warning, set the '{_devices_kw}' keyword argument to "
+ f"`range(torch.{device_type}.device_count())`."
+ )
warnings.warn(message)
_fork_rng_warned_already = True
devices = list(range(num_devices))
diff --git a/torch/return_types.py b/torch/return_types.py
index a74150a..d456742 100644
--- a/torch/return_types.py
+++ b/torch/return_types.py
@@ -1,8 +1,9 @@
-import torch
import inspect
+import torch
from torch.utils._pytree import register_pytree_node, SequenceKey
+
__all__ = ["pytree_register_structseq", "all_return_types"]
all_return_types = []
@@ -10,6 +11,7 @@
# error: Module has no attribute "_return_types"
return_types = torch._C._return_types # type: ignore[attr-defined]
+
def pytree_register_structseq(cls):
def structseq_flatten(structseq):
return list(structseq), None
@@ -28,14 +30,15 @@
flatten_with_keys_fn=structseq_flatten_with_keys,
)
+
for name in dir(return_types):
- if name.startswith('__'):
+ if name.startswith("__"):
continue
_attr = getattr(return_types, name)
globals()[name] = _attr
- if not name.startswith('_'):
+ if not name.startswith("_"):
__all__.append(name)
all_return_types.append(_attr)
diff --git a/torch/serialization.py b/torch/serialization.py
index 311aac2..738af26 100644
--- a/torch/serialization.py
+++ b/torch/serialization.py
@@ -1,70 +1,87 @@
# mypy: allow-untyped-defs
+import copyreg
import difflib
import functools
-import os
import io
+import os
+import pickle
import shutil
import struct
import sys
-import torch
import tarfile
import tempfile
import warnings
from contextlib import closing, contextmanager
from enum import Enum
-from ._utils import _import_dotted_name
-from torch._sources import get_source_lines_and_file
-from torch.types import Storage
-from torch.storage import _get_dtype_from_pickle_storage_type
-from typing import Any, BinaryIO, Callable, cast, Dict, Optional, Type, Tuple, Union, IO, List
+from typing import (
+ Any,
+ BinaryIO,
+ Callable,
+ cast,
+ Dict,
+ IO,
+ List,
+ Optional,
+ Tuple,
+ Type,
+ Union,
+)
from typing_extensions import TypeAlias, TypeGuard # Python 3.10+
-import copyreg
-import pickle
+
+import torch
import torch._weights_only_unpickler as _weights_only_unpickler
+from torch._sources import get_source_lines_and_file
+from torch._utils import _import_dotted_name
+from torch.storage import _get_dtype_from_pickle_storage_type
+from torch.types import Storage
+
+
+__all__ = [
+ "SourceChangeWarning",
+ "mkdtemp",
+ "register_package",
+ "check_module_version_greater_or_equal",
+ "validate_cuda_device",
+ "validate_hpu_device",
+ "location_tag",
+ "default_restore_location",
+ "normalize_storage_type",
+ "storage_to_tensor_type",
+ "save",
+ "load",
+ "StorageType",
+ "LoadEndianness",
+ "get_default_load_endianness",
+ "set_default_load_endianness",
+ "clear_safe_globals",
+ "get_safe_globals",
+ "add_safe_globals",
+]
+
DEFAULT_PROTOCOL = 2
-LONG_SIZE = struct.Struct('=l').size
-INT_SIZE = struct.Struct('=i').size
-SHORT_SIZE = struct.Struct('=h').size
+LONG_SIZE = struct.Struct("=l").size
+INT_SIZE = struct.Struct("=i").size
+SHORT_SIZE = struct.Struct("=h").size
-MAGIC_NUMBER = 0x1950a86a20f9469cfc6c
+MAGIC_NUMBER = 0x1950A86A20F9469CFC6C
PROTOCOL_VERSION = 1001
-STORAGE_KEY_SEPARATOR = ','
+STORAGE_KEY_SEPARATOR = ","
FILE_LIKE: TypeAlias = Union[str, os.PathLike, BinaryIO, IO[bytes]]
-MAP_LOCATION: TypeAlias = Optional[Union[Callable[[Storage, str], Storage], torch.device, str, Dict[str, str]]]
+MAP_LOCATION: TypeAlias = Optional[
+ Union[Callable[[Storage, str], Storage], torch.device, str, Dict[str, str]]
+]
STORAGE: TypeAlias = Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage]
IS_WINDOWS = sys.platform == "win32"
if not IS_WINDOWS:
- from mmap import MAP_SHARED, MAP_PRIVATE
+ from mmap import MAP_PRIVATE, MAP_SHARED
else:
MAP_SHARED, MAP_PRIVATE = None, None # type: ignore[assignment]
-__all__ = [
- 'SourceChangeWarning',
- 'mkdtemp',
- 'register_package',
- 'check_module_version_greater_or_equal',
- 'validate_cuda_device',
- 'validate_hpu_device',
- 'location_tag',
- 'default_restore_location',
- 'normalize_storage_type',
- 'storage_to_tensor_type',
- 'save',
- 'load',
- 'StorageType',
- 'LoadEndianness',
- 'get_default_load_endianness',
- 'set_default_load_endianness',
- 'clear_safe_globals',
- 'get_safe_globals',
- 'add_safe_globals',
-]
-
class SourceChangeWarning(Warning):
pass
@@ -79,17 +96,26 @@
shutil.rmtree(path)
-_package_registry: List[Tuple[int, Callable[[STORAGE], Optional[str]], Callable[[STORAGE, str], Optional[STORAGE]]]] = []
+_package_registry: List[
+ Tuple[
+ int,
+ Callable[[STORAGE], Optional[str]],
+ Callable[[STORAGE, str], Optional[STORAGE]],
+ ]
+] = []
+
class LoadEndianness(Enum):
NATIVE = 1
LITTLE = 2
BIG = 3
+
_default_load_endian: Optional[LoadEndianness] = None
+
def get_default_load_endianness() -> Optional[LoadEndianness]:
- '''
+ """
Get fallback byte order for loading files
If byteorder mark is not present in saved checkpoint,
@@ -98,11 +124,12 @@
Returns:
default_load_endian: Optional[LoadEndianness]
- '''
+ """
return _default_load_endian
+
def set_default_load_endianness(endianness):
- '''
+ """
Set fallback byte order for loading files
If byteorder mark is not present in saved checkpoint,
@@ -111,16 +138,18 @@
Args:
endianness: the new fallback byte order
- '''
+ """
global _default_load_endian
if not isinstance(endianness, LoadEndianness) and endianness is not None:
raise TypeError("Invalid argument type in function set_default_load_endianness")
_default_load_endian = endianness
+
_default_mmap_options: int = MAP_PRIVATE
+
def get_default_mmap_options() -> int:
- '''
+ """
Get default mmap options for :func:`torch.load` with ``mmap=True``.
Defaults to ``mmap.MAP_PRIVATE``.
@@ -128,11 +157,12 @@
Returns:
default_mmap_options: int
- '''
+ """
return _default_mmap_options
+
def set_default_mmap_options(flags: int):
- '''
+ """
Set default mmap options for :func:`torch.load` with ``mmap=True`` to flags.
For now, only either ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED`` are supported.
@@ -143,36 +173,44 @@
Args:
flags: ``mmap.MAP_PRIVATE`` or ``mmap.MAP_SHARED``
- '''
+ """
global _default_mmap_options
if IS_WINDOWS:
- raise RuntimeError("Changing the default mmap options is currently not supported for Windows")
- if (flags != MAP_PRIVATE and flags != MAP_SHARED):
- raise ValueError("Invalid argument in function set_default_mmap_options, "
- f"expected mmap.MAP_PRIVATE or mmap.MAP_SHARED, but got {flags}")
+ raise RuntimeError(
+ "Changing the default mmap options is currently not supported for Windows"
+ )
+ if flags != MAP_PRIVATE and flags != MAP_SHARED:
+ raise ValueError(
+ "Invalid argument in function set_default_mmap_options, "
+ f"expected mmap.MAP_PRIVATE or mmap.MAP_SHARED, but got {flags}"
+ )
_default_mmap_options = flags
+
def clear_safe_globals() -> None:
- '''
+ """
Clears the list of globals that are safe for ``weights_only`` load.
- '''
+ """
_weights_only_unpickler._clear_safe_globals()
+
def get_safe_globals() -> List[Any]:
- '''
+ """
Returns the list of user-added globals that are safe for ``weights_only`` load.
- '''
+ """
return _weights_only_unpickler._get_safe_globals()
+
def add_safe_globals(safe_globals: List[Any]) -> None:
- '''
+ """
Marks the given globals as safe for ``weights_only`` load.
Args:
safe_globals (List[Any]): list of globals to mark as safe
- '''
+ """
_weights_only_unpickler._add_safe_globals(safe_globals)
+
def _is_zipfile(f) -> bool:
# This is a stricter implementation than zipfile.is_zipfile().
# zipfile.is_zipfile() is True if the magic number appears anywhere in the
@@ -183,7 +221,7 @@
start = f.tell()
# Read the first few bytes and match against the ZIP file signature
- local_header_magic_number = b'PK\x03\x04'
+ local_header_magic_number = b"PK\x03\x04"
read_bytes = f.read(len(local_header_magic_number))
f.seek(start)
return read_bytes == local_header_magic_number
@@ -192,9 +230,9 @@
def register_package(
priority: int,
tagger: Callable[[STORAGE], Optional[str]],
- deserializer: Callable[[STORAGE, str], Optional[STORAGE]]
+ deserializer: Callable[[STORAGE, str], Optional[STORAGE]],
):
- '''
+ """
Registers callables for tagging and deserializing storage objects with an associated priority.
Tagging associates a device with a storage object at save time while deserializing moves a
storage object to an appropriate device at load time. :attr:`tagger` and :attr:`deserializer`
@@ -228,14 +266,16 @@
>>> assert torch.ipu.is_available(), "ipu is not available"
>>> return obj.ipu(location)
>>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize)
- '''
+ """
queue_elem = (priority, tagger, deserializer)
_package_registry.append(queue_elem)
_package_registry.sort()
-def check_module_version_greater_or_equal(module, req_version_tuple, error_if_malformed=True):
- '''
+def check_module_version_greater_or_equal(
+ module, req_version_tuple, error_if_malformed=True
+):
+ """
Check if a module's version satisfies requirements
Usually, a module's version string will be like 'x.y.z', which would be represented
@@ -250,12 +290,13 @@
Returns:
requirement_is_met: bool
- '''
+ """
try:
- version_strs = module.__version__.split('.')
+ version_strs = module.__version__.split(".")
# Cast module version fields to match the types of the required version
module_version = tuple(
- type(req_field)(version_strs[idx]) for idx, req_field in enumerate(req_version_tuple)
+ type(req_field)(version_strs[idx])
+ for idx, req_field in enumerate(req_version_tuple)
)
requirement_is_met = module_version >= req_version_tuple
@@ -267,54 +308,54 @@
if error_if_malformed:
raise RuntimeError(message) from e
else:
- warnings.warn(message + ', but continuing assuming that requirement is met')
+ warnings.warn(message + ", but continuing assuming that requirement is met")
requirement_is_met = True
return requirement_is_met
def _cpu_tag(obj):
- if obj.device.type == 'cpu':
- return 'cpu'
+ if obj.device.type == "cpu":
+ return "cpu"
def _mps_tag(obj):
- if obj.device.type == 'mps':
- return 'mps'
+ if obj.device.type == "mps":
+ return "mps"
def _meta_tag(obj):
- if obj.device.type == 'meta':
- return 'meta'
+ if obj.device.type == "meta":
+ return "meta"
def _backend_tag(backend_name, obj):
- if backend_name == 'privateuse1':
+ if backend_name == "privateuse1":
backend_name = torch._C._get_privateuse1_backend_name()
if obj.device.type == backend_name:
if obj.device.index is None:
return backend_name
else:
- return backend_name + ':' + str(obj.device.index)
+ return backend_name + ":" + str(obj.device.index)
def _cpu_deserialize(obj, location):
- if location == 'cpu':
+ if location == "cpu":
return obj
def _mps_deserialize(obj, location):
- if location.startswith('mps'):
+ if location.startswith("mps"):
return obj.mps()
def _meta_deserialize(obj, location):
- if location == 'meta':
- return torch.UntypedStorage(obj.nbytes(), device='meta')
+ if location == "meta":
+ return torch.UntypedStorage(obj.nbytes(), device="meta")
def _validate_device(location, backend_name):
- '''
+ """
Check whether the device index of specified backend is valid
In case of privateuse1 backend, your must first register a device_module for
@@ -328,45 +369,53 @@
Returns:
device_index: int
- '''
+ """
if not hasattr(torch, backend_name):
- raise RuntimeError(f'The {backend_name.upper()} device module is not registered. '
- 'If you are running on a CPU-only machine, '
- 'please use torch.load with map_location=torch.device(\'cpu\') '
- 'to map your storages to the CPU.')
+ raise RuntimeError(
+ f"The {backend_name.upper()} device module is not registered. "
+ "If you are running on a CPU-only machine, "
+ "please use torch.load with map_location=torch.device('cpu') "
+ "to map your storages to the CPU."
+ )
device_module = getattr(torch, backend_name)
- if hasattr(device_module, '_utils') and hasattr(device_module._utils, '_get_device_index'):
+ if hasattr(device_module, "_utils") and hasattr(
+ device_module._utils, "_get_device_index"
+ ):
device_index = device_module._utils._get_device_index(location, True)
device = torch.device(backend_name, device_index)
else:
device = torch.device(location)
device_index = device.index if device.index else 0
- if hasattr(device_module, 'is_available') and not device_module.is_available():
- raise RuntimeError(f'Attempting to deserialize object on a {backend_name.upper()} '
- f'device but torch.{backend_name}.is_available() is False. '
- 'If you are running on a CPU-only machine, '
- 'please use torch.load with map_location=torch.device(\'cpu\') '
- 'to map your storages to the CPU.')
- if hasattr(device_module, 'device_count'):
+ if hasattr(device_module, "is_available") and not device_module.is_available():
+ raise RuntimeError(
+ f"Attempting to deserialize object on a {backend_name.upper()} "
+ f"device but torch.{backend_name}.is_available() is False. "
+ "If you are running on a CPU-only machine, "
+ "please use torch.load with map_location=torch.device('cpu') "
+ "to map your storages to the CPU."
+ )
+ if hasattr(device_module, "device_count"):
device_count = device_module.device_count()
if device_index >= device_count:
- raise RuntimeError(f'Attempting to deserialize object on {backend_name.upper()} device '
- f'{device_index} but torch.{backend_name}.device_count() is {device_count}. '
- 'Please use torch.load with map_location to map your storages '
- 'to an existing device.')
+ raise RuntimeError(
+ f"Attempting to deserialize object on {backend_name.upper()} device "
+ f"{device_index} but torch.{backend_name}.device_count() is {device_count}. "
+ "Please use torch.load with map_location to map your storages "
+ "to an existing device."
+ )
return device
def validate_cuda_device(location):
- return _validate_device(location, 'cuda').index
+ return _validate_device(location, "cuda").index
def validate_hpu_device(location):
- return _validate_device(location, 'hpu').index
+ return _validate_device(location, "hpu").index
def _deserialize(backend_name, obj, location):
- if backend_name == 'privateuse1':
+ if backend_name == "privateuse1":
backend_name = torch._C._get_privateuse1_backend_name()
if location.startswith(backend_name):
device = _validate_device(location, backend_name)
@@ -374,20 +423,34 @@
register_package(10, _cpu_tag, _cpu_deserialize)
-register_package(20, functools.partial(_backend_tag, 'cuda'), functools.partial(_deserialize, 'cuda'))
+register_package(
+ 20, functools.partial(_backend_tag, "cuda"), functools.partial(_deserialize, "cuda")
+)
register_package(21, _mps_tag, _mps_deserialize)
register_package(22, _meta_tag, _meta_deserialize)
-register_package(23, functools.partial(_backend_tag, 'privateuse1'), functools.partial(_deserialize, 'privateuse1'))
-register_package(24, functools.partial(_backend_tag, 'hpu'), functools.partial(_deserialize, 'hpu'))
-register_package(25, functools.partial(_backend_tag, 'xpu'), functools.partial(_deserialize, 'xpu'))
+register_package(
+ 23,
+ functools.partial(_backend_tag, "privateuse1"),
+ functools.partial(_deserialize, "privateuse1"),
+)
+register_package(
+ 24, functools.partial(_backend_tag, "hpu"), functools.partial(_deserialize, "hpu")
+)
+register_package(
+ 25, functools.partial(_backend_tag, "xpu"), functools.partial(_deserialize, "xpu")
+)
-def location_tag(storage: Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage]):
+
+def location_tag(
+ storage: Union[Storage, torch.storage.TypedStorage, torch.UntypedStorage]
+):
for _, tagger, _ in _package_registry:
location = tagger(storage)
if location:
return location
- raise RuntimeError("don't know how to determine data location of "
- + torch.typename(storage))
+ raise RuntimeError(
+ "don't know how to determine data location of " + torch.typename(storage)
+ )
def default_restore_location(storage, location):
@@ -414,9 +477,13 @@
result = fn(storage, location)
if result is not None:
return result
- raise RuntimeError("don't know how to restore data location of "
- + torch.typename(storage) + " (tagged with "
- + location + ")")
+ raise RuntimeError(
+ "don't know how to restore data location of "
+ + torch.typename(storage)
+ + " (tagged with "
+ + location
+ + ")"
+ )
def normalize_storage_type(storage_type):
@@ -426,7 +493,7 @@
def storage_to_tensor_type(storage):
storage_type = type(storage)
module = _import_dotted_name(storage_type.__module__)
- return getattr(module, storage_type.__name__.replace('Storage', 'Tensor'))
+ return getattr(module, storage_type.__name__.replace("Storage", "Tensor"))
def _is_path(name_or_buffer) -> TypeGuard[Union[str, os.PathLike]]:
@@ -467,9 +534,9 @@
if _is_path(name_or_buffer):
return _open_file(name_or_buffer, mode)
else:
- if 'w' in mode:
+ if "w" in mode:
return _open_buffer_writer(name_or_buffer)
- elif 'r' in mode:
+ elif "r" in mode:
return _open_buffer_reader(name_or_buffer)
else:
raise RuntimeError(f"Expected 'r' or 'w' in mode but got {mode}")
@@ -485,12 +552,12 @@
self.file_stream = None
self.name = str(name)
try:
- self.name.encode('ascii')
+ self.name.encode("ascii")
except UnicodeEncodeError:
# PyTorchFileWriter only supports ascii filename.
# For filenames with non-ascii characters, we rely on Python
# for writing out the file.
- self.file_stream = io.FileIO(self.name, mode='w')
+ self.file_stream = io.FileIO(self.name, mode="w")
super().__init__(torch._C.PyTorchFileWriter(self.file_stream))
else:
super().__init__(torch._C.PyTorchFileWriter(self.name))
@@ -526,7 +593,7 @@
def _is_compressed_file(f) -> bool:
- compress_modules = ['gzip']
+ compress_modules = ["gzip"]
try:
return f.__module__ in compress_modules
except AttributeError:
@@ -550,13 +617,15 @@
def _check_seekable(f) -> bool:
-
def raise_err_msg(patterns, e):
for p in patterns:
if p in str(e):
- msg = (str(e) + ". You can only torch.load from a file that is seekable."
- + " Please pre-load the data into a buffer like io.BytesIO and"
- + " try to load from it instead.")
+ msg = (
+ str(e)
+ + ". You can only torch.load from a file that is seekable."
+ + " Please pre-load the data into a buffer like io.BytesIO and"
+ + " try to load from it instead."
+ )
raise type(e)(msg)
raise e
@@ -569,30 +638,35 @@
def _check_dill_version(pickle_module) -> None:
- '''Checks if using dill as the pickle module, and if so, checks if it is the correct version.
+ """Checks if using dill as the pickle module, and if so, checks if it is the correct version.
If dill version is lower than 0.3.1, a ValueError is raised.
Args:
pickle_module: module used for pickling metadata and objects
- '''
- if pickle_module is not None and pickle_module.__name__ == 'dill':
+ """
+ if pickle_module is not None and pickle_module.__name__ == "dill":
required_dill_version = (0, 3, 1)
- if not check_module_version_greater_or_equal(pickle_module, required_dill_version, False):
- raise ValueError((
- "'torch' supports dill >= {}, but you have dill {}."
- " Please upgrade dill or switch to 'pickle'"
- ).format(
- '.'.join([str(num) for num in required_dill_version]),
- pickle_module.__version__
- ))
+ if not check_module_version_greater_or_equal(
+ pickle_module, required_dill_version, False
+ ):
+ raise ValueError(
+ (
+ "'torch' supports dill >= {}, but you have dill {}."
+ " Please upgrade dill or switch to 'pickle'"
+ ).format(
+ ".".join([str(num) for num in required_dill_version]),
+ pickle_module.__version__,
+ )
+ )
def _check_save_filelike(f):
- if not _is_path(f) and not hasattr(f, 'write'):
+ if not _is_path(f) and not hasattr(f, "write"):
raise AttributeError(
"expected 'f' to be string, path, or a file-like object with "
- "a 'write' attribute")
+ "a 'write' attribute"
+ )
def save(
@@ -601,7 +675,7 @@
pickle_module: Any = pickle,
pickle_protocol: int = DEFAULT_PROTOCOL,
_use_new_zipfile_serialization: bool = True,
- _disable_byteorder_record: bool = False
+ _disable_byteorder_record: bool = False,
) -> None:
# Reference: https://github.com/pytorch/pytorch/issues/54354
# The first line of this docstring overrides the one Sphinx generates for the
@@ -649,15 +723,22 @@
if _use_new_zipfile_serialization:
with _open_zipfile_writer(f) as opened_zipfile:
- _save(obj, opened_zipfile, pickle_module, pickle_protocol, _disable_byteorder_record)
+ _save(
+ obj,
+ opened_zipfile,
+ pickle_module,
+ pickle_protocol,
+ _disable_byteorder_record,
+ )
return
else:
- with _open_file_like(f, 'wb') as opened_file:
+ with _open_file_like(f, "wb") as opened_file:
_legacy_save(obj, opened_file, pickle_module, pickle_protocol)
def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
import torch.nn as nn
+
serialized_container_types = {}
serialized_storages = {}
@@ -680,12 +761,16 @@
source_file = source = None
try:
source_lines, _, source_file = get_source_lines_and_file(obj)
- source = ''.join(source_lines)
- except Exception: # saving the source is optional, so we can ignore any errors
- warnings.warn("Couldn't retrieve source code for container of "
- "type " + obj.__name__ + ". It won't be checked "
- "for correctness upon loading.")
- return ('module', obj, source_file, source)
+ source = "".join(source_lines)
+ except (
+ Exception
+ ): # saving the source is optional, so we can ignore any errors
+ warnings.warn(
+ "Couldn't retrieve source code for container of "
+ "type " + obj.__name__ + ". It won't be checked "
+ "for correctness upon loading."
+ )
+ return ("module", obj, source_file, source)
if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
storage: torch.UntypedStorage
@@ -707,7 +792,7 @@
dtype = torch.uint8
storage_numel = storage.nbytes()
else:
- raise TypeError(f'type not recognized: {type(obj)}')
+ raise TypeError(f"type not recognized: {type(obj)}")
# If storage is allocated, ensure that any other saved storages
# pointing to the same data all have the same dtype. If storage is
@@ -716,8 +801,9 @@
if storage.data_ptr() in storage_dtypes:
if storage_dtype != storage_dtypes[storage.data_ptr()]:
raise RuntimeError(
- 'Cannot save multiple tensors or storages that '
- 'view the same data as different types')
+ "Cannot save multiple tensors or storages that "
+ "view the same data as different types"
+ )
else:
storage_dtypes[storage.data_ptr()] = storage_dtype
@@ -766,18 +852,20 @@
else:
view_metadata = None
- res = ('storage',
- storage_type,
- storage_key,
- location,
- storage_numel,
- view_metadata)
+ res = (
+ "storage",
+ storage_type,
+ storage_key,
+ location,
+ storage_numel,
+ view_metadata,
+ )
return res
return None
sys_info = dict(
protocol_version=PROTOCOL_VERSION,
- little_endian=sys.byteorder == 'little',
+ little_endian=sys.byteorder == "little",
type_sizes=dict(
short=SHORT_SIZE,
int=INT_SIZE,
@@ -797,7 +885,9 @@
f.flush()
for key in serialized_storage_keys:
storage, dtype = serialized_storages[key]
- storage._write_file(f, _should_read_directly(f), True, torch._utils._element_size(dtype))
+ storage._write_file(
+ f, _should_read_directly(f), True, torch._utils._element_size(dtype)
+ )
def _save(obj, zip_file, pickle_module, pickle_protocol, _disable_byteorder_record):
@@ -817,7 +907,6 @@
# https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
# https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
-
if isinstance(obj, torch.storage.TypedStorage):
# TODO: Once we decide to break serialization FC, this case
# can be deleted
@@ -840,8 +929,9 @@
if storage.data_ptr() in storage_dtypes:
if storage_dtype != storage_dtypes[storage.data_ptr()]:
raise RuntimeError(
- 'Cannot save multiple tensors or storages that '
- 'view the same data as different types')
+ "Cannot save multiple tensors or storages that "
+ "view the same data as different types"
+ )
else:
storage_dtypes[storage.data_ptr()] = storage_dtype
@@ -849,11 +939,7 @@
location = location_tag(storage)
serialized_storages[storage_key] = storage
- return ('storage',
- storage_type,
- storage_key,
- location,
- storage_numel)
+ return ("storage", storage_type, storage_key, location, storage_numel)
return None
@@ -863,23 +949,23 @@
pickler.persistent_id = persistent_id
pickler.dump(obj)
data_value = data_buf.getvalue()
- zip_file.write_record('data.pkl', data_value, len(data_value))
+ zip_file.write_record("data.pkl", data_value, len(data_value))
# Write byte order marker
if not _disable_byteorder_record:
- if sys.byteorder not in ['little', 'big']:
- raise ValueError('Unknown endianness type: ' + sys.byteorder)
+ if sys.byteorder not in ["little", "big"]:
+ raise ValueError("Unknown endianness type: " + sys.byteorder)
- zip_file.write_record('byteorder', sys.byteorder, len(sys.byteorder))
+ zip_file.write_record("byteorder", sys.byteorder, len(sys.byteorder))
# Write each tensor to a file named tensor/the_tensor_key in the zip archive
for key in sorted(serialized_storages.keys()):
- name = f'data/{key}'
+ name = f"data/{key}"
storage = serialized_storages[key]
# given that we copy things around anyway, we might use storage.cpu()
# this means to that to get tensors serialized, you need to implement
# .cpu() on the underlying Storage
- if storage.device.type != 'cpu':
+ if storage.device.type != "cpu":
storage = storage.cpu()
# Now that it is on the CPU we can directly copy it into the zip file
num_bytes = storage.nbytes()
@@ -893,7 +979,7 @@
*,
weights_only: bool = False,
mmap: Optional[bool] = None,
- **pickle_load_args: Any
+ **pickle_load_args: Any,
) -> Any:
# Reference: https://github.com/pytorch/pytorch/issues/54354
# The first line of this docstring overrides the one Sphinx generates for the
@@ -1002,12 +1088,19 @@
" WeightsUnpickler error: "
)
# Add ability to force safe only weight loads via environment variable
- if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() in ['1', 'y', 'yes', 'true']:
+ if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() in [
+ "1",
+ "y",
+ "yes",
+ "true",
+ ]:
weights_only = True
if weights_only:
if pickle_module is not None:
- raise RuntimeError("Can not safely load weights when explicit pickle_module is specified")
+ raise RuntimeError(
+ "Can not safely load weights when explicit pickle_module is specified"
+ )
else:
if pickle_module is None:
pickle_module = pickle
@@ -1018,10 +1111,10 @@
_check_dill_version(pickle_module)
- if 'encoding' not in pickle_load_args.keys():
- pickle_load_args['encoding'] = 'utf-8'
+ if "encoding" not in pickle_load_args.keys():
+ pickle_load_args["encoding"] = "utf-8"
- with _open_file_like(f, 'rb') as opened_file:
+ with _open_file_like(f, "rb") as opened_file:
if _is_zipfile(opened_file):
# The zipfile reader is going to advance the current file position.
# If we want to actually tail call to torch.jit.load, we need to
@@ -1030,61 +1123,81 @@
overall_storage = None
with _open_zipfile_reader(opened_file) as opened_zipfile:
if _is_torchscript_zip(opened_zipfile):
- warnings.warn("'torch.load' received a zip file that looks like a TorchScript archive"
- " dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to"
- " silence this warning)", UserWarning)
+ warnings.warn(
+ "'torch.load' received a zip file that looks like a TorchScript archive"
+ " dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to"
+ " silence this warning)",
+ UserWarning,
+ )
opened_file.seek(orig_position)
return torch.jit.load(opened_file, map_location=map_location)
if mmap:
if not _is_path(f):
- raise ValueError("f must be a file path in order to use the mmap argument")
+ raise ValueError(
+ "f must be a file path in order to use the mmap argument"
+ )
size = os.path.getsize(f)
if not IS_WINDOWS:
shared = get_default_mmap_options() == MAP_SHARED
else:
shared = False
- overall_storage = torch.UntypedStorage.from_file(os.fspath(f), shared, size)
+ overall_storage = torch.UntypedStorage.from_file(
+ os.fspath(f), shared, size
+ )
if weights_only:
try:
- return _load(opened_zipfile,
- map_location,
- _weights_only_unpickler,
- overall_storage=overall_storage,
- **pickle_load_args)
+ return _load(
+ opened_zipfile,
+ map_location,
+ _weights_only_unpickler,
+ overall_storage=overall_storage,
+ **pickle_load_args,
+ )
except RuntimeError as e:
raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None
- return _load(opened_zipfile,
- map_location,
- pickle_module,
- overall_storage=overall_storage,
- **pickle_load_args)
+ return _load(
+ opened_zipfile,
+ map_location,
+ pickle_module,
+ overall_storage=overall_storage,
+ **pickle_load_args,
+ )
if mmap:
f_name = "" if not isinstance(f, str) else f"{f}, "
- raise RuntimeError("mmap can only be used with files saved with "
- f"`torch.save({f_name}_use_new_zipfile_serialization=True), "
- "please torch.save your checkpoint with this option in order to use mmap.")
+ raise RuntimeError(
+ "mmap can only be used with files saved with "
+ f"`torch.save({f_name}_use_new_zipfile_serialization=True), "
+ "please torch.save your checkpoint with this option in order to use mmap."
+ )
if weights_only:
try:
- return _legacy_load(opened_file, map_location, _weights_only_unpickler, **pickle_load_args)
+ return _legacy_load(
+ opened_file,
+ map_location,
+ _weights_only_unpickler,
+ **pickle_load_args,
+ )
except RuntimeError as e:
raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None
- return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
+ return _legacy_load(
+ opened_file, map_location, pickle_module, **pickle_load_args
+ )
# Register pickling support for layout instances such as
# torch.sparse_coo, etc
def _get_layout(name):
- """Get layout extension object from its string representation.
- """
- cache = _get_layout.cache # type: ignore[attr-defined]
+ """Get layout extension object from its string representation."""
+ cache = _get_layout.cache # type: ignore[attr-defined]
if not cache:
for v in torch.__dict__.values():
if isinstance(v, torch.layout):
cache[str(v)] = v
return cache[name]
+
# There are yet not good way to type annotate function attributes https://github.com/python/mypy/issues/2087
-_get_layout.cache = {} # type: ignore[attr-defined]
+_get_layout.cache = {} # type: ignore[attr-defined]
copyreg.pickle(torch.layout, lambda obj: (_get_layout, (str(obj),)))
@@ -1094,9 +1207,8 @@
restore_location = _get_restore_location(map_location)
class UnpicklerWrapper(pickle_module.Unpickler): # type: ignore[name-defined]
-
def find_class(self, mod_name, name):
- if type(name) is str and 'Storage' in name:
+ if type(name) is str and "Storage" in name:
try:
return StorageType(name)
except KeyError:
@@ -1105,41 +1217,52 @@
def _check_container_source(container_type, source_file, original_source):
try:
- current_source = ''.join(get_source_lines_and_file(container_type)[0])
+ current_source = "".join(get_source_lines_and_file(container_type)[0])
except Exception: # saving the source is optional, so we can ignore any errors
- warnings.warn("Couldn't retrieve source code for container of "
- "type " + container_type.__name__ + ". It won't be checked "
- "for correctness upon loading.")
+ warnings.warn(
+ "Couldn't retrieve source code for container of "
+ "type " + container_type.__name__ + ". It won't be checked "
+ "for correctness upon loading."
+ )
return
if original_source != current_source:
if container_type.dump_patches:
- file_name = container_type.__name__ + '.patch'
- diff = difflib.unified_diff(current_source.split('\n'),
- original_source.split('\n'),
- source_file,
- source_file, lineterm="")
- lines = '\n'.join(diff)
+ file_name = container_type.__name__ + ".patch"
+ diff = difflib.unified_diff(
+ current_source.split("\n"),
+ original_source.split("\n"),
+ source_file,
+ source_file,
+ lineterm="",
+ )
+ lines = "\n".join(diff)
try:
- with open(file_name, 'a+') as f:
+ with open(file_name, "a+") as f:
file_size = f.seek(0, 2)
f.seek(0)
if file_size == 0:
f.write(lines)
elif file_size != len(lines) or f.read() != lines:
raise OSError
- msg = ("Saved a reverse patch to " + file_name + ". "
- "Run `patch -p0 < " + file_name + "` to revert your "
- "changes.")
+ msg = (
+ "Saved a reverse patch to " + file_name + ". "
+ "Run `patch -p0 < " + file_name + "` to revert your "
+ "changes."
+ )
except OSError:
- msg = ("Tried to save a patch, but couldn't create a "
- "writable file " + file_name + ". Make sure it "
- "doesn't exist and your working directory is "
- "writable.")
+ msg = (
+ "Tried to save a patch, but couldn't create a "
+ "writable file " + file_name + ". Make sure it "
+ "doesn't exist and your working directory is "
+ "writable."
+ )
else:
- msg = ("you can retrieve the original source code by "
- "accessing the object's source attribute or set "
- "`torch.nn.Module.dump_patches = True` and use the "
- "patch tool to revert the changes.")
+ msg = (
+ "you can retrieve the original source code by "
+ "accessing the object's source attribute or set "
+ "`torch.nn.Module.dump_patches = True` and use the "
+ "patch tool to revert the changes."
+ )
msg = f"source code of class '{torch.typename(container_type)}' has changed. {msg}"
warnings.warn(msg, SourceChangeWarning)
@@ -1154,24 +1277,25 @@
return saved_id[0]
return deserialized_objects[int(saved_id)]
- with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \
- mkdtemp() as tmpdir:
-
- tar.extract('storages', path=tmpdir)
- with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f:
+ with closing(
+ tarfile.open(fileobj=f, mode="r:", format=tarfile.PAX_FORMAT)
+ ) as tar, mkdtemp() as tmpdir:
+ tar.extract("storages", path=tmpdir)
+ with open(os.path.join(tmpdir, "storages"), "rb", 0) as f:
num_storages = pickle_module.load(f, **pickle_load_args)
for i in range(num_storages):
args = pickle_module.load(f, **pickle_load_args)
key, location, storage_type = args
dtype = storage_type._dtype
- obj = cast(Storage, torch.UntypedStorage)._new_with_file(f, torch._utils._element_size(dtype))
+ obj = cast(Storage, torch.UntypedStorage)._new_with_file(
+ f, torch._utils._element_size(dtype)
+ )
obj = restore_location(obj, location)
# TODO: Once we decide to break serialization FC, we can
# stop wrapping with TypedStorage
deserialized_objects[key] = torch.storage.TypedStorage(
- wrap_storage=obj,
- dtype=dtype,
- _internal=True)
+ wrap_storage=obj, dtype=dtype, _internal=True
+ )
storage_views = pickle_module.load(f, **pickle_load_args)
for target_cdata, root_cdata, offset, numel in storage_views:
@@ -1181,28 +1305,32 @@
# TODO: Once we decide to break serialization FC, we can
# stop wrapping with TypedStorage
deserialized_objects[target_cdata] = torch.storage.TypedStorage(
- wrap_storage=root._untyped_storage[offset_bytes:offset_bytes + numel * element_size],
+ wrap_storage=root._untyped_storage[
+ offset_bytes : offset_bytes + numel * element_size
+ ],
dtype=root.dtype,
- _internal=True)
+ _internal=True,
+ )
- tar.extract('tensors', path=tmpdir)
- with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f:
+ tar.extract("tensors", path=tmpdir)
+ with open(os.path.join(tmpdir, "tensors"), "rb", 0) as f:
num_tensors = pickle_module.load(f, **pickle_load_args)
for _ in range(num_tensors):
args = pickle_module.load(f, **pickle_load_args)
key, storage_id, original_tensor_type = args
storage = deserialized_objects[storage_id]
- ndim, = struct.unpack('<i', f.read(4))
+ (ndim,) = struct.unpack("<i", f.read(4))
# skip next 4 bytes; legacy encoding treated ndim as 8 bytes
f.read(4)
- numel = struct.unpack(f'<{ndim}q', f.read(8 * ndim))
- stride = struct.unpack(f'<{ndim}q', f.read(8 * ndim))
- storage_offset, = struct.unpack('<q', f.read(8))
+ numel = struct.unpack(f"<{ndim}q", f.read(8 * ndim))
+ stride = struct.unpack(f"<{ndim}q", f.read(8 * ndim))
+ (storage_offset,) = struct.unpack("<q", f.read(8))
tensor = torch.empty((0,), dtype=storage.dtype).set_(
- storage._untyped_storage, storage_offset, numel, stride)
+ storage._untyped_storage, storage_offset, numel, stride
+ )
deserialized_objects[key] = tensor
- pickle_file = tar.extractfile('pickle')
+ pickle_file = tar.extractfile("pickle")
unpickler = UnpicklerWrapper(pickle_file, **pickle_load_args)
unpickler.persistent_load = persistent_load
result = unpickler.load()
@@ -1215,12 +1343,12 @@
typename = _maybe_decode_ascii(saved_id[0])
data = saved_id[1:]
- if typename == 'module':
+ if typename == "module":
# Ignore containers that don't have any sources saved
if all(data[1:]):
_check_container_source(*data)
return data[0]
- elif typename == 'storage':
+ elif typename == "storage":
storage_type, root_key, location, numel, view_metadata = data
location = _maybe_decode_ascii(location)
dtype = storage_type.dtype
@@ -1229,7 +1357,7 @@
if root_key not in deserialized_objects:
if torch._guards.active_fake_mode() is not None:
- obj = cast(Storage, torch.UntypedStorage(nbytes, device='meta'))
+ obj = cast(Storage, torch.UntypedStorage(nbytes, device="meta"))
else:
obj = cast(Storage, torch.UntypedStorage(nbytes))
obj._torch_load_uninitialized = True
@@ -1237,9 +1365,8 @@
# TODO: Once we decide to break serialization FC, we can
# stop wrapping with TypedStorage
typed_storage = torch.storage.TypedStorage(
- wrap_storage=obj,
- dtype=dtype,
- _internal=True)
+ wrap_storage=obj, dtype=dtype, _internal=True
+ )
deserialized_objects[root_key] = typed_storage
else:
typed_storage = deserialized_objects[root_key]
@@ -1247,7 +1374,8 @@
typed_storage = torch.storage.TypedStorage(
device=typed_storage._untyped_storage.device,
dtype=dtype,
- _internal=True)
+ _internal=True,
+ )
if view_metadata is not None:
view_key, offset, view_size = view_metadata
@@ -1257,9 +1385,12 @@
# TODO: Once we decide to break serialization FC, we can
# stop wrapping with TypedStorage
deserialized_objects[view_key] = torch.storage.TypedStorage(
- wrap_storage=typed_storage._untyped_storage[offset_bytes:offset_bytes + view_size_bytes],
+ wrap_storage=typed_storage._untyped_storage[
+ offset_bytes : offset_bytes + view_size_bytes
+ ],
dtype=dtype,
- _internal=True)
+ _internal=True,
+ )
res = deserialized_objects[view_key]
else:
@@ -1280,15 +1411,17 @@
if _is_zipfile(f):
# .zip is used for torch.jit.save and will throw an un-pickling error here
raise RuntimeError(
- f"{f.name} is a zip archive (did you mean to use torch.jit.load()?)") from None
+ f"{f.name} is a zip archive (did you mean to use torch.jit.load()?)"
+ ) from None
# if not a tarfile, reset file offset and proceed
f.seek(0)
- if not hasattr(f, 'readinto') and (3, 8, 0) <= sys.version_info < (3, 8, 2):
+ if not hasattr(f, "readinto") and (3, 8, 0) <= sys.version_info < (3, 8, 2):
raise RuntimeError(
"torch.load does not work with file-like objects that do not implement readinto on Python 3.8.0 and 3.8.1. "
f'Received object of type "{type(f)}". Please update to Python 3.8.2 or newer to restore this '
- "functionality.")
+ "functionality."
+ )
magic_number = pickle_module.load(f, **pickle_load_args)
if magic_number != MAGIC_NUMBER:
@@ -1310,8 +1443,11 @@
assert key in deserialized_objects
typed_storage = deserialized_objects[key]
typed_storage._untyped_storage._set_from_file(
- f, offset, f_should_read_directly,
- torch._utils._element_size(typed_storage.dtype))
+ f,
+ offset,
+ f_should_read_directly,
+ torch._utils._element_size(typed_storage.dtype),
+ )
if offset is not None:
offset = f.tell()
@@ -1328,7 +1464,7 @@
# NOTE: This should only be used on internal keys (e.g., `typename` and
# `location` in `persistent_load` below!
if isinstance(bytes_str, bytes):
- return bytes_str.decode('ascii')
+ return bytes_str.decode("ascii")
return bytes_str
@@ -1336,21 +1472,29 @@
if map_location is None:
restore_location = default_restore_location
elif isinstance(map_location, dict):
+
def restore_location(storage, location):
location = map_location.get(location, location)
return default_restore_location(storage, location)
+
elif isinstance(map_location, (str, bytes)):
+
def restore_location(storage, location):
return default_restore_location(storage, map_location)
+
elif isinstance(map_location, torch.device):
+
def restore_location(storage, location):
return default_restore_location(storage, str(map_location))
+
else:
+
def restore_location(storage, location):
result = map_location(storage, location)
if result is None:
result = default_restore_location(storage, location)
return result
+
return restore_location
@@ -1363,53 +1507,70 @@
return self._dtype
def __str__(self):
- return f'StorageType(dtype={self.dtype})'
+ return f"StorageType(dtype={self.dtype})"
-def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', overall_storage=None, **pickle_load_args):
+def _load(
+ zip_file,
+ map_location,
+ pickle_module,
+ pickle_file="data.pkl",
+ overall_storage=None,
+ **pickle_load_args,
+):
restore_location = _get_restore_location(map_location)
loaded_storages = {}
# check if byteswapping is needed
- byteordername = 'byteorder'
+ byteordername = "byteorder"
byteorderdata = None
if zip_file.has_record(byteordername):
byteorderdata = zip_file.get_record(byteordername)
- if byteorderdata not in [b'little', b'big']:
- raise ValueError('Unknown endianness type: ' + byteorderdata.decode())
- elif get_default_load_endianness() == LoadEndianness.LITTLE or \
- get_default_load_endianness() is None:
- byteorderdata = b'little'
+ if byteorderdata not in [b"little", b"big"]:
+ raise ValueError("Unknown endianness type: " + byteorderdata.decode())
+ elif (
+ get_default_load_endianness() == LoadEndianness.LITTLE
+ or get_default_load_endianness() is None
+ ):
+ byteorderdata = b"little"
elif get_default_load_endianness() == LoadEndianness.BIG:
- byteorderdata = b'big'
+ byteorderdata = b"big"
elif get_default_load_endianness() == LoadEndianness.NATIVE:
pass
else:
- raise ValueError('Invalid load endianness type')
+ raise ValueError("Invalid load endianness type")
- if not zip_file.has_record(byteordername) and \
- get_default_load_endianness() is None and \
- sys.byteorder == 'big':
+ if (
+ not zip_file.has_record(byteordername)
+ and get_default_load_endianness() is None
+ and sys.byteorder == "big"
+ ):
# Default behaviour was changed
# See https://github.com/pytorch/pytorch/issues/101688
- warnings.warn("The default load endianness for checkpoints without a byteorder mark "
- "on big endian machines was changed from 'native' to 'little' endian, "
- "to avoid this behavior please use "
- "torch.serialization.set_default_load_endianness to set "
- "the desired default load endianness",
- UserWarning)
+ warnings.warn(
+ "The default load endianness for checkpoints without a byteorder mark "
+ "on big endian machines was changed from 'native' to 'little' endian, "
+ "to avoid this behavior please use "
+ "torch.serialization.set_default_load_endianness to set "
+ "the desired default load endianness",
+ UserWarning,
+ )
def load_tensor(dtype, numel, key, location):
- name = f'data/{key}'
+ name = f"data/{key}"
if torch._guards.detect_fake_mode(None) is not None:
nbytes = numel * torch._utils._element_size(dtype)
- storage = torch.UntypedStorage(nbytes, device='meta')
+ storage = torch.UntypedStorage(nbytes, device="meta")
elif overall_storage is not None:
storage_offset = zip_file.get_record_offset(name)
- storage = overall_storage[storage_offset:storage_offset + numel]
+ storage = overall_storage[storage_offset : storage_offset + numel]
else:
- storage = zip_file.get_storage_from_record(name, numel, torch.UntypedStorage)._typed_storage()._untyped_storage
+ storage = (
+ zip_file.get_storage_from_record(name, numel, torch.UntypedStorage)
+ ._typed_storage()
+ ._untyped_storage
+ )
# swap here if byteswapping is needed
if byteorderdata is not None:
if byteorderdata.decode() != sys.byteorder:
@@ -1420,7 +1581,8 @@
typed_storage = torch.storage.TypedStorage(
wrap_storage=restore_location(storage, location),
dtype=dtype,
- _internal=True)
+ _internal=True,
+ )
if typed_storage._data_ptr() != 0:
loaded_storages[key] = typed_storage
@@ -1432,8 +1594,9 @@
typename = _maybe_decode_ascii(saved_id[0])
data = saved_id[1:]
- assert typename == 'storage', \
- f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
+ assert (
+ typename == "storage"
+ ), f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
storage_type, key, location, numel = data
if storage_type is torch.UntypedStorage:
dtype = torch.uint8
@@ -1444,13 +1607,15 @@
typed_storage = loaded_storages[key]
else:
nbytes = numel * torch._utils._element_size(dtype)
- typed_storage = load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location))
+ typed_storage = load_tensor(
+ dtype, nbytes, key, _maybe_decode_ascii(location)
+ )
return typed_storage
load_module_mapping: Dict[str, str] = {
# See https://github.com/pytorch/pytorch/pull/51633
- 'torch.tensor': 'torch._tensor'
+ "torch.tensor": "torch._tensor"
}
# Need to subclass Unpickler instead of directly monkey-patching the find_class method
@@ -1461,7 +1626,7 @@
# Lets us override the imports that pickle uses when unpickling an object.
# This is useful for maintaining BC if we change a module path that tensor instantiation relies on.
def find_class(self, mod_name, name):
- if type(name) is str and 'Storage' in name:
+ if type(name) is str and "Storage" in name:
try:
return StorageType(name)
except KeyError:
@@ -1488,4 +1653,4 @@
def _is_torchscript_zip(zip_file):
- return 'constants.pkl' in zip_file.get_all_records()
+ return "constants.pkl" in zip_file.get_all_records()
diff --git a/torch/torch_version.py b/torch/torch_version.py
index e8814cf..1b18f7d 100644
--- a/torch/torch_version.py
+++ b/torch/torch_version.py
@@ -2,8 +2,9 @@
from typing import Any, Iterable
-from ._vendor.packaging.version import InvalidVersion, Version
-from .version import __version__ as internal_version
+from torch._vendor.packaging.version import InvalidVersion, Version
+from torch.version import __version__ as internal_version
+
__all__ = ["TorchVersion"]
diff --git a/torch/types.py b/torch/types.py
index a522d62..67875f8 100644
--- a/torch/types.py
+++ b/torch/types.py
@@ -1,9 +1,14 @@
# mypy: allow-untyped-defs
import builtins
-from typing import Any, List, Optional, Sequence, Tuple, Union
+from typing import Any, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
import torch
+
+if TYPE_CHECKING:
+ from torch.autograd.graph import GradientEdge
+
+
# Convenience aliases for common composite types that we need
# to talk about in PyTorch
@@ -11,8 +16,8 @@
_TensorOrTensorsOrGradEdge = Union[
torch.Tensor,
Sequence[torch.Tensor],
- "torch.autograd.graph.GradientEdge",
- Sequence["torch.autograd.graph.GradientEdge"],
+ "GradientEdge",
+ Sequence["GradientEdge"],
]
# In some cases, these basic types are shadowed by corresponding