| # @lint-ignore-every PYTHON3COMPATIMPORTS | 
 |  | 
 | r""" | 
 | The torch package contains data structures for multi-dimensional | 
 | tensors and mathematical operations over these are defined. | 
 | Additionally, it provides many utilities for efficient serializing of | 
 | Tensors and arbitrary types, and other useful utilities. | 
 |  | 
 | It has a CUDA counterpart, that enables you to run your tensor computations | 
 | on an NVIDIA GPU with compute capability >= 3.0. | 
 | """ | 
 |  | 
 | import os | 
 | import sys | 
 | import platform | 
 | from ._utils import _import_dotted_name | 
 | from ._utils_internal import get_file_path, prepare_multiprocessing_environment | 
 | from .version import __version__  # noqa: F401 | 
 | from ._six import string_classes as _string_classes | 
 |  | 
 | __all__ = [ | 
 |     'typename', 'is_tensor', 'is_storage', 'set_default_tensor_type', | 
 |     'set_rng_state', 'get_rng_state', 'manual_seed', 'initial_seed', 'seed', | 
 |     'save', 'load', 'set_printoptions', 'chunk', 'split', 'stack', 'matmul', | 
 |     'no_grad', 'enable_grad', 'rand', 'randn', | 
 |     'DoubleStorage', 'FloatStorage', 'LongStorage', 'IntStorage', | 
 |     'ShortStorage', 'CharStorage', 'ByteStorage', 'BoolStorage', | 
 |     'DoubleTensor', 'FloatTensor', 'LongTensor', 'IntTensor', | 
 |     'ShortTensor', 'CharTensor', 'ByteTensor', 'BoolTensor', 'Tensor', | 
 | ] | 
 |  | 
 | ################################################################################ | 
 | # Load the extension module | 
 | ################################################################################ | 
 |  | 
 | # Loading the extension with RTLD_GLOBAL option allows to not link extension | 
 | # modules against the _C shared object. Their missing THP symbols will be | 
 | # automatically filled by the dynamic loader. | 
 | import os as _dl_flags | 
 |  | 
 | # if we have numpy, it *must* be imported before the call to setdlopenflags() | 
 | # or there is risk that later c modules will segfault when importing numpy | 
 | try: | 
 |     import numpy as _np  # noqa: F401 | 
 | except ImportError: | 
 |     pass | 
 |  | 
 | if platform.system() == 'Windows': | 
 |     # first get nvToolsExt PATH | 
 |     def get_nvToolsExt_path(): | 
 |         NVTOOLEXT_HOME = _dl_flags.getenv('NVTOOLSEXT_PATH', 'C:\\Program Files\\NVIDIA Corporation\\NvToolsExt') | 
 |  | 
 |         if _dl_flags.path.exists(NVTOOLEXT_HOME): | 
 |             return _dl_flags.path.join(NVTOOLEXT_HOME, 'bin', 'x64') | 
 |         else: | 
 |             return '' | 
 |  | 
 |     py_dll_path = _dl_flags.path.join(sys.exec_prefix, 'Library', 'bin') | 
 |     th_dll_path = _dl_flags.path.join(_dl_flags.path.dirname(__file__), 'lib') | 
 |  | 
 |     dll_paths = [th_dll_path, py_dll_path, get_nvToolsExt_path(), _dl_flags.environ['PATH']] | 
 |  | 
 |     # then add the path to env | 
 |     _dl_flags.environ['PATH'] = ';'.join(dll_paths) | 
 |  | 
 | else: | 
 |     # first check if the os package has the required flags | 
 |     if not hasattr(_dl_flags, 'RTLD_GLOBAL') or not hasattr(_dl_flags, 'RTLD_LAZY'): | 
 |         try: | 
 |             # next try if DLFCN exists | 
 |             import DLFCN as _dl_flags | 
 |         except ImportError: | 
 |             # as a last attempt, use compile-time constants | 
 |             import torch._dl as _dl_flags | 
 |  | 
 |     old_flags = sys.getdlopenflags() | 
 |     sys.setdlopenflags(_dl_flags.RTLD_GLOBAL | _dl_flags.RTLD_LAZY) | 
 |  | 
 | del _dl_flags | 
 |  | 
 | from torch._C import * | 
 |  | 
 | __all__ += [name for name in dir(_C) | 
 |             if name[0] != '_' and | 
 |             not name.endswith('Base')] | 
 |  | 
 | if platform.system() != 'Windows': | 
 |     sys.setdlopenflags(old_flags) | 
 |     del old_flags | 
 |  | 
 | ################################################################################ | 
 | # Define basic utilities | 
 | ################################################################################ | 
 |  | 
 |  | 
 | def typename(o): | 
 |     if isinstance(o, torch.Tensor): | 
 |         return o.type() | 
 |  | 
 |     module = '' | 
 |     class_name = '' | 
 |     if hasattr(o, '__module__') and o.__module__ != 'builtins' \ | 
 |             and o.__module__ != '__builtin__' and o.__module__ is not None: | 
 |         module = o.__module__ + '.' | 
 |  | 
 |     if hasattr(o, '__qualname__'): | 
 |         class_name = o.__qualname__ | 
 |     elif hasattr(o, '__name__'): | 
 |         class_name = o.__name__ | 
 |     else: | 
 |         class_name = o.__class__.__name__ | 
 |  | 
 |     return module + class_name | 
 |  | 
 |  | 
 | def is_tensor(obj): | 
 |     r"""Returns True if `obj` is a PyTorch tensor. | 
 |  | 
 |     Args: | 
 |         obj (Object): Object to test | 
 |     """ | 
 |     return isinstance(obj, torch.Tensor) | 
 |  | 
 |  | 
 | def is_storage(obj): | 
 |     r"""Returns True if `obj` is a PyTorch storage object. | 
 |  | 
 |     Args: | 
 |         obj (Object): Object to test | 
 |     """ | 
 |     return type(obj) in _storage_classes | 
 |  | 
 |  | 
 | def set_default_tensor_type(t): | 
 |     r"""Sets the default ``torch.Tensor`` type to floating point tensor type | 
 |     ``t``. This type will also be used as default floating point type for | 
 |     type inference in :func:`torch.tensor`. | 
 |  | 
 |     The default floating point tensor type is initially ``torch.FloatTensor``. | 
 |  | 
 |     Args: | 
 |         t (type or string): the floating point tensor type or its name | 
 |  | 
 |     Example:: | 
 |  | 
 |         >>> torch.tensor([1.2, 3]).dtype    # initial default for floating point is torch.float32 | 
 |         torch.float32 | 
 |         >>> torch.set_default_tensor_type(torch.DoubleTensor) | 
 |         >>> torch.tensor([1.2, 3]).dtype    # a new floating point tensor | 
 |         torch.float64 | 
 |  | 
 |     """ | 
 |     if isinstance(t, _string_classes): | 
 |         t = _import_dotted_name(t) | 
 |     _C._set_default_tensor_type(t) | 
 |  | 
 |  | 
 | def set_default_dtype(d): | 
 |     r"""Sets the default floating point dtype to :attr:`d`. This type will be | 
 |     used as default floating point type for type inference in | 
 |     :func:`torch.tensor`. | 
 |  | 
 |     The default floating point dtype is initially ``torch.float32``. | 
 |  | 
 |     Args: | 
 |         d (:class:`torch.dtype`): the floating point dtype to make the default | 
 |  | 
 |     Example:: | 
 |  | 
 |         >>> torch.tensor([1.2, 3]).dtype           # initial default for floating point is torch.float32 | 
 |         torch.float32 | 
 |         >>> torch.set_default_dtype(torch.float64) | 
 |         >>> torch.tensor([1.2, 3]).dtype           # a new floating point tensor | 
 |         torch.float64 | 
 |  | 
 |     """ | 
 |     _C._set_default_dtype(d) | 
 |  | 
 | # If you edit these imports, please update torch/__init__.py.in as well | 
 | from .random import set_rng_state, get_rng_state, manual_seed, initial_seed, seed | 
 | from .serialization import save, load | 
 | from ._tensor_str import set_printoptions | 
 |  | 
 | ################################################################################ | 
 | # Define Storage and Tensor classes | 
 | ################################################################################ | 
 |  | 
 | from .tensor import Tensor | 
 | from .storage import _StorageBase | 
 |  | 
 |  | 
 | class DoubleStorage(_C.DoubleStorageBase, _StorageBase): | 
 |     pass | 
 |  | 
 |  | 
 | class FloatStorage(_C.FloatStorageBase, _StorageBase): | 
 |     pass | 
 |  | 
 |  | 
 | class HalfStorage(_C.HalfStorageBase, _StorageBase): | 
 |     pass | 
 |  | 
 |  | 
 | class LongStorage(_C.LongStorageBase, _StorageBase): | 
 |     pass | 
 |  | 
 |  | 
 | class IntStorage(_C.IntStorageBase, _StorageBase): | 
 |     pass | 
 |  | 
 |  | 
 | class ShortStorage(_C.ShortStorageBase, _StorageBase): | 
 |     pass | 
 |  | 
 |  | 
 | class CharStorage(_C.CharStorageBase, _StorageBase): | 
 |     pass | 
 |  | 
 |  | 
 | class ByteStorage(_C.ByteStorageBase, _StorageBase): | 
 |     pass | 
 |  | 
 |  | 
 | class BoolStorage(_C.BoolStorageBase, _StorageBase): | 
 |     pass | 
 |  | 
 |  | 
 | class BFloat16Storage(_C.BFloat16StorageBase, _StorageBase): | 
 |     pass | 
 |  | 
 |  | 
 | class QUInt8Storage(_C.QUInt8StorageBase, _StorageBase): | 
 |     pass | 
 |  | 
 | class QInt8Storage(_C.QInt8StorageBase, _StorageBase): | 
 |     pass | 
 |  | 
 | class QInt32Storage(_C.QInt32StorageBase, _StorageBase): | 
 |     pass | 
 |  | 
 |  | 
 | _storage_classes = { | 
 |     DoubleStorage, FloatStorage, LongStorage, IntStorage, ShortStorage, | 
 |     CharStorage, ByteStorage, HalfStorage, BoolStorage, QUInt8Storage, QInt8Storage, | 
 |     QInt32Storage, BFloat16Storage | 
 | } | 
 |  | 
 | # The _tensor_classes set is initialized by the call to _C._initialize_tensor_type_bindings() | 
 | _tensor_classes = set() | 
 |  | 
 |  | 
 | ################################################################################ | 
 | # Initialize extension | 
 | ################################################################################ | 
 |  | 
 | def manager_path(): | 
 |     if platform.system() == 'Windows': | 
 |         return b"" | 
 |     path = get_file_path('torch', 'bin', 'torch_shm_manager') | 
 |     prepare_multiprocessing_environment(get_file_path('torch')) | 
 |     if not os.path.exists(path): | 
 |         raise RuntimeError("Unable to find torch_shm_manager at " + path) | 
 |     return path.encode('utf-8') | 
 |  | 
 |  | 
 | # Shared memory manager needs to know the exact location of manager executable | 
 | _C._initExtension(manager_path()) | 
 | del manager_path | 
 |  | 
 | for name in dir(_C._VariableFunctions): | 
 |     if name.startswith('__'): | 
 |         continue | 
 |     globals()[name] = getattr(_C._VariableFunctions, name) | 
 |  | 
 | ################################################################################ | 
 | # Import interface functions defined in Python | 
 | ################################################################################ | 
 |  | 
 | # needs to be after the above ATen bindings so we can overwrite from Python side | 
 | from .functional import * | 
 |  | 
 |  | 
 | ################################################################################ | 
 | # Remove unnecessary members | 
 | ################################################################################ | 
 |  | 
 | del DoubleStorageBase | 
 | del FloatStorageBase | 
 | del LongStorageBase | 
 | del IntStorageBase | 
 | del ShortStorageBase | 
 | del CharStorageBase | 
 | del ByteStorageBase | 
 | del BoolStorageBase | 
 | del QUInt8StorageBase | 
 | del BFloat16StorageBase | 
 |  | 
 | ################################################################################ | 
 | # Import most common subpackages | 
 | ################################################################################ | 
 |  | 
 | import torch.cuda | 
 | import torch.autograd | 
 | from torch.autograd import no_grad, enable_grad, set_grad_enabled  # noqa: F401 | 
 | import torch.nn | 
 | import torch.nn._intrinsic | 
 | import torch.nn.quantized | 
 | import torch.optim | 
 | import torch.multiprocessing | 
 | import torch.sparse | 
 | import torch.utils.backcompat | 
 | import torch.onnx | 
 | import torch.jit | 
 | import torch.hub | 
 | import torch.random | 
 | import torch.distributions | 
 | import torch.testing | 
 | import torch.backends.cuda | 
 | import torch.backends.mkl | 
 | import torch.backends.openmp | 
 | import torch.backends.quantized | 
 | import torch.quantization | 
 | import torch.utils.data | 
 | import torch.__config__ | 
 | import torch.__future__ | 
 |  | 
 | _C._init_names(list(torch._storage_classes)) | 
 |  | 
 | # attach docstrings to torch and tensor functions | 
 | from . import _torch_docs, _tensor_docs, _storage_docs | 
 | del _torch_docs, _tensor_docs, _storage_docs | 
 |  | 
 |  | 
 | def compiled_with_cxx11_abi(): | 
 |     r"""Returns whether PyTorch was built with _GLIBCXX_USE_CXX11_ABI=1""" | 
 |     return _C._GLIBCXX_USE_CXX11_ABI | 
 |  | 
 |  | 
 | # Import the ops "namespace" | 
 | from torch._ops import ops  # noqa: F401 | 
 | from torch._classes import classes  # noqa: F401 | 
 |  | 
 | # Import the quasi random sampler | 
 | import torch.quasirandom |