Fix DeviceContext bug (#133729)

Fixes https://github.com/pytorch/pytorch/issues/133666

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133729
Approved by: https://github.com/bdhirsh
ghstack dependencies: #133130
diff --git a/test/test_overrides.py b/test/test_overrides.py
index de3d524..22c04ed 100644
--- a/test/test_overrides.py
+++ b/test/test_overrides.py
@@ -22,6 +22,7 @@
     TorchFunctionMode,
     _get_current_function_mode,
     _get_current_function_mode_stack,
+    BaseTorchFunctionMode
 )
 from torch.utils._mode_utils import all_same_mode
 from torch.utils._pytree import tree_map
@@ -1600,6 +1601,29 @@
             self.assertEqual(d_kwargs.type, "xla")
             self.assertEqual(d_kwargs.index, 0)
 
+    def test_device_context_semantics(self):
+        from torch._C import _len_torch_function_stack
+        from torch.utils._device import DeviceContext
+        torch.set_default_device("cuda")
+
+        def get_stack():
+            return [torch._C._get_function_stack_at(i) for i in range(_len_torch_function_stack())]
+
+        base_mode = BaseTorchFunctionMode()
+        with base_mode:
+            torch.set_default_device("cpu")
+            x = torch.ones(2, 2)
+            stack = get_stack()
+            self.assertIsInstance(stack[0], DeviceContext)
+            self.assertEqual(stack[0].device, torch.device("cpu"))
+
+        stack = get_stack()
+        self.assertIsInstance(stack[0], DeviceContext)
+        self.assertEqual(stack[0].device, torch.device("cpu"))
+
+
+
+
 
 if __name__ == '__main__':
     run_tests()
diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py
index 42ba358..3a1cef7 100644
--- a/torch/_dynamo/variables/user_defined.py
+++ b/torch/_dynamo/variables/user_defined.py
@@ -75,6 +75,11 @@
         from _pytest.python_api import RaisesContext
         from _pytest.recwarn import WarningsChecker
 
+        # TODO mlazos: Temporary to get this stack to pass
+        # remove in subsequent PR
+        from torch.overrides import BaseTorchFunctionMode
+
+        f_ctxs.append(BaseTorchFunctionMode)
         f_ctxs.append(RaisesContext)
         f_ctxs.append(WarningsChecker)
     except ImportError:
diff --git a/torch/utils/_device.py b/torch/utils/_device.py
index c852cd3..d9e8852 100644
--- a/torch/utils/_device.py
+++ b/torch/utils/_device.py
@@ -1,8 +1,9 @@
 # mypy: allow-untyped-defs
 from typing import Optional
 import torch
-from torch.overrides import TorchFunctionMode
+from torch.overrides import TorchFunctionMode, _pop_mode, _push_mode
 from torch.utils._contextlib import context_decorator
+from torch._C import _len_torch_function_stack
 import functools
 
 CURRENT_DEVICE: Optional[torch.device] = None
@@ -65,12 +66,38 @@
         global CURRENT_DEVICE
         self.old_device = CURRENT_DEVICE
         CURRENT_DEVICE = self.device
-        return super().__enter__()
+        # We need to put the device at the bottom of the stack
+        # If we set default device within a function mode context
+        # exiting that context mode will pop the device function mode off
+        # of the stack incorrectly
+        cur_stack = []
+        for _ in range(_len_torch_function_stack()):
+            cur_stack.append(_pop_mode())
+
+        _push_mode(self)
+
+        for mode in reversed(cur_stack):
+            _push_mode(mode)
+
 
     def __exit__(self, exc_type, exc_val, exc_tb):
         global CURRENT_DEVICE
         CURRENT_DEVICE = self.old_device
-        return super().__exit__(exc_type, exc_val, exc_tb)
+        cur_stack = []
+        # Invariant: there should only be one DeviceContext on the stack at any time
+        # (At the bottom), pop all mdoes until we hit the bottom, assert it's a DeviceContext
+        # or else someone else has popped it!
+        for _ in range(_len_torch_function_stack() - 1):
+            mode = _pop_mode()
+            assert not isinstance(mode, DeviceContext)
+            cur_stack.append(mode)
+
+        if _len_torch_function_stack() > 0:
+            mode = _pop_mode()
+            assert isinstance(mode, DeviceContext)
+
+        for mode in reversed(cur_stack):
+            _push_mode(mode)
 
     def __torch_function__(self, func, types, args=(), kwargs=None):
         kwargs = kwargs or {}