refactor code to share across different devices (#120602)
# Motivation
Refactor utils code to make it possible to share across CUDA, XPU, and other backends.
# Solution
Move `_dummy_type` and `_LazySeedTracker` to torch._utils;
# Additional Context
When upstreaming, refactor these code changes by isolating them into in an additional PR to minimize their impact on the CUDA code.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120602
Approved by: https://github.com/albanD, https://github.com/jgong5, https://github.com/gujinghui, https://github.com/EikanWang
diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py
index 4f4131c..885069b 100644
--- a/torch/_dynamo/trace_rules.py
+++ b/torch/_dynamo/trace_rules.py
@@ -2243,6 +2243,7 @@
"torch._register_device_module",
"torch._running_with_deploy",
"torch._sparse_coo_tensor_unsafe",
+ "torch._utils._dummy_type",
"torch._weights_only_unpickler._get_allowed_globals",
"torch._weights_only_unpickler.load",
"torch.align_tensors",
@@ -2389,7 +2390,6 @@
"torch.cuda._set_stream_by_id",
"torch.cuda._sleep",
"torch.cuda._transform_uuid_to_ordinals",
- "torch.cuda._utils._dummy_type",
"torch.cuda._utils._get_device_index",
"torch.cuda.amp.autocast_mode._cast",
"torch.cuda.amp.autocast_mode.custom_bwd",
diff --git a/torch/_utils.py b/torch/_utils.py
index e93ec10..5976a90 100644
--- a/torch/_utils.py
+++ b/torch/_utils.py
@@ -891,3 +891,43 @@
f"Device '{device_type}' does not have a corresponding module registered as 'torch.{device_type}'."
)
return device_module
+
+
+def _dummy_type(name: str) -> type:
+ def get_err_fn(is_init: bool):
+ def err_fn(obj, *args, **kwargs):
+ if is_init:
+ class_name = obj.__class__.__name__
+ else:
+ class_name = obj.__name__
+ raise RuntimeError(f"Tried to instantiate dummy base class {class_name}")
+
+ return err_fn
+
+ return type(
+ name, (object,), {"__init__": get_err_fn(True), "__new__": get_err_fn(False)}
+ )
+
+
+class _LazySeedTracker:
+ # Since seeding is memory-less, only track the latest seed.
+ # Note: `manual_seed_all` followed by `manual_seed` overwrites
+ # the seed on current device. We track the order of **latest**
+ # calls between these two API.
+ def __init__(self):
+ self.manual_seed_all_cb = None
+ self.manual_seed_cb = None
+ self.call_order = []
+
+ def queue_seed_all(self, cb, traceback):
+ self.manual_seed_all_cb = (cb, traceback)
+ # update seed_all to be latest
+ self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb]
+
+ def queue_seed(self, cb, traceback):
+ self.manual_seed_cb = (cb, traceback)
+ # update seed to be latest
+ self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb]
+
+ def get_calls(self) -> List:
+ return self.call_order
diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py
index 29bc3c2..b042126 100644
--- a/torch/cuda/__init__.py
+++ b/torch/cuda/__init__.py
@@ -25,8 +25,8 @@
import torch._C
from torch.types import Device
from .. import device as _device
-from .._utils import classproperty
-from ._utils import _dummy_type, _get_device_index
+from .._utils import _dummy_type, _LazySeedTracker, classproperty
+from ._utils import _get_device_index
from .graphs import (
CUDAGraph,
graph,
@@ -59,31 +59,6 @@
except ImportError as err:
_PYNVML_ERR = err # sometimes a lib is installed but the import fails for some other reason, so we log the error for later
-
-class _LazySeedTracker:
- # Since seeding is memory-less, only track the latest seed.
- # Note: `manual_seed_all` followed by `manual_seed` overwrites
- # the seed on current device. We track the order of **latest**
- # calls between these two API.
- def __init__(self):
- self.manual_seed_all_cb = None
- self.manual_seed_cb = None
- self.call_order = []
-
- def queue_seed_all(self, cb, traceback):
- self.manual_seed_all_cb = (cb, traceback)
- # update seed_all to be latest
- self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb]
-
- def queue_seed(self, cb, traceback):
- self.manual_seed_cb = (cb, traceback)
- # update seed to be latest
- self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb]
-
- def get_calls(self) -> List:
- return self.call_order
-
-
_lazy_seed_tracker = _LazySeedTracker()
# Define dummy _CudaDeviceProperties type if PyTorch was compiled without CUDA
diff --git a/torch/cuda/_utils.py b/torch/cuda/_utils.py
index 1794ca9..1d0ee88 100644
--- a/torch/cuda/_utils.py
+++ b/torch/cuda/_utils.py
@@ -36,19 +36,3 @@
if isinstance(device, torch.cuda.device):
return device.idx
return _torch_get_device_index(device, optional, allow_cpu)
-
-
-def _dummy_type(name: str) -> type:
- def get_err_fn(is_init: bool):
- def err_fn(obj, *args, **kwargs):
- if is_init:
- class_name = obj.__class__.__name__
- else:
- class_name = obj.__name__
- raise RuntimeError(f"Tried to instantiate dummy base class {class_name}")
-
- return err_fn
-
- return type(
- name, (object,), {"__init__": get_err_fn(True), "__new__": get_err_fn(False)}
- )
diff --git a/torch/cuda/graphs.py b/torch/cuda/graphs.py
index 563450e..5e98a7a 100644
--- a/torch/cuda/graphs.py
+++ b/torch/cuda/graphs.py
@@ -3,7 +3,7 @@
import torch
from torch.utils import _pytree
-from ._utils import _dummy_type
+from .._utils import _dummy_type
if not hasattr(torch._C, "_CudaStreamBase"):
# Define dummy base classes
diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py
index 55022ae..60440c5 100644
--- a/torch/cuda/memory.py
+++ b/torch/cuda/memory.py
@@ -14,10 +14,10 @@
from torch import _C
from torch.types import Device
+from .._utils import _dummy_type
from . import _get_device_index, _get_nvml_device_index, _lazy_init, is_initialized
from ._memory_viz import memory as _memory, segments as _segments
-from ._utils import _dummy_type
__all__ = [
"caching_allocator_alloc",
diff --git a/torch/cuda/streams.py b/torch/cuda/streams.py
index 3d41795..22d541f 100644
--- a/torch/cuda/streams.py
+++ b/torch/cuda/streams.py
@@ -2,7 +2,7 @@
import torch
from torch._streambase import _EventBase, _StreamBase
-from ._utils import _dummy_type
+from .._utils import _dummy_type
if not hasattr(torch._C, "_CudaStreamBase"):
diff --git a/torch/xpu/__init__.py b/torch/xpu/__init__.py
index 8a7969b..a6fb1a0 100644
--- a/torch/xpu/__init__.py
+++ b/torch/xpu/__init__.py
@@ -13,7 +13,8 @@
import torch
import torch._C
from .. import device as _device
-from ._utils import _dummy_type, _get_device_index
+from .._utils import _dummy_type, _LazySeedTracker
+from ._utils import _get_device_index
from .streams import Event, Stream
_initialized = False
@@ -24,29 +25,6 @@
] = [] # don't invoke these until initialization occurs
_is_in_bad_fork = getattr(torch._C, "_xpu_isInBadFork", lambda: False)
_device_t = Union[_device, str, int, None]
-
-
-class _LazySeedTracker:
- # Since seeding is memory-less, only track the latest seed.
- def __init__(self):
- self.manual_seed_all_cb = None
- self.manual_seed_cb = None
- self.call_order = []
-
- def queue_seed_all(self, cb, traceback):
- self.manual_seed_all_cb = (cb, traceback)
- # update seed_all to be latest
- self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb]
-
- def queue_seed(self, cb, traceback):
- self.manual_seed_cb = (cb, traceback)
- # update seed to be latest
- self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb]
-
- def get_calls(self) -> List:
- return self.call_order
-
-
_lazy_seed_tracker = _LazySeedTracker()
default_generators: Tuple[torch._C.Generator] = () # type: ignore[assignment]
diff --git a/torch/xpu/_utils.py b/torch/xpu/_utils.py
index d34db17..8f73826 100644
--- a/torch/xpu/_utils.py
+++ b/torch/xpu/_utils.py
@@ -37,19 +37,3 @@
if isinstance(device, torch.xpu.device):
return device.idx
return _torch_get_device_index(device, optional, allow_cpu)
-
-
-def _dummy_type(name: str) -> type:
- def get_err_fn(is_init: bool):
- def err_fn(obj, *args, **kwargs):
- if is_init:
- class_name = obj.__class__.__name__
- else:
- class_name = obj.__name__
- raise RuntimeError(f"Tried to instantiate dummy base class {class_name}")
-
- return err_fn
-
- return type(
- name, (object,), {"__init__": get_err_fn(True), "__new__": get_err_fn(False)}
- )
diff --git a/torch/xpu/streams.py b/torch/xpu/streams.py
index 4fa639d..2c3c3a6 100644
--- a/torch/xpu/streams.py
+++ b/torch/xpu/streams.py
@@ -2,7 +2,7 @@
import torch
from torch._streambase import _EventBase, _StreamBase
-from ._utils import _dummy_type
+from .._utils import _dummy_type
if not hasattr(torch._C, "_XpuStreamBase"):