[FSDP] Default `limit_all_gathers=True` (#104900)
This PR defaults to `limit_all_gathers=True`.
I included a `record_function()` for the rate limiter synchronization to help with user confusion on the gap in the pre-forward:
<img width="874" alt="Screenshot 2023-07-10 at 3 28 18 PM" src="https://github.com/pytorch/pytorch/assets/31054793/61f55e0e-58d7-4162-9395-bea06d3e8d8a">
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104900
Approved by: https://github.com/fegin
diff --git a/test/distributed/fsdp/test_fsdp_overlap.py b/test/distributed/fsdp/test_fsdp_overlap.py
index 86db213..d5dd6ef 100644
--- a/test/distributed/fsdp/test_fsdp_overlap.py
+++ b/test/distributed/fsdp/test_fsdp_overlap.py
@@ -58,13 +58,16 @@
def _create_model(compute_cycles, has_params: bool):
+ # Use `limit_all_gathers=False` since the timing being tested relies on the
+ # CPU running ahead of the GPU
model = FSDP(
nn.Sequential(
- FSDP(Layer(compute_cycles, has_params)),
- FSDP(Layer(compute_cycles, has_params)),
- FSDP(Layer(compute_cycles, has_params)),
- FSDP(Layer(compute_cycles, has_params)),
- )
+ FSDP(Layer(compute_cycles, has_params), limit_all_gathers=False),
+ FSDP(Layer(compute_cycles, has_params), limit_all_gathers=False),
+ FSDP(Layer(compute_cycles, has_params), limit_all_gathers=False),
+ FSDP(Layer(compute_cycles, has_params), limit_all_gathers=False),
+ ),
+ limit_all_gathers=False,
).cuda()
return model
diff --git a/torch/distributed/fsdp/_runtime_utils.py b/torch/distributed/fsdp/_runtime_utils.py
index ee807e4..7be5347 100644
--- a/torch/distributed/fsdp/_runtime_utils.py
+++ b/torch/distributed/fsdp/_runtime_utils.py
@@ -358,7 +358,10 @@
if state.limit_all_gathers:
event = state._free_event_queue.dequeue_if_needed()
if event:
- event.synchronize()
+ with torch.profiler.record_function(
+ "FullyShardedDataParallel.rate_limiter"
+ ):
+ event.synchronize()
with state._device_handle.stream(unshard_stream):
for handle in handles:
handle.unshard()
diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py
index 9e67441..13da17d 100644
--- a/torch/distributed/fsdp/fully_sharded_data_parallel.py
+++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py
@@ -214,6 +214,13 @@
instead of reacquiring the references each iteration, then it will not
see FSDP's newly created views, and autograd will not work correctly.
+ .. note::
+ With ``limit_all_gathers=True``, you may see a gap in the FSDP
+ pre-forward where the CPU thread is not issuing any kernels. This is
+ intentional and shows the rate limiter in effect. Synchronizing the CPU
+ thread in that way prevents over-allocating memory for subsequent
+ all-gathers, and it should not actually delay GPU kernel execution.
+
Args:
module (nn.Module):
This is the module to be wrapped with FSDP.
@@ -334,12 +341,16 @@
bound workloads. This should only be used for static graph models
since the forward order is fixed based on the first iteration's
execution. (Default: ``False``)
- limit_all_gathers (bool): If ``False``, then FSDP allows the CPU
- thread to schedule all-gathers without any extra synchronization.
- If ``True``, then FSDP explicitly synchronizes the CPU thread to
- prevent too many in-flight all-gathers. This ``bool`` only affects
- the sharded strategies that schedule all-gathers. Enabling this can
- help lower the number of CUDA malloc retries.
+ limit_all_gathers (bool): If ``True``, then FSDP explicitly
+ synchronizes the CPU thread to ensure GPU memory usage from only
+ *two* consecutive FSDP instances (the current instance running
+ computation and the next instance whose all-gather is prefetched).
+ If ``False``, then FSDP allows the CPU thread to issue all-gathers
+ without any extra synchronization. (Default: ``True``) We often
+ refer to this feature as the "rate limiter". This flag should only
+ be set to ``False`` for specific CPU-bound workloads with low
+ memory pressure in which case the CPU thread can aggressively issue
+ all kernels without concern for the GPU memory usage.
use_orig_params (bool): Setting this to ``True`` has FSDP use
``module`` 's original parameters. FSDP exposes those original
parameters to the user via :meth:`nn.Module.named_parameters`
@@ -382,7 +393,7 @@
device_id: Optional[Union[int, torch.device]] = None,
sync_module_states: bool = False,
forward_prefetch: bool = False,
- limit_all_gathers: bool = False,
+ limit_all_gathers: bool = True,
use_orig_params: bool = False,
ignored_states: Union[
Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]