[FSDP] Backward prefetch in recursive call (#71804)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71804
Add backward prefetch arg when using auto_wrap_policy. Unittests are
updated appropriately.
ghstack-source-id: 147753214
Test Plan: CI
Reviewed By: zhaojuanmao
Differential Revision: D33782346
fbshipit-source-id: c0176b48db29c3756a8873e809610ed53480102b
(cherry picked from commit 764acb3f1c8fb9879b6c92a934df1a7d2c9e3f3d)
diff --git a/test/distributed/fsdp/test_wrap.py b/test/distributed/fsdp/test_wrap.py
index 1b23341..f1b1680 100644
--- a/test/distributed/fsdp/test_wrap.py
+++ b/test/distributed/fsdp/test_wrap.py
@@ -12,6 +12,7 @@
from torch.distributed._fsdp.fully_sharded_data_parallel import (
FullyShardedDataParallel as FSDP,
CPUOffload,
+ BackwardPrefetch_,
)
from torch.distributed._fsdp.wrap import (
default_auto_wrap_policy,
@@ -130,10 +131,14 @@
[CPUOffload(offload_params=False), CPUOffload(offload_params=True)]
)
@parametrize(
+ "backward_prefetch",
+ [BackwardPrefetch_.BACKWARD_POST, BackwardPrefetch_.BACKWARD_PRE]
+ )
+ @parametrize(
"fsdp_init_mode",
[FSDPInitMode.CUDA_AFTER, FSDPInitMode.CUDA_BEFORE]
)
- def test_main_wrap_api(self, cpu_offload, fsdp_init_mode):
+ def test_main_wrap_api(self, cpu_offload, backward_prefetch, fsdp_init_mode):
if fsdp_init_mode == FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params:
# they don't work together, expected
@@ -168,22 +173,25 @@
min_num_params=0, # wrap all modules
),
cpu_offload=cpu_offload,
+ backward_prefetch=backward_prefetch,
)
if fsdp_init_mode == FSDPInitMode.CUDA_AFTER:
wrapped_model = wrapped_model.cuda()
- modules = [
- wrapped_model,
+ modules_in_fsdp_graph_order = [
wrapped_model.module.lin1,
wrapped_model.module.lin2,
wrapped_model.module.lin3,
- wrapped_model.module.lin4,
- # Nested FSDP
wrapped_model.module.lin4.module.nested_lin,
+ wrapped_model.module.lin4,
+ wrapped_model
]
- for module in modules:
+
+ for module in modules_in_fsdp_graph_order:
self.assertTrue(isinstance(module, FSDP))
self._check_cpu_offload(module, cpu_offload)
+ self._check_backward_prefetch(module, backward_prefetch)
+
# Run model a few times for sanity check.
optim = torch.optim.SGD(wrapped_model.parameters(), lr=1e-2, momentum=0.9)
inp = torch.ones(1).cuda()
@@ -193,6 +201,14 @@
loss.backward()
optim.step()
+ # Since we ran with backward prefetch, verify backward prefetch related
+ # data.
+ for i, module in enumerate(modules_in_fsdp_graph_order):
+ self.assertEqual(i, module._my_fsdp_idx_in_graph)
+ self.assertTrue(
+ module._fsdp_graph_order == modules_in_fsdp_graph_order
+ )
+
class TestAutoWrap(TestCase):
def setUp(self) -> None:
diff --git a/torch/distributed/_fsdp/fully_sharded_data_parallel.py b/torch/distributed/_fsdp/fully_sharded_data_parallel.py
index 614800e..2b0da86 100644
--- a/torch/distributed/_fsdp/fully_sharded_data_parallel.py
+++ b/torch/distributed/_fsdp/fully_sharded_data_parallel.py
@@ -165,6 +165,7 @@
# FSDP arguments follow.
process_group=process_group,
cpu_offload=cpu_offload,
+ backward_prefetch=backward_prefetch,
# Note that recursive_wap should not call FSDP with wrapping
# enabled, as this recursive call handles all wrapping,
# including for nested children.
diff --git a/torch/testing/_internal/common_fsdp.py b/torch/testing/_internal/common_fsdp.py
index 244dc96..e617b11 100644
--- a/torch/testing/_internal/common_fsdp.py
+++ b/torch/testing/_internal/common_fsdp.py
@@ -333,6 +333,9 @@
def _check_cpu_offload(self, fsdp_model, cpu_offload):
self.assertEqual(cpu_offload, fsdp_model.cpu_offload)
+ def _check_backward_prefetch(self, fsdp_model, backward_prefetch):
+ self.assertEqual(backward_prefetch, fsdp_model.backward_prefetch)
+
@classmethod
def _run(cls, rank, test_name, file_name, pipe):
self = cls(test_name)