[FSDP] Removed `.detach` in `clip_grad_norm_` (#120612)
This seems unnecessary under `no_grad()` context. The unit tests still pass.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120612
Approved by: https://github.com/Skylion007
ghstack dependencies: #120231
diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py
index 7c71120..0a997d7 100644
--- a/torch/distributed/fsdp/fully_sharded_data_parallel.py
+++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py
@@ -1170,7 +1170,7 @@
# `if clip_coef < 1`
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
for grad in grads:
- grad.detach().mul_(clip_coef_clamped.to(grad.device, grad.dtype))
+ grad.mul_(clip_coef_clamped.to(grad.device, grad.dtype))
# Use the "largest" dtype by type promotion semantics to use the same
# dtype as if we did not force local norm computation to be in FP32
if len(grads) == 0: