[BE] update type annotations for basic utilities in `torch/__init__.py` (#129001)
Changes:
1. Make some arguments positional-only as we only support Python 3.8+
2. Clean up `torch.typename(obj)` implementation.
3. Update type annotations., especially `is_tensor()` and `is_masked_tensor()` using `TypeGuard`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129001
Approved by: https://github.com/malfet
diff --git a/.lintrunner.toml b/.lintrunner.toml
index eedbf32..1498618 100644
--- a/.lintrunner.toml
+++ b/.lintrunner.toml
@@ -1188,9 +1188,6 @@
'torch/_export/serde/upgrade.py',
'torch/_export/trace.py',
'torch/_export/verifier.py',
- 'torch/_higher_order_ops/__init__.py',
- 'torch/_higher_order_ops/out_dtype.py',
- 'torch/_higher_order_ops/wrap.py',
'torch/_vendor/**',
'torch/ao/__init__.py',
'torch/ao/nn/__init__.py',
diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py
index e5b36c4..bc04b86 100644
--- a/test/test_fake_tensor.py
+++ b/test/test_fake_tensor.py
@@ -1036,7 +1036,7 @@
make_propagate_real_tensors_cls(FakeTensorConstHandling)
-def contains_type(type: torch._C.Type, maybe_contained_type: torch._C.Type):
+def contains_type(type: torch.Type, maybe_contained_type: torch.Type):
return maybe_contained_type.isSubtypeOf(type) or any(
contains_type(e, maybe_contained_type) for e in type.containedTypes()
)
diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py
index 65a5bf9..45d80a7 100644
--- a/test/test_public_bindings.py
+++ b/test/test_public_bindings.py
@@ -1,19 +1,26 @@
# Owner(s): ["module: autograd"]
-from torch.testing._internal.common_utils import TestCase, run_tests, IS_JETSON, IS_WINDOWS, IS_MACOS, skipIfTorchDynamo
-from torch._utils_internal import get_file_path_2
-
-import pkgutil
-import torch
import importlib
-from typing import Callable
import inspect
import json
import os
+import pkgutil
import unittest
-from importlib import import_module
from itertools import chain
from pathlib import Path
+from typing import Callable
+
+import torch
+from torch._utils_internal import get_file_path_2
+from torch.testing._internal.common_utils import (
+ IS_JETSON,
+ IS_MACOS,
+ IS_WINDOWS,
+ run_tests,
+ skipIfTorchDynamo,
+ TestCase,
+)
+
def _find_all_importables(pkg):
"""Find all importables in the project.
@@ -56,6 +63,19 @@
class TestPublicBindings(TestCase):
+ def test_no_new_reexport_callables(self):
+ """
+ This test aims to stop the introduction of new re-exported callables into
+ torch whose names do not start with _. Such callables are made available as
+ torch.XXX, which may not be desirable.
+ """
+ reexported_callables = sorted(
+ k
+ for k, v in vars(torch).items()
+ if callable(v) and not v.__module__.startswith('torch')
+ )
+ self.assertTrue(all(k.startswith('_') for k in reexported_callables), reexported_callables)
+
def test_no_new_bindings(self):
"""
This test aims to stop the introduction of new JIT bindings into torch._C
@@ -278,7 +298,6 @@
return False
return True
-
@unittest.skipIf(IS_WINDOWS or IS_MACOS, "Inductor/Distributed modules hard fail on windows and macos")
@skipIfTorchDynamo("Broken and not relevant for now")
def test_modules_can_be_imported(self):
@@ -289,7 +308,7 @@
# which calls sys.exit() when we try to import it
if "__main__" in modname:
continue
- import_module(modname)
+ importlib.import_module(modname)
except Exception as e:
# Some current failures are not ImportError
failures.append((modname, type(e)))
diff --git a/tools/pyi/gen_pyi.py b/tools/pyi/gen_pyi.py
index 33456b3..0604fce 100644
--- a/tools/pyi/gen_pyi.py
+++ b/tools/pyi/gen_pyi.py
@@ -1086,12 +1086,12 @@
"def __init__(self, other: Tensor) -> None: ...",
f"def __init__(self, size: _size, *, {DEVICE_PARAM}) -> None: ...",
],
- "as_subclass": ["def as_subclass(self, cls: Type[S]) -> S: ..."],
+ "as_subclass": ["def as_subclass(self, cls: _Type[S]) -> S: ..."],
"_make_subclass": [
"@staticmethod \ndef _make_subclass({}) -> S: ...".format(
", ".join(
[
- "cls: Type[S]",
+ "cls: _Type[S]",
"data: Tensor",
"require_grad: _bool = False",
"dispatch_strides: _bool = False",
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index a4c1b1a..721849b 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -24,7 +24,7 @@
Set,
SupportsIndex,
Tuple,
- Type,
+ Type as _Type,
TypeVar,
Union,
overload,
@@ -35,12 +35,17 @@
import numpy
import torch
-from torch import inf, SymInt, Tensor
+from torch import SymInt, Tensor, inf
+from torch._prims_common import DeviceLikeType
from torch.autograd.graph import Node as _Node
from torch.package import PackageExporter
-from torch.storage import UntypedStorage, TypedStorage
+from torch.storage import TypedStorage, UntypedStorage
from torch.types import (
+ Device,
+ Number,
+ Storage,
_bool,
+ _bytes,
_complex,
_device,
_dispatchkey,
@@ -50,17 +55,23 @@
_layout,
_qscheme,
_size,
- Device,
- Number,
- Storage,
+ _str,
)
-
-from torch._prims_common import DeviceLikeType
from torch.utils._python_dispatch import TorchDispatchMode
-# This module is defined in torch/csrc/Module.cpp
+from . import (
+ _aoti,
+ _cpu,
+ _functorch,
+ _lazy,
+ _lazy_ts_backend,
+ _nn,
+ _onnx,
+ _VariableFunctions,
+ _verbose,
+)
-from . import _functorch, _lazy, _lazy_ts_backend, _nn, _onnx, _VariableFunctions, _cpu, _aoti, _verbose
+# This module is defined in torch/csrc/Module.cpp
K = TypeVar("K")
T = TypeVar("T")
@@ -1105,7 +1116,7 @@
def _initExtension(shm_manager_path: str) -> None: ... # THPModule_initExtension
def _autograd_init() -> _bool: ... # THPAutograd_initExtension
def _add_docstr(obj: T, doc_obj: str) -> T: ... # THPModule_addDocStr
-def _init_names(arg: Sequence[Type]) -> None: ... # THPModule_initNames
+def _init_names(arg: Sequence[_Type]) -> None: ... # THPModule_initNames
def _has_distributed() -> _bool: ... # THPModule_hasDistributed
def _set_default_tensor_type(type) -> None: ... # THPModule_setDefaultTensorType
def _set_default_dtype(d: _dtype) -> None: ... # THPModule_setDefaultDtype
@@ -1235,13 +1246,13 @@
def _demangle(str) -> str: ... # c10::demangle
def _disabled_torch_function_impl(
func: Callable,
- types: Iterable[Type],
+ types: Iterable[_Type],
args: Tuple,
kwargs: Dict,
) -> Any: ... # THPModule_disable_torch_function
def _disabled_torch_dispatch_impl(
func: Callable,
- types: Iterable[Type],
+ types: Iterable[_Type],
args: Tuple,
kwargs: Dict,
) -> Any: ... # THPModule_disable_dispatch_function
@@ -1455,7 +1466,7 @@
class Generator:
device: _device
def __init__(self, device: Optional[DeviceLikeType] = None) -> None: ...
- def __reduce__(self) -> Tuple[Type[Generator], Tuple[_device], Tuple[_int, Optional[_int], Tensor]]: ...
+ def __reduce__(self) -> Tuple[_Type[Generator], Tuple[_device], Tuple[_int, Optional[_int], Tensor]]: ...
def __setstate__(self, state: Tuple[_int, Optional[_int], Tensor]) -> None: ...
def get_state(self) -> Tensor: ...
def set_state(self, _new_state: Tensor) -> Generator: ...
@@ -2146,6 +2157,24 @@
R = TypeVar("R", bound=JitType)
+class Type(JitType):
+ def str(self) -> _str: ...
+ def containedTypes(self) -> List[JitType]: ...
+ def dim(self) -> Optional[_int]: ...
+ def undefined(self) -> Optional[_bool]: ...
+ def sizes(self) -> Optional[List[_int]]: ...
+ def symbol_sizes(self) -> Optional[List[_int]]: ...
+ def varyingSizes(self) -> Optional[List[Optional[_int]]]: ...
+ def strides(self) -> Optional[List[_int]]: ...
+ def contiguous(self) -> Self: ...
+ def device(self) -> Optional[_device]: ...
+ def __eq__(self, other: object) -> _bool: ...
+ __hash__ = None # type: ignore[assignment]
+ def is_interface_type(self) -> _bool: ...
+ def requires_grad(self) -> _bool: ...
+ @property
+ def annotation_string(self) -> _str: ...
+
class AnyType(JitType):
@staticmethod
def get() -> AnyType: ...
diff --git a/torch/__init__.py b/torch/__init__.py
index 8042862..0633471 100644
--- a/torch/__init__.py
+++ b/torch/__init__.py
@@ -14,7 +14,6 @@
import ctypes
import glob
import importlib
-import importlib.util
import inspect
import math
import os
@@ -22,13 +21,24 @@
import sys
import textwrap
import threading
-from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, TYPE_CHECKING, Union
+from typing import (
+ Any as _Any,
+ Callable as _Callable,
+ Dict as _Dict,
+ Optional as _Optional,
+ Set as _Set,
+ Tuple as _Tuple,
+ Type as _Type,
+ TYPE_CHECKING,
+ Union as _Union,
+)
+from typing_extensions import TypeGuard as _TypeGuard
# multipy/deploy is setting this import before importing torch, this is the most
# reliable way we have to detect if we're running within deploy.
# https://github.com/pytorch/multipy/blob/d60f34ad38c371e441fe7ffdb77a3c3dda5a5d19/multipy/runtime/interpreter/interpreter_impl.cpp#L134-L137
-def _running_with_deploy():
+def _running_with_deploy() -> builtins.bool:
return sys.modules.get("torch._meta_registrations", None) is object
@@ -131,7 +141,7 @@
if sys.platform == "win32":
- def _load_dll_libraries():
+ def _load_dll_libraries() -> None:
import sysconfig
from torch.version import cuda as cuda_version
@@ -246,7 +256,7 @@
del _load_dll_libraries
-def _preload_cuda_deps(lib_folder, lib_name):
+def _preload_cuda_deps(lib_folder: str, lib_name: str) -> None:
"""Preloads cuda deps if they could not be found otherwise."""
# Should only be called on Linux if default path resolution have failed
assert platform.system() == "Linux", "Should only be called on Linux"
@@ -284,7 +294,7 @@
except OSError as err:
# Can only happen for wheel with cuda libs as PYPI deps
# As PyTorch is not purelib, but nvidia-*-cu12 is
- cuda_libs: Dict[str, str] = {
+ cuda_libs: _Dict[str, str] = {
"cublas": "libcublas.so.*[0-9]",
"cudnn": "libcudnn.so.*[0-9]",
"cuda_nvrtc": "libnvrtc.so.*[0-9]",
@@ -391,14 +401,14 @@
def __floordiv__(self, other):
if isinstance(other, (builtins.float, SymFloat)):
- return torch.sym_float(math.floor(sym_float(self) / other))
+ return sym_float(math.floor(sym_float(self) / other))
if not isinstance(other, (builtins.int, SymInt)):
return NotImplemented
return self.__int_floordiv__(other)
def __rfloordiv__(self, other):
if isinstance(other, (builtins.float, SymFloat)):
- return torch.sym_float(math.floor(other / sym_float(self)))
+ return sym_float(math.floor(other / sym_float(self)))
if not isinstance(other, (builtins.int, SymInt)):
return NotImplemented
return self.__rint_floordiv__(other)
@@ -528,12 +538,12 @@
def __floordiv__(self, other):
if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
return NotImplemented
- return torch.sym_float(math.floor(self / sym_float(other)))
+ return sym_float(math.floor(self / sym_float(other)))
def __rfloordiv__(self, other):
if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)):
return NotImplemented
- return torch.sym_float(math.floor(sym_float(other) / self))
+ return sym_float(math.floor(sym_float(other) / self))
def __bool__(self):
return self.node.bool_()
@@ -858,7 +868,7 @@
__name, __candidate = "", None
for __name in dir(_C):
__candidate = getattr(_C, __name)
- if type(__candidate) is type(_C):
+ if inspect.ismodule(__candidate):
# submodule
sys.modules.setdefault(f"{__name__}._C.{__name}", __candidate)
@@ -870,44 +880,42 @@
################################################################################
-def typename(o):
+def typename(obj: _Any, /) -> str:
"""
String representation of the type of an object.
This function returns a fully qualified string representation of an object's type.
Args:
- o (Object): The object whose type to represent
+ obj (object): The object whose type to represent
Returns:
str: the type of the object `o`
Example:
- >>> x = torch.tensor([1,2,3])
+ >>> x = torch.tensor([1, 2, 3])
>>> torch.typename(x)
'torch.LongTensor'
+ >>> torch.typename(torch.nn.Parameter)
+ 'torch.nn.parameter.Parameter'
"""
- if isinstance(o, torch.Tensor):
- return o.type()
+ if isinstance(obj, torch.Tensor):
+ return obj.type()
- module = ""
- class_name = ""
- if (
- hasattr(o, "__module__")
- and o.__module__ != "builtins"
- and o.__module__ != "__builtin__"
- and o.__module__ is not None
- ):
- module = o.__module__ + "."
+ module = getattr(obj, "__module__", "") or ""
+ qualname = ""
- if hasattr(o, "__qualname__"):
- class_name = o.__qualname__
- elif hasattr(o, "__name__"):
- class_name = o.__name__
+ if hasattr(obj, "__qualname__"):
+ qualname = obj.__qualname__
+ elif hasattr(obj, "__name__"):
+ qualname = obj.__name__
else:
- class_name = o.__class__.__name__
+ module = obj.__class__.__module__ or ""
+ qualname = obj.__class__.__qualname__
- return module + class_name
+ if module in {"", "builtins"}:
+ return qualname
+ return f"{module}.{qualname}"
-def is_tensor(obj):
+def is_tensor(obj: _Any, /) -> _TypeGuard["torch.Tensor"]:
r"""Returns True if `obj` is a PyTorch tensor.
Note that this function is simply doing ``isinstance(obj, Tensor)``.
@@ -916,7 +924,7 @@
``is_tensor``.
Args:
- obj (Object): Object to test
+ obj (object): Object to test
Example::
>>> x = torch.tensor([1, 2, 3])
@@ -927,7 +935,7 @@
return isinstance(obj, torch.Tensor)
-def is_storage(obj):
+def is_storage(obj: _Any, /) -> _TypeGuard[_Union["TypedStorage", "UntypedStorage"]]:
r"""Returns True if `obj` is a PyTorch storage object.
Args:
@@ -942,6 +950,7 @@
def get_default_device() -> "torch.device":
r"""Gets the default ``torch.Tensor`` to be allocated on ``device``"""
global _GLOBAL_DEVICE_CONTEXT
+
if hasattr(_GLOBAL_DEVICE_CONTEXT, "device_context"):
device = _GLOBAL_DEVICE_CONTEXT.device_context.device
if device.index is not None:
@@ -954,7 +963,9 @@
return torch.device("cpu")
-def set_default_device(device):
+def set_default_device(
+ device: _Optional[_Union["torch.device", str, builtins.int]],
+) -> None:
"""Sets the default ``torch.Tensor`` to be allocated on ``device``. This
does not affect factory function calls which are called with an explicit
``device`` argument. Factory calls will be performed as if they
@@ -1016,7 +1027,7 @@
_GLOBAL_DEVICE_CONTEXT.device_context = device_context
-def set_default_tensor_type(t):
+def set_default_tensor_type(t: _Union[_Type["torch.Tensor"], str], /) -> None:
r"""
.. warning::
@@ -1047,7 +1058,7 @@
_C._set_default_tensor_type(t)
-def set_default_dtype(d):
+def set_default_dtype(d: "torch.dtype", /) -> None:
r"""
Sets the default floating point dtype to :attr:`d`. Supports floating point dtype
@@ -1257,7 +1268,7 @@
return _C._get_deterministic_algorithms_warn_only()
-def set_deterministic_debug_mode(debug_mode: Union[builtins.int, str]) -> None:
+def set_deterministic_debug_mode(debug_mode: _Union[builtins.int, str]) -> None:
r"""Sets the debug mode for deterministic operations.
.. note:: This is an alternative interface for
@@ -1316,7 +1327,7 @@
return 0
-def get_float32_matmul_precision() -> builtins.str:
+def get_float32_matmul_precision() -> str:
r"""Returns the current value of float32 matrix multiplication precision. Refer to
:func:`torch.set_float32_matmul_precision` documentation for more details.
"""
@@ -1389,7 +1400,7 @@
_C._set_float32_matmul_precision(precision)
-def set_warn_always(b: builtins.bool) -> None:
+def set_warn_always(b: builtins.bool, /) -> None:
r"""When this flag is False (default) then some PyTorch warnings may only
appear once per process. This helps avoid excessive warning information.
Setting it to True causes these warnings to always appear, which may be
@@ -1419,10 +1430,10 @@
def _check_with(
error_type,
- cond: Union[builtins.bool, SymBool],
- message: Callable[[], str],
+ cond: _Union[builtins.bool, SymBool],
+ message: _Callable[[], str],
): # noqa: F811
- if not isinstance(cond, (builtins.bool, torch.SymBool)):
+ if not isinstance(cond, (builtins.bool, SymBool)):
raise TypeError(f"cond must be a bool, but got {type(cond)}")
from torch.fx.experimental.symbolic_shapes import expect_true
@@ -1557,13 +1568,13 @@
def _check_tensor_all_with(error_type, cond, message=None): # noqa: F811
- if not torch.is_tensor(cond):
+ if not is_tensor(cond):
raise TypeError(f"cond must be a tensor, but got {type(cond)}")
if not cond.dtype == torch.bool:
raise TypeError(f"cond tensor must have dtype torch.bool, but got {cond.dtype}")
- _check_with(error_type, cond._is_all_true().item(), message)
+ _check_with(error_type, cond._is_all_true().item(), message) # type: ignore[arg-type]
# C++ equivalent: `TORCH_CHECK_TENSOR_ALL`
@@ -1614,10 +1625,9 @@
UntypedStorage,
)
+
# NOTE: New <type>Storage classes should never be added. When adding a new
# dtype, use torch.storage.TypedStorage directly.
-
-
class ByteStorage(_LegacyStorage):
@classproperty
def dtype(self):
@@ -1805,7 +1815,7 @@
return torch.quint2x4
-_storage_classes = {
+_storage_classes: _Set[_Type[_Union[TypedStorage, UntypedStorage]]] = {
UntypedStorage,
DoubleStorage,
FloatStorage,
@@ -1828,7 +1838,7 @@
}
# The _tensor_classes set is initialized by the call to initialize_python_bindings.
-_tensor_classes: Set[Type] = set()
+_tensor_classes: _Set[_Type["torch.Tensor"]] = set()
# If you edit these imports, please update torch/__init__.py.in as well
from torch import amp as amp, random as random, serialization as serialization
@@ -2067,7 +2077,7 @@
compiler_name = "inductor"
def __init__(self, mode, options, dynamic):
- self.config: Dict[str, Any] = dict()
+ self.config: _Dict[str, _Any] = dict()
self.dynamic = dynamic
self.apply_mode(mode)
self.apply_options(options)
@@ -2091,7 +2101,7 @@
and self.dynamic == other.dynamic
)
- def apply_mode(self, mode: Optional[str]):
+ def apply_mode(self, mode: _Optional[str]):
if mode is None or mode == "default":
pass
elif mode in {"reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"}:
@@ -2103,13 +2113,13 @@
f"Unrecognized mode={mode}, should be one of: default, reduce-overhead, max-autotune, max-autotune-no-cudagraphs"
)
- def apply_options(self, options: Optional[Dict[str, Any]]):
+ def apply_options(self, options: _Optional[_Dict[str, _Any]]):
if not options:
return
from torch._inductor import config
- current_config: Dict[str, Any] = config.shallow_copy_dict()
+ current_config: _Dict[str, _Any] = config.shallow_copy_dict()
for key, val in options.items():
attr_name = key.replace("-", "_")
@@ -2181,15 +2191,15 @@
def compile(
- model: Optional[Callable] = None,
+ model: _Optional[_Callable] = None,
*,
fullgraph: builtins.bool = False,
- dynamic: Optional[builtins.bool] = None,
- backend: Union[str, Callable] = "inductor",
- mode: Union[str, None] = None,
- options: Optional[Dict[str, Union[str, builtins.int, builtins.bool]]] = None,
+ dynamic: _Optional[builtins.bool] = None,
+ backend: _Union[str, _Callable] = "inductor",
+ mode: _Union[str, None] = None,
+ options: _Optional[_Dict[str, _Union[str, builtins.int, builtins.bool]]] = None,
disable: builtins.bool = False,
-) -> Callable:
+) -> _Callable:
"""
Optimizes given model/function using TorchDynamo and specified backend.
If you are compiling an :class:`torch.nn.Module`, you can also use :meth:`torch.nn.Module.compile`
@@ -2281,7 +2291,7 @@
# Decorator mode
if model is None:
- def fn(model: Callable):
+ def fn(model: _Callable):
if model is None:
raise RuntimeError("Model can't be None")
return compile(
@@ -2315,11 +2325,6 @@
)(model)
-from torch import export as export
-
-from torch._higher_order_ops import cond, while_loop
-
-
def _register_device_module(device_type, module):
r"""Register an external runtime module of the specific :attr:`device_type`
supported by torch.
@@ -2340,8 +2345,14 @@
sys.modules[torch_module_name] = module
-# expose return_types
-from torch import library as library, return_types as return_types
+from torch import (
+ export as export,
+ func as func,
+ library as library,
+ return_types as return_types,
+)
+from torch._higher_order_ops import cond as cond, while_loop as while_loop
+from torch.func import vmap as vmap
if not TYPE_CHECKING:
from torch import _meta_registrations
@@ -2355,10 +2366,6 @@
# Populate magic methods on SymInt and SymFloat
import torch.fx.experimental.sym_node
-from torch import func as func
-from torch.func import vmap as vmap
-
-
# Register MPS specific decomps
torch.backends.mps._init()
@@ -2367,7 +2374,7 @@
class _TritonLibrary:
lib = torch.library.Library("triton", "DEF")
- ops_table: Dict[Tuple[str, str], Callable] = {}
+ ops_table: _Dict[_Tuple[str, str], _Callable] = {}
@classmethod
def registerOp(cls, op_key, full_schema, op_impl, dispatch_key):
@@ -2421,7 +2428,7 @@
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
-def get_device_module(device: Optional[Union[torch.device, str]] = None):
+def get_device_module(device: _Optional[_Union[torch.device, str]] = None):
"""
Returns the module associated with a given device(e.g., torch.device('cuda'), "mtia:0", "xpu", ...).
If no device is given, return the module for the current accelerator or CPU if none is present.
@@ -2447,8 +2454,8 @@
def _constrain_as_size(
symbol,
- min: Optional[builtins.int] = None,
- max: Optional[builtins.int] = None,
+ min: _Optional[builtins.int] = None,
+ max: _Optional[builtins.int] = None,
):
"""
This indicates that a given int is size-like, and can be used in any context where a size is expected.
diff --git a/torch/_higher_order_ops/__init__.py b/torch/_higher_order_ops/__init__.py
index 99b3577..5bfca7a 100644
--- a/torch/_higher_order_ops/__init__.py
+++ b/torch/_higher_order_ops/__init__.py
@@ -1,3 +1,14 @@
-from .cond import cond
-from .while_loop import while_loop
-from .flex_attention import flex_attention, flex_attention_backward
+from torch._higher_order_ops.cond import cond
+from torch._higher_order_ops.flex_attention import (
+ flex_attention,
+ flex_attention_backward,
+)
+from torch._higher_order_ops.while_loop import while_loop
+
+
+__all__ = [
+ "cond",
+ "while_loop",
+ "flex_attention",
+ "flex_attention_backward",
+]
diff --git a/torch/_higher_order_ops/associative_scan.py b/torch/_higher_order_ops/associative_scan.py
index 0d88aa0..3f9c6c1 100644
--- a/torch/_higher_order_ops/associative_scan.py
+++ b/torch/_higher_order_ops/associative_scan.py
@@ -4,12 +4,9 @@
from typing import Callable, List
import torch
-
import torch._prims_common as utils
import torch._subclasses.functional_tensor
-
import torch.utils._pytree as pytree
-
from torch._C import DispatchKey
from torch._C._functorch import _add_batch_dim, get_unwrapped, maybe_get_bdim
from torch._higher_order_ops.utils import (
@@ -18,7 +15,6 @@
reenter_make_fx,
unique_graph_id,
)
-
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import (
@@ -27,6 +23,7 @@
track_tensor_tree,
)
+
aten = torch._ops.ops.aten
diff --git a/torch/_higher_order_ops/cond.py b/torch/_higher_order_ops/cond.py
index f4fe64d..ff2f061 100644
--- a/torch/_higher_order_ops/cond.py
+++ b/torch/_higher_order_ops/cond.py
@@ -3,9 +3,7 @@
import torch
import torch._subclasses.functional_tensor
-
import torch.utils._pytree as pytree
-
from torch._C import DispatchKey
from torch._C._functorch import (
_add_batch_dim,
@@ -15,7 +13,6 @@
)
from torch._functorch.utils import exposed_in
from torch._guards import detect_fake_mode
-
from torch._higher_order_ops.utils import (
_has_potential_branch_input_alias,
_has_potential_branch_input_mutation,
@@ -25,7 +22,6 @@
unique_graph_id,
UnsupportedAliasMutationException,
)
-
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import (
diff --git a/torch/_higher_order_ops/effects.py b/torch/_higher_order_ops/effects.py
index c5a4488..826932c 100644
--- a/torch/_higher_order_ops/effects.py
+++ b/torch/_higher_order_ops/effects.py
@@ -5,6 +5,7 @@
import torch
import torch.utils._pytree as pytree
from torch._C import DispatchKey
+from torch._higher_order_ops.torchbind import call_torchbind
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import (
@@ -12,7 +13,6 @@
ProxyTorchDispatchMode,
track_tensor_tree,
)
-from .torchbind import call_torchbind
class _EffectType(Enum):
diff --git a/torch/_higher_order_ops/flex_attention.py b/torch/_higher_order_ops/flex_attention.py
index 38b01f7..f800237 100644
--- a/torch/_higher_order_ops/flex_attention.py
+++ b/torch/_higher_order_ops/flex_attention.py
@@ -18,7 +18,6 @@
track_tensor_tree,
)
from torch.fx.graph_module import GraphModule
-
from torch.overrides import TorchFunctionMode
@@ -288,7 +287,6 @@
from torch._dispatch.python import suspend_functionalization
from torch._functorch.aot_autograd import AOTConfig, create_joint
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
-
from torch._subclasses.functional_tensor import disable_functional_mode
from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing
diff --git a/torch/_higher_order_ops/map.py b/torch/_higher_order_ops/map.py
index f5bf1d4..943740c 100644
--- a/torch/_higher_order_ops/map.py
+++ b/torch/_higher_order_ops/map.py
@@ -4,7 +4,6 @@
from torch._C import DispatchKey
from torch._dispatch.python import suspend_functionalization
from torch._functorch.aot_autograd import AOTConfig, create_joint, from_fun
-
from torch._higher_order_ops.utils import (
_has_potential_branch_input_alias,
_has_potential_branch_input_mutation,
diff --git a/torch/_higher_order_ops/out_dtype.py b/torch/_higher_order_ops/out_dtype.py
index a3f5e21..58bf02b 100644
--- a/torch/_higher_order_ops/out_dtype.py
+++ b/torch/_higher_order_ops/out_dtype.py
@@ -2,17 +2,18 @@
import torch
import torch.utils._pytree as pytree
+from torch._C import DispatchKey
+from torch._higher_order_ops.utils import autograd_not_implemented
+from torch._ops import HigherOrderOperator
+from torch._prims_common import elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND
+from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import (
disable_proxy_modes_tracing,
+ maybe_handle_decomp,
ProxyTorchDispatchMode,
track_tensor_tree,
- maybe_handle_decomp,
)
-from torch._C import DispatchKey
-from torch._ops import HigherOrderOperator
-from torch._subclasses.fake_tensor import FakeTensorMode
-from torch._prims_common import elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND
-from torch._higher_order_ops.utils import autograd_not_implemented
+
# TODO to figure out a more generic approach
ALLOWABLE_OPS = [
@@ -43,7 +44,6 @@
3. Cast the output to `out_dtype`
"""
-
def __init__(self):
super().__init__("out_dtype")
# TODO(ydwu4): Subclassing HigherOrderOperator causes __module__ to
@@ -55,10 +55,12 @@
if not isinstance(op, torch._ops.OpOverload):
raise ValueError("out_dtype's first argument must be an OpOverload")
if op._schema.is_mutable:
- raise ValueError("out_dtype's first argument needs to be a functional operator")
+ raise ValueError(
+ "out_dtype's first argument needs to be a functional operator"
+ )
if not (
- len(op._schema.returns) == 1 and
- isinstance(op._schema.returns[0].type, torch.TensorType)
+ len(op._schema.returns) == 1
+ and isinstance(op._schema.returns[0].type, torch.TensorType)
):
raise ValueError(
"out_dtype's can only apply to ops that return a single tensor"
@@ -77,6 +79,7 @@
out_dtype = OutDtypeOperator()
+
def trace_out_dtype(proxy_mode, func_overload, op, output_dtype, *args):
# NB: Long-term we should put the decomposition logic into
# ProxyTorchDispatchMode so that people do not need to call maybe_handle_decomp
@@ -99,11 +102,7 @@
@out_dtype.py_impl(DispatchKey.CompositeExplicitAutograd)
-def out_dtype_dense(
- op: torch._ops.OpOverload,
- output_dtype: torch.dtype,
- *args
-):
+def out_dtype_dense(op: torch._ops.OpOverload, output_dtype: torch.dtype, *args):
if is_int_mm(op, output_dtype, args):
return torch._int_mm(*args)
return out_dtype_fallback(op, output_dtype, *args)
@@ -111,13 +110,13 @@
def is_int_mm(op, output_dtype, args):
return (
- op == torch.ops.aten.mm.default and
- output_dtype == torch.int32 and
- len(args) == 2 and
- args[0].dtype == torch.int8 and
- args[1].dtype == torch.int8 and
- args[0].is_cuda and
- args[1].is_cuda
+ op == torch.ops.aten.mm.default
+ and output_dtype == torch.int32
+ and len(args) == 2
+ and args[0].dtype == torch.int8
+ and args[1].dtype == torch.int8
+ and args[0].is_cuda
+ and args[1].is_cuda
)
@@ -135,7 +134,9 @@
return res
-out_dtype.py_impl(DispatchKey.Autograd)(autograd_not_implemented(out_dtype, deferred_error=True))
+out_dtype.py_impl(DispatchKey.Autograd)(
+ autograd_not_implemented(out_dtype, deferred_error=True)
+)
@out_dtype.py_impl(ProxyTorchDispatchMode)
@@ -143,7 +144,7 @@
mode: ProxyTorchDispatchMode,
op: torch._ops.OpOverload,
output_dtype: torch.dtype,
- *args
+ *args,
):
if mode.enable_tracing:
return trace_out_dtype(mode, out_dtype, op, output_dtype, *args)
@@ -156,7 +157,7 @@
mode: FakeTensorMode,
op: torch._ops.OpOverload,
output_dtype: torch.dtype,
- *args
+ *args,
):
with mode:
return out_dtype_dense(op, output_dtype, *args)
diff --git a/torch/_higher_order_ops/strict_mode.py b/torch/_higher_order_ops/strict_mode.py
index d781248..1922519 100644
--- a/torch/_higher_order_ops/strict_mode.py
+++ b/torch/_higher_order_ops/strict_mode.py
@@ -1,12 +1,9 @@
# mypy: allow-untyped-defs
import torch
import torch._subclasses.functional_tensor
-
import torch.utils._pytree as pytree
-
from torch._C import DispatchKey
from torch._functorch.utils import exposed_in
-
from torch._higher_order_ops.utils import _set_compilation_env, autograd_not_implemented
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
diff --git a/torch/_higher_order_ops/torchbind.py b/torch/_higher_order_ops/torchbind.py
index e44ba28..8db1b8f 100644
--- a/torch/_higher_order_ops/torchbind.py
+++ b/torch/_higher_order_ops/torchbind.py
@@ -13,6 +13,7 @@
from torch.fx.node import has_side_effect
from torch.utils import _pytree as pytree
+
log = logging.getLogger(__name__)
# The call_torchbind operator represents a method invocation on a torchbind
diff --git a/torch/_higher_order_ops/triton_kernel_wrap.py b/torch/_higher_order_ops/triton_kernel_wrap.py
index 5552ef1..a335912 100644
--- a/torch/_higher_order_ops/triton_kernel_wrap.py
+++ b/torch/_higher_order_ops/triton_kernel_wrap.py
@@ -18,6 +18,7 @@
track_tensor_tree,
)
+
log = logging.getLogger("torch._dynamo")
diff --git a/torch/_higher_order_ops/while_loop.py b/torch/_higher_order_ops/while_loop.py
index 4577036..baf5500 100644
--- a/torch/_higher_order_ops/while_loop.py
+++ b/torch/_higher_order_ops/while_loop.py
@@ -3,9 +3,7 @@
import torch
import torch.utils._pytree as pytree
-
from torch._C import DispatchKey
-
from torch._higher_order_ops.utils import (
_has_potential_branch_input_alias,
_has_potential_branch_input_mutation,
diff --git a/torch/_higher_order_ops/wrap.py b/torch/_higher_order_ops/wrap.py
index e7fe553..946dc11 100644
--- a/torch/_higher_order_ops/wrap.py
+++ b/torch/_higher_order_ops/wrap.py
@@ -4,15 +4,16 @@
import logging
import torch
+import torch._dynamo.config
from torch._ops import HigherOrderOperator
from torch.utils.checkpoint import checkpoint
-import torch._dynamo.config
log = logging.getLogger(__name__)
uid = itertools.count(1)
+
# Used for testing the HigherOrderOperator mechanism
class Wrap(HigherOrderOperator):
def __init__(self):
@@ -31,8 +32,10 @@
return wrapper()
+
wrap = Wrap()
+
class WrapWithSetGradEnabled(HigherOrderOperator):
def __init__(self):
super().__init__("wrap_with_set_grad_enabled")
@@ -47,10 +50,13 @@
def wrapper():
with torch.set_grad_enabled(enable_grad):
return wrapped_func(*args, **kwargs)
+
return wrapper()
+
wrap_with_set_grad_enabled = WrapWithSetGradEnabled()
+
class WrapActivationCheckpoint(HigherOrderOperator):
"""
This operator is used to wrap torch.utils.checkpoint. This avoids
@@ -68,6 +74,7 @@
that duplication/recomputation is done as a compiler pass in the
partitioners. See TagActivationCheckpoint for more information.
"""
+
def __init__(self):
super().__init__("wrap_activation_checkpoint")
@@ -77,14 +84,17 @@
# version of checkpointing.
import torch.fx.traceback as fx_traceback
from torch.fx import Interpreter
+
kwargs["use_reentrant"] = False
kwargs["preserve_rng_state"] = False
# Using interpreter allows preservation of metadata through torch.compile stack.
with fx_traceback.preserve_node_meta():
return checkpoint(Interpreter(function).run, *args, **kwargs)
+
wrap_activation_checkpoint = WrapActivationCheckpoint()
+
class TagActivationCheckpoint(HigherOrderOperator):
"""
This operator is supposed to be used only with torch.compile stack. This
@@ -136,8 +146,12 @@
# `preserve_rng_state` is not a regular kwarg
checkpoint_keys.add("preserve_rng_state")
- checkpoint_kwargs = {name: kwargs[name] for name in kwargs.keys() if name in checkpoint_keys}
- gmod_kwargs = {name: kwargs[name] for name in kwargs.keys() if name not in checkpoint_keys}
+ checkpoint_kwargs = {
+ name: kwargs[name] for name in kwargs.keys() if name in checkpoint_keys
+ }
+ gmod_kwargs = {
+ name: kwargs[name] for name in kwargs.keys() if name not in checkpoint_keys
+ }
return checkpoint_kwargs, gmod_kwargs
def tag_nodes(self, gmod):
@@ -150,13 +164,17 @@
def __call__(self, gmod, *args, **kwargs):
import torch.fx.traceback as fx_traceback
from torch.fx import Interpreter
+
if "_checkpoint_context_fn" in gmod.meta:
- assert torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint, \
- "Passing context_fn to torch.utils.checkpoint is currently not supported under torch.compile"
- log.warning("""
+ assert (
+ torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint
+ ), "Passing context_fn to torch.utils.checkpoint is currently not supported under torch.compile"
+ log.warning(
+ """
Detected that context_fn is passed to torch.utils.checkpoint under torch.compile.
Please make sure the checkpointed region does not contain in-place ops (e.g. torch.relu_).
-""")
+"""
+ )
# use_reentrant is set to False because this op is going to be traced.
# And we ensure that AOT Autograd traces through the non reentrant
# version of checkpointing.
@@ -183,4 +201,5 @@
with fx_traceback.preserve_node_meta():
return Interpreter(gmod).run(*args)
+
tag_activation_checkpoint = TagActivationCheckpoint()
diff --git a/torch/masked/maskedtensor/core.py b/torch/masked/maskedtensor/core.py
index 0933a80..69850df 100644
--- a/torch/masked/maskedtensor/core.py
+++ b/torch/masked/maskedtensor/core.py
@@ -2,6 +2,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import warnings
+from typing import Any
+from typing_extensions import TypeGuard
import torch
from torch.overrides import get_default_nowrap_functions
@@ -13,7 +15,7 @@
]
-def is_masked_tensor(a):
+def is_masked_tensor(obj: Any, /) -> TypeGuard["MaskedTensor"]:
r"""Returns True if the input is a MaskedTensor, else False
Args:
@@ -29,7 +31,7 @@
>>> is_masked_tensor(mt)
True
"""
- return isinstance(a, MaskedTensor)
+ return isinstance(obj, MaskedTensor)
def _tensors_match(a, b, exact=True, rtol=1e-05, atol=1e-08):
@@ -147,13 +149,14 @@
if is_masked_tensor(mask) or not torch.is_tensor(mask):
raise TypeError("mask must be a Tensor")
# Use a Tensor that of the give size for the wrapper.
- kwargs = {}
- kwargs["device"] = data.device
- kwargs["dtype"] = data.dtype
- kwargs["layout"] = data.layout
- kwargs["requires_grad"] = requires_grad
- kwargs["dispatch_sizes_strides_policy"] = "strides"
- kwargs["dispatch_layout"] = True
+ kwargs = {
+ "device": data.device,
+ "dtype": data.dtype,
+ "layout": data.layout,
+ "requires_grad": requires_grad,
+ "dispatch_sizes_strides_policy": "strides",
+ "dispatch_layout": True,
+ }
warnings.warn(
(
"The PyTorch API of MaskedTensors is in prototype stage "
@@ -162,12 +165,14 @@
"module for further information about the project."
),
UserWarning,
+ stacklevel=2,
)
if data.requires_grad:
warnings.warn(
"It is not recommended to create a MaskedTensor with a tensor that requires_grad. "
"To avoid this, you can use data.clone().detach()",
UserWarning,
+ stacklevel=2,
)
return torch.Tensor._make_wrapper_subclass(cls, data.size(), **kwargs) # type: ignore[attr-defined]
diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py
index 3fdcbe7..2a4db35 100644
--- a/torch/nn/parallel/distributed.py
+++ b/torch/nn/parallel/distributed.py
@@ -1598,7 +1598,9 @@
treespec,
output_is_rref,
) = _tree_flatten_with_rref(output)
- output_placeholders = [None for _ in range(len(output_tensor_list))]
+ output_placeholders: List[Optional[torch.Tensor]] = [
+ None for _ in range(len(output_tensor_list))
+ ]
# Do not touch tensors that have no grad_fn, which can cause issues
# such as https://github.com/pytorch/pytorch/issues/60733
for i, output in enumerate(output_tensor_list):
diff --git a/torch/types.py b/torch/types.py
index 67875f8..c15ba04 100644
--- a/torch/types.py
+++ b/torch/types.py
@@ -1,5 +1,17 @@
# mypy: allow-untyped-defs
-import builtins
+
+# In some cases, these basic types are shadowed by corresponding
+# top-level values. The underscore variants let us refer to these
+# types. See https://github.com/python/mypy/issues/4146 for why these
+# workarounds is necessary
+from builtins import ( # noqa: F401
+ bool as _bool,
+ bytes as _bytes,
+ complex as _complex,
+ float as _float,
+ int as _int,
+ str as _str,
+)
from typing import Any, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
import torch
@@ -20,71 +32,75 @@
Sequence["GradientEdge"],
]
-# In some cases, these basic types are shadowed by corresponding
-# top-level values. The underscore variants let us refer to these
-# types. See https://github.com/python/mypy/issues/4146 for why these
-# workarounds is necessary
-_int = builtins.int
-_float = builtins.float
-_bool = builtins.bool
-_complex = builtins.complex
-
_dtype = torch.dtype
_device = torch.device
_qscheme = torch.qscheme
-_size = Union[torch.Size, List[_int], Tuple[_int, ...]]
_layout = torch.layout
-_dispatchkey = Union[str, torch._C.DispatchKey]
+_size = Union[torch.Size, List[_int], Tuple[_int, ...]]
+_dispatchkey = Union[_str, torch._C.DispatchKey]
# Meta-type for "numeric" things; matches our docs
-Number = Union[builtins.int, builtins.float, builtins.bool]
+Number = Union[_int, _float, _bool]
# Meta-type for "device-like" things. Not to be confused with 'device' (a
# literal device object). This nomenclature is consistent with PythonArgParser.
# None means use the default device (typically CPU)
-Device = Optional[Union[_device, str, _int]]
+Device = Optional[Union[_device, _str, _int]]
del Optional
# Storage protocol implemented by ${Type}StorageBase classes
class Storage:
- _cdata: int
+ _cdata: _int
device: torch.device
dtype: torch.dtype
- _torch_load_uninitialized: bool
+ _torch_load_uninitialized: _bool
- def __deepcopy__(self, memo) -> "Storage": # type: ignore[empty-body]
+ def __deepcopy__(self, memo: dict) -> "Storage": # type: ignore[empty-body]
...
- def _new_shared(self, int) -> "Storage": # type: ignore[empty-body]
+ def _new_shared(self, size: _int) -> "Storage": # type: ignore[empty-body]
...
def _write_file(
- self, f: Any, is_real_file: _bool, save_size: _bool, element_size: int
+ self,
+ f: Any,
+ is_real_file: _bool,
+ save_size: _bool,
+ element_size: _int,
) -> None:
...
- def element_size(self) -> int: # type: ignore[empty-body]
+ def element_size(self) -> _int: # type: ignore[empty-body]
...
- def is_shared(self) -> bool: # type: ignore[empty-body]
+ def is_shared(self) -> _bool: # type: ignore[empty-body]
...
def share_memory_(self) -> "Storage": # type: ignore[empty-body]
...
- def nbytes(self) -> int: # type: ignore[empty-body]
+ def nbytes(self) -> _int: # type: ignore[empty-body]
...
def cpu(self) -> "Storage": # type: ignore[empty-body]
...
- def data_ptr(self) -> int: # type: ignore[empty-body]
+ def data_ptr(self) -> _int: # type: ignore[empty-body]
...
- def from_file(self, filename: str, shared: bool = False, nbytes: int = 0) -> "Storage": # type: ignore[empty-body]
+ def from_file( # type: ignore[empty-body]
+ self,
+ filename: _str,
+ shared: _bool = False,
+ nbytes: _int = 0,
+ ) -> "Storage":
...
- def _new_with_file(self, f: Any, element_size: int) -> "Storage": # type: ignore[empty-body]
+ def _new_with_file( # type: ignore[empty-body]
+ self,
+ f: Any,
+ element_size: _int,
+ ) -> "Storage":
...
diff --git a/torchgen/fuse/gen_patterns.py b/torchgen/fuse/gen_patterns.py
index 18562f5..0861c88 100644
--- a/torchgen/fuse/gen_patterns.py
+++ b/torchgen/fuse/gen_patterns.py
@@ -7,12 +7,11 @@
if __name__ == "__main__":
# Start by deleting all the existing patterns.
- for file in os.listdir(pattern_matcher.SERIALIZED_PATTERN_PATH):
- if file in ("__init__.py", "__pycache__"):
+ for path in pattern_matcher.SERIALIZED_PATTERN_PATH.iterdir():
+ if path.name in {"__init__.py", "__pycache__"}:
continue
- file = pattern_matcher.SERIALIZED_PATTERN_PATH / file
- if file.is_file():
- file.unlink()
+ if path.is_file():
+ path.unlink()
# Now have joint_graph load all known patterns and tell the pattern matcher
# to serialize the patterns as it goes.