Avoid device casting for all singleton tensors in optimizer states (#91454)

Fixes #75224
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91454
Approved by: https://github.com/janeyx99
diff --git a/test/test_optim.py b/test/test_optim.py
index 3ec534f..843d088 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -48,6 +48,7 @@
 )
 from typing import Dict, Any, Tuple
 from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook
+from torch.utils._pytree import tree_flatten
 
 # load_tests from common_utils is used to automatically filter tests for
 # sharding on sandcastle. This line silences flake warnings
@@ -300,13 +301,12 @@
         # Make sure state dict wasn't modified
         self.assertEqual(state_dict, state_dict_c)
 
-        # Make sure that device of state['step'] is still CPU
+        # Make sure that all singleton tensors in the state_dict are still on CPU.
         new_state_dict = optimizer_cuda.state_dict()
-        if "step" in state_dict["state"][0] and torch.is_tensor(
-            state_dict["state"][0]["step"]
-        ):
-            for state in new_state_dict["state"].values():
-                self.assertEqual(state["step"].device.type, "cpu")
+        flat_new_state_dict, _ = tree_flatten(new_state_dict)
+        for state_item in flat_new_state_dict:
+            if torch.is_tensor(state_item) and state_item.dim() == 0:
+                self.assertEqual(state_item.device.type, "cpu")
 
         for _i in range(20):
             optimizer.step(fn)
diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py
index b2935b4..2440d69 100644
--- a/torch/optim/optimizer.py
+++ b/torch/optim/optimizer.py
@@ -339,10 +339,13 @@
             if isinstance(value, torch.Tensor):
                 # Floating-point types are a bit special here. They are the only ones
                 # that are assumed to always match the type of params.
-                # Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424
-                if (key != "step"):
-                    if param.is_floating_point():
-                        value = value.to(param.dtype)
+                if param.is_floating_point():
+                    value = value.to(param.dtype)
+
+                # Make sure singleton tensors (e.g. state['step']) do not change device.
+                # See https://github.com/pytorch/pytorch/issues/74424
+                is_singleton: bool = value.dim() == 0
+                if not is_singleton:
                     value = value.to(param.device)
                 return value
             elif isinstance(value, dict):