Avoid device casting for all singleton tensors in optimizer states (#91454)
Fixes #75224
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91454
Approved by: https://github.com/janeyx99
diff --git a/test/test_optim.py b/test/test_optim.py
index 3ec534f..843d088 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -48,6 +48,7 @@
)
from typing import Dict, Any, Tuple
from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook
+from torch.utils._pytree import tree_flatten
# load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
@@ -300,13 +301,12 @@
# Make sure state dict wasn't modified
self.assertEqual(state_dict, state_dict_c)
- # Make sure that device of state['step'] is still CPU
+ # Make sure that all singleton tensors in the state_dict are still on CPU.
new_state_dict = optimizer_cuda.state_dict()
- if "step" in state_dict["state"][0] and torch.is_tensor(
- state_dict["state"][0]["step"]
- ):
- for state in new_state_dict["state"].values():
- self.assertEqual(state["step"].device.type, "cpu")
+ flat_new_state_dict, _ = tree_flatten(new_state_dict)
+ for state_item in flat_new_state_dict:
+ if torch.is_tensor(state_item) and state_item.dim() == 0:
+ self.assertEqual(state_item.device.type, "cpu")
for _i in range(20):
optimizer.step(fn)
diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py
index b2935b4..2440d69 100644
--- a/torch/optim/optimizer.py
+++ b/torch/optim/optimizer.py
@@ -339,10 +339,13 @@
if isinstance(value, torch.Tensor):
# Floating-point types are a bit special here. They are the only ones
# that are assumed to always match the type of params.
- # Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424
- if (key != "step"):
- if param.is_floating_point():
- value = value.to(param.dtype)
+ if param.is_floating_point():
+ value = value.to(param.dtype)
+
+ # Make sure singleton tensors (e.g. state['step']) do not change device.
+ # See https://github.com/pytorch/pytorch/issues/74424
+ is_singleton: bool = value.dim() == 0
+ if not is_singleton:
value = value.to(param.device)
return value
elif isinstance(value, dict):