[Easy][FSDP] Clarify `_use_unsharded_grad_views` comment (#100359)
This is an easy follow-up to the previous PR to (1) clarify that `view` is the original parameter's gradient and (2) that after `reshard()` the gradient is on CPU only if offloading parameters.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100359
Approved by: https://github.com/rohan-varma
diff --git a/torch/distributed/fsdp/flat_param.py b/torch/distributed/fsdp/flat_param.py
index 072f2a9..ca01241 100644
--- a/torch/distributed/fsdp/flat_param.py
+++ b/torch/distributed/fsdp/flat_param.py
@@ -1779,9 +1779,11 @@
):
# NOTE: This is a hack using `.data` to side step the check
# that parameter/gradient sizes/dtypes/devices match. From
- # calling `reshard()`, `param` has the sharded size, the full
- # precision dtype, and is on CPU. Thus, one or more of the
- # following cases can hold when in `no_sync()`:
+ # calling `reshard()`, `param` has the sharded size, has the
+ # full precision dtype, and if CPU offloading is enabled, is on
+ # CPU. Thus, one or more of the following cases can hold when
+ # in `no_sync()`, where `view` is the original parameter's
+ # gradient:
# 1. `view` can have the unsharded size.
# 2. `view` can have the parameter low precision dtype.
# 3. `view` can be on GPU.