Fix DeviceContext bug (#133729)
Fixes https://github.com/pytorch/pytorch/issues/133666
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133729
Approved by: https://github.com/bdhirsh
ghstack dependencies: #133130
diff --git a/test/test_overrides.py b/test/test_overrides.py
index de3d524..22c04ed 100644
--- a/test/test_overrides.py
+++ b/test/test_overrides.py
@@ -22,6 +22,7 @@
TorchFunctionMode,
_get_current_function_mode,
_get_current_function_mode_stack,
+ BaseTorchFunctionMode
)
from torch.utils._mode_utils import all_same_mode
from torch.utils._pytree import tree_map
@@ -1600,6 +1601,29 @@
self.assertEqual(d_kwargs.type, "xla")
self.assertEqual(d_kwargs.index, 0)
+ def test_device_context_semantics(self):
+ from torch._C import _len_torch_function_stack
+ from torch.utils._device import DeviceContext
+ torch.set_default_device("cuda")
+
+ def get_stack():
+ return [torch._C._get_function_stack_at(i) for i in range(_len_torch_function_stack())]
+
+ base_mode = BaseTorchFunctionMode()
+ with base_mode:
+ torch.set_default_device("cpu")
+ x = torch.ones(2, 2)
+ stack = get_stack()
+ self.assertIsInstance(stack[0], DeviceContext)
+ self.assertEqual(stack[0].device, torch.device("cpu"))
+
+ stack = get_stack()
+ self.assertIsInstance(stack[0], DeviceContext)
+ self.assertEqual(stack[0].device, torch.device("cpu"))
+
+
+
+
if __name__ == '__main__':
run_tests()
diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py
index 42ba358..3a1cef7 100644
--- a/torch/_dynamo/variables/user_defined.py
+++ b/torch/_dynamo/variables/user_defined.py
@@ -75,6 +75,11 @@
from _pytest.python_api import RaisesContext
from _pytest.recwarn import WarningsChecker
+ # TODO mlazos: Temporary to get this stack to pass
+ # remove in subsequent PR
+ from torch.overrides import BaseTorchFunctionMode
+
+ f_ctxs.append(BaseTorchFunctionMode)
f_ctxs.append(RaisesContext)
f_ctxs.append(WarningsChecker)
except ImportError:
diff --git a/torch/utils/_device.py b/torch/utils/_device.py
index c852cd3..d9e8852 100644
--- a/torch/utils/_device.py
+++ b/torch/utils/_device.py
@@ -1,8 +1,9 @@
# mypy: allow-untyped-defs
from typing import Optional
import torch
-from torch.overrides import TorchFunctionMode
+from torch.overrides import TorchFunctionMode, _pop_mode, _push_mode
from torch.utils._contextlib import context_decorator
+from torch._C import _len_torch_function_stack
import functools
CURRENT_DEVICE: Optional[torch.device] = None
@@ -65,12 +66,38 @@
global CURRENT_DEVICE
self.old_device = CURRENT_DEVICE
CURRENT_DEVICE = self.device
- return super().__enter__()
+ # We need to put the device at the bottom of the stack
+ # If we set default device within a function mode context
+ # exiting that context mode will pop the device function mode off
+ # of the stack incorrectly
+ cur_stack = []
+ for _ in range(_len_torch_function_stack()):
+ cur_stack.append(_pop_mode())
+
+ _push_mode(self)
+
+ for mode in reversed(cur_stack):
+ _push_mode(mode)
+
def __exit__(self, exc_type, exc_val, exc_tb):
global CURRENT_DEVICE
CURRENT_DEVICE = self.old_device
- return super().__exit__(exc_type, exc_val, exc_tb)
+ cur_stack = []
+ # Invariant: there should only be one DeviceContext on the stack at any time
+ # (At the bottom), pop all mdoes until we hit the bottom, assert it's a DeviceContext
+ # or else someone else has popped it!
+ for _ in range(_len_torch_function_stack() - 1):
+ mode = _pop_mode()
+ assert not isinstance(mode, DeviceContext)
+ cur_stack.append(mode)
+
+ if _len_torch_function_stack() > 0:
+ mode = _pop_mode()
+ assert isinstance(mode, DeviceContext)
+
+ for mode in reversed(cur_stack):
+ _push_mode(mode)
def __torch_function__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}