[BE][FSDP] Retire `_get_full_detached_param()` (#80871)
The tests did not actually require that the parameters be detached, so this coalesces `_get_full_detached_param()` with `get_full_params()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80871
Approved by: https://github.com/rohan-varma
diff --git a/test/distributed/fsdp/test_fsdp_state_dict.py b/test/distributed/fsdp/test_fsdp_state_dict.py
index 367f5cb..0ef5c15 100644
--- a/test/distributed/fsdp/test_fsdp_state_dict.py
+++ b/test/distributed/fsdp/test_fsdp_state_dict.py
@@ -29,7 +29,6 @@
from torch.testing._internal.common_fsdp import (
FSDPTest,
get_full_params,
- _get_full_detached_param,
_get_state_dict,
SkipModel,
_zero_model,
@@ -350,7 +349,7 @@
)
model = self._get_simple_nested_model(mixed_precision=mixed_precision)
optim = torch.optim.SGD(model.parameters(), lr=0.1)
- initial_params = _get_full_detached_param(model)
+ initial_params = get_full_params(model)
for _ in range(6):
inp = torch.randn(1, 10, device=torch.cuda.current_device())
output = model(*inp)
@@ -360,7 +359,7 @@
loss.backward()
optim.step()
- trained_params = _get_full_detached_param(model)
+ trained_params = get_full_params(model)
# Ensure some training occured
self.assertNotEqual(initial_params, trained_params)
# Save a copy of the state_dict
@@ -392,7 +391,7 @@
with FSDP.state_dict_type(model, STATE_DICT_MAPPING[state_dict_type]):
model.load_state_dict(state_dict)
- loaded_params = _get_full_detached_param(model)
+ loaded_params = get_full_params(model)
self.assertEqual(loaded_params, trained_params)
def _initialize_model(
diff --git a/torch/testing/_internal/common_fsdp.py b/torch/testing/_internal/common_fsdp.py
index 15fa53b..b529c80 100644
--- a/torch/testing/_internal/common_fsdp.py
+++ b/torch/testing/_internal/common_fsdp.py
@@ -30,11 +30,6 @@
# Don't move model to CUDA at all.
CUDA_NEVER = 3
-def _get_full_detached_param(fsdp_model: FullyShardedDataParallel):
- with FullyShardedDataParallel.summon_full_params(fsdp_model):
- params = list(p.clone().detach_() for p in fsdp_model.parameters())
-
- return params
def _validate(model, process_group, assert_fn):
module_states = [param.detach().cpu() for param in model.parameters()]