[FSDP] Reduce CPU overhead (#96958)
I experimented with 200 `nn.Linear`s with `bias=True` for a total of 400 `nn.Parameter`s all wrapped into the same FSDP instance and world size of 2.
**`unshard()` -> `_use_unsharded_views()`**
- (From previous PR) unsafe `setattr`: 6.112 ms -> 4.268 ms
**`pre_unshard()` -> `_writeback_orig_params()`**
- Factor out `flat_param` and `flat_param_grad` data pointers: ~1.8 ms -> 1.071 ms
- Now dominated by calling `_typed_storage()` on each original parameter and its gradient
**`reshard()` -> `_use_sharded_views()`**
- Factor out `torch.empty(0, ...)`: ~4.6 - 4.7 ms -> ~2.7 - 2.8 ms
- Now dominated by `aten::slice()` and (unsafe) `setattr`, which are required
I removed some `assert` calls that were only needed for mypy or if the subsequent call would provide the same error message anyway. These have negligible overhead, but I think it is still okay to remove them and avoid the type check. We need to address type checking more holistically anyway.
---
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96958
Approved by: https://github.com/rohan-varma
diff --git a/torch/distributed/fsdp/_utils.py b/torch/distributed/fsdp/_utils.py
index 45c8c45..378f547 100644
--- a/torch/distributed/fsdp/_utils.py
+++ b/torch/distributed/fsdp/_utils.py
@@ -21,6 +21,10 @@
return x._typed_storage()._data_ptr() == y._typed_storage()._data_ptr()
+def _same_storage_as_data_ptr(x: torch.Tensor, data_ptr: int) -> bool:
+ return x._typed_storage()._data_ptr() == data_ptr
+
+
def _no_dispatch_record_stream(tensor: torch.Tensor, stream: torch.cuda.Stream) -> None:
with no_dispatch():
tensor.record_stream(cast(torch._C.Stream, stream))
diff --git a/torch/distributed/fsdp/flat_param.py b/torch/distributed/fsdp/flat_param.py
index 260cb60..a1b195c 100644
--- a/torch/distributed/fsdp/flat_param.py
+++ b/torch/distributed/fsdp/flat_param.py
@@ -32,7 +32,7 @@
from torch.distributed.utils import _alloc_storage, _free_storage, _p_assert
from ._fsdp_extensions import _ext_post_unflatten_transform, _ext_pre_flatten_transform
-from ._utils import _no_dispatch_record_stream, _same_storage
+from ._utils import _no_dispatch_record_stream, _same_storage_as_data_ptr
__all__ = [
"FlatParameter",
@@ -1385,6 +1385,7 @@
)
return views
+ @no_type_check
def _use_unsharded_views(self, as_params: bool) -> None:
"""
Unflattens the unsharded flattened parameter by setting the original
@@ -1409,7 +1410,7 @@
# variable.
self._setattr_param(module, param_name, nn.Parameter(view))
continue
- param = self.flat_param._params[i] # type: ignore[index]
+ param = self.flat_param._params[i]
self._setattr_param(module, param_name, param)
param.data = view
elif as_params:
@@ -1418,42 +1419,29 @@
param_var: Tensor = view
if self._use_orig_params:
if self._training_state == HandleTrainingState.FORWARD:
- assert self.flat_param._tensors is not None
# Save the `Tensor` for the pre-backward
self.flat_param._tensors[i] = view # save for pre-backward
elif self._training_state == HandleTrainingState.BACKWARD_PRE:
# Use the saved `Tensor` variable from the forward to
# preserve the autograd graph so that the post-backward
# hook fires (e.g. for reentrant AC)
- assert self.flat_param._tensors is not None # mypy
tensor = self.flat_param._tensors[i]
- _p_assert(
- tensor is not None,
- "Expects `Tensor` to have been saved in forward",
- )
- tensor.data = view # type: ignore[union-attr]
- assert tensor is not None # mypy
+ tensor.data = view
param_var = tensor
self._setattr_tensor(module, param_name, param_var)
if (
self._use_orig_params
and self._training_state == HandleTrainingState.FORWARD
):
- module._parameters[param_name] = param_var # type: ignore[assignment]
+ module._parameters[param_name] = param_var
for i, (
param_name,
module,
_,
prim_param_name,
prim_module,
- prim_module_name,
+ _,
) in enumerate(self.flat_param._shared_param_infos):
- if hasattr(module, param_name):
- delattr(module, param_name)
- _p_assert(
- hasattr(prim_module, prim_param_name),
- f"Module {prim_module_name} is missing parameter {prim_param_name}",
- )
prim_param: Union[Tensor, nn.Parameter] = getattr(
prim_module, prim_param_name
)
@@ -1462,11 +1450,10 @@
f"as_params={as_params} type(prim_param)={type(prim_param)}",
)
if self._use_orig_params and as_params:
- shared_param = self.flat_param._shared_params[i] # type: ignore[index]
+ shared_param = self.flat_param._shared_params[i]
self._setattr_param(module, param_name, shared_param)
shared_param.data = prim_param
elif as_params:
- assert isinstance(prim_param, nn.Parameter)
self._setattr_param(module, param_name, prim_param)
else:
self._setattr_tensor(module, param_name, prim_param)
@@ -1474,7 +1461,7 @@
self._use_orig_params
and self._training_state == HandleTrainingState.FORWARD
):
- module._parameters[param_name] = prim_param # type: ignore[assignment]
+ module._parameters[param_name] = prim_param
def _use_unsharded_grad_views(self) -> None:
"""
@@ -1554,6 +1541,7 @@
finally:
self._use_unsharded_views(as_params=False)
+ @no_type_check
@torch.no_grad()
def _use_sharded_views(self) -> None:
"""
@@ -1573,31 +1561,30 @@
self._use_unsharded_views(as_params=True)
return
self._check_sharded(self.flat_param)
- start, end = self.flat_param._shard_indices # type: ignore[attr-defined]
+ start, end = self.flat_param._shard_indices
offset = 0
- assert self.flat_param._params is not None
+ # Construct once and reuse for all parameters not in the local shard
+ size_0_empty_tensor = torch.empty(
+ 0,
+ dtype=self.flat_param.dtype, # in case `flat_param` changed dtype
+ device=self.flat_param.device,
+ requires_grad=False,
+ )
for i, (param, (param_name, module, _)) in enumerate(
zip(self.flat_param._params, self.flat_param._param_infos)
):
self._setattr_param(module, param_name, param)
in_sharded_flat_param = (
- i >= start
- and i <= end
- and self.flat_param._shard_param_offsets # type: ignore[attr-defined]
+ i >= start and i <= end and self.flat_param._shard_param_offsets
)
if in_sharded_flat_param:
- param_start, param_end = self.flat_param._shard_param_offsets[i - start] # type: ignore[attr-defined]
+ param_start, param_end = self.flat_param._shard_param_offsets[i - start]
numel_in_shard = param_end - param_start + 1
param.data = self.flat_param[offset : offset + numel_in_shard]
offset += numel_in_shard
else:
# Allow the original data to be freed via garbage collection
- param.data = torch.empty(
- 0,
- dtype=self.flat_param.dtype, # in case `flat_param` changed dtype
- device=self.flat_param.device,
- requires_grad=False,
- )
+ param.data = size_0_empty_tensor
assert self.flat_param._shared_params is not None
for i, (
param,
@@ -1609,10 +1596,9 @@
prim_param = getattr(prim_module, prim_param_name)
param.data = prim_param # could be both empty and non-empty
if self._training_state == HandleTrainingState.BACKWARD_POST:
- assert self.flat_param._tensors is not None # mypy
# Clear the saved `Tensor`s since they are unneeded now
for i in range(len(self.flat_param._tensors)):
- self.flat_param._tensors[i] = None # type: ignore[index]
+ self.flat_param._tensors[i] = None
@torch.no_grad()
def _use_sharded_grad_views(self) -> None:
@@ -1681,6 +1667,7 @@
else:
param.grad = None
+ @no_type_check
@torch.no_grad()
def _writeback_orig_params(self) -> bool:
"""
@@ -1698,10 +1685,25 @@
# For `NO_SHARD`, we may still need to writeback
return False
flat_param = self.flat_param
- start, end = flat_param._shard_indices # type: ignore[attr-defined]
+ start, end = flat_param._shard_indices
offset = 0
assert flat_param._params is not None
wroteback = False
+ flat_param_data_ptr = flat_param.untyped_storage().data_ptr()
+ # NOTE: Since this method is called in the pre-unshard, which is only
+ # called during computation in the pre-forward or pre-backward, the
+ # sharded gradient should be guaranteed to be in `.grad`, not in
+ # `._saved_grad_shard`.
+ flat_param_grad = (
+ flat_param.grad
+ if self.uses_sharded_strategy or not self._offload_params
+ else flat_param._cpu_grad
+ )
+ flat_param_grad_data_ptr = (
+ None
+ if flat_param_grad is None
+ else flat_param_grad.untyped_storage().data_ptr()
+ )
for i, (param, (param_name, module, _)) in enumerate(
zip(flat_param._params, flat_param._param_infos)
):
@@ -1710,20 +1712,20 @@
# (e.g. during model checkpointing)
continue
in_sharded_flat_param = (
- i >= start
- and i <= end
- and self.flat_param._shard_param_offsets # type: ignore[attr-defined]
+ i >= start and i <= end and self.flat_param._shard_param_offsets
)
if not in_sharded_flat_param:
continue
- param_start, param_end = flat_param._shard_param_offsets[i - start] # type: ignore[attr-defined]
+ param_start, param_end = flat_param._shard_param_offsets[i - start]
numel_in_shard = param_end - param_start + 1
# Check for parameter writeback
param_changed = getattr(module, param_name) is not param
needs_param_writeback = (
param_changed # changed parameter variable itself
- or not _same_storage(param, flat_param) # changed `.data`
+ or not _same_storage_as_data_ptr(
+ param, flat_param_data_ptr
+ ) # changed `.data`
)
if param_changed:
# NOTE: The gradient is not preserved after a parameter change.
@@ -1737,10 +1739,6 @@
wroteback = True
# Check for gradient writeback
- # NOTE: Since this method is called in the pre-unshard, which is
- # only called during computation in the pre-forward or
- # pre-backward, the sharded gradient should be guaranteed to be in
- # `.grad`, not in `._saved_grad_shard`.
if param.grad is None and flat_param.grad is not None:
expected_shape = torch.Size([numel_in_shard])
self._writeback_tensor(
@@ -1750,13 +1748,11 @@
# For `NO_SHARD` + CPU offloading, `_cpu_grad` is always in
# memory and owns the gradient storage, so it will never
# require gradient writeback.
- flat_param_grad = (
- flat_param.grad
- if self.uses_sharded_strategy or not self._offload_params
- else flat_param._cpu_grad # type: ignore[attr-defined]
- )
- needs_grad_writeback = flat_param_grad is None or not _same_storage(
- param.grad, flat_param_grad
+ needs_grad_writeback = (
+ flat_param_grad is None
+ or not _same_storage_as_data_ptr(
+ param.grad, flat_param_grad_data_ptr
+ )
)
if needs_grad_writeback:
if flat_param_grad is None:
@@ -1766,6 +1762,10 @@
param.grad, flat_param_grad, i, expected_shape, offset, False
)
flat_param.grad = flat_param_grad
+ flat_param_grad = flat_param.grad
+ flat_param_grad_data_ptr = (
+ flat_param_grad.untyped_storage().data_ptr()
+ )
offset += numel_in_shard
# TODO (awgu): Handle shared parameters. We need to re-generate the
# shared parameter data structures in case sharedness changed.