[FSDP] Backward prefetch in recursive call (#71804)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71804

Add backward prefetch arg when using auto_wrap_policy. Unittests are
updated appropriately.
ghstack-source-id: 147753214

Test Plan: CI

Reviewed By: zhaojuanmao

Differential Revision: D33782346

fbshipit-source-id: c0176b48db29c3756a8873e809610ed53480102b
(cherry picked from commit 764acb3f1c8fb9879b6c92a934df1a7d2c9e3f3d)
diff --git a/test/distributed/fsdp/test_wrap.py b/test/distributed/fsdp/test_wrap.py
index 1b23341..f1b1680 100644
--- a/test/distributed/fsdp/test_wrap.py
+++ b/test/distributed/fsdp/test_wrap.py
@@ -12,6 +12,7 @@
 from torch.distributed._fsdp.fully_sharded_data_parallel import (
     FullyShardedDataParallel as FSDP,
     CPUOffload,
+    BackwardPrefetch_,
 )
 from torch.distributed._fsdp.wrap import (
     default_auto_wrap_policy,
@@ -130,10 +131,14 @@
         [CPUOffload(offload_params=False), CPUOffload(offload_params=True)]
     )
     @parametrize(
+        "backward_prefetch",
+        [BackwardPrefetch_.BACKWARD_POST, BackwardPrefetch_.BACKWARD_PRE]
+    )
+    @parametrize(
         "fsdp_init_mode",
         [FSDPInitMode.CUDA_AFTER, FSDPInitMode.CUDA_BEFORE]
     )
-    def test_main_wrap_api(self, cpu_offload, fsdp_init_mode):
+    def test_main_wrap_api(self, cpu_offload, backward_prefetch, fsdp_init_mode):
 
         if fsdp_init_mode == FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params:
             # they don't work together, expected
@@ -168,22 +173,25 @@
                 min_num_params=0,  # wrap all modules
             ),
             cpu_offload=cpu_offload,
+            backward_prefetch=backward_prefetch,
         )
         if fsdp_init_mode == FSDPInitMode.CUDA_AFTER:
             wrapped_model = wrapped_model.cuda()
 
-        modules = [
-            wrapped_model,
+        modules_in_fsdp_graph_order = [
             wrapped_model.module.lin1,
             wrapped_model.module.lin2,
             wrapped_model.module.lin3,
-            wrapped_model.module.lin4,
-            # Nested FSDP
             wrapped_model.module.lin4.module.nested_lin,
+            wrapped_model.module.lin4,
+            wrapped_model
         ]
-        for module in modules:
+
+        for module in modules_in_fsdp_graph_order:
             self.assertTrue(isinstance(module, FSDP))
             self._check_cpu_offload(module, cpu_offload)
+            self._check_backward_prefetch(module, backward_prefetch)
+
         # Run model a few times for sanity check.
         optim = torch.optim.SGD(wrapped_model.parameters(), lr=1e-2, momentum=0.9)
         inp = torch.ones(1).cuda()
@@ -193,6 +201,14 @@
             loss.backward()
             optim.step()
 
+        # Since we ran with backward prefetch, verify backward prefetch related
+        # data.
+        for i, module in enumerate(modules_in_fsdp_graph_order):
+            self.assertEqual(i, module._my_fsdp_idx_in_graph)
+            self.assertTrue(
+                module._fsdp_graph_order == modules_in_fsdp_graph_order
+            )
+
 
 class TestAutoWrap(TestCase):
     def setUp(self) -> None:
diff --git a/torch/distributed/_fsdp/fully_sharded_data_parallel.py b/torch/distributed/_fsdp/fully_sharded_data_parallel.py
index 614800e..2b0da86 100644
--- a/torch/distributed/_fsdp/fully_sharded_data_parallel.py
+++ b/torch/distributed/_fsdp/fully_sharded_data_parallel.py
@@ -165,6 +165,7 @@
                 # FSDP arguments follow.
                 process_group=process_group,
                 cpu_offload=cpu_offload,
+                backward_prefetch=backward_prefetch,
                 # Note that recursive_wap should not call FSDP with wrapping
                 # enabled, as this recursive call handles all wrapping,
                 # including for nested children.
diff --git a/torch/testing/_internal/common_fsdp.py b/torch/testing/_internal/common_fsdp.py
index 244dc96..e617b11 100644
--- a/torch/testing/_internal/common_fsdp.py
+++ b/torch/testing/_internal/common_fsdp.py
@@ -333,6 +333,9 @@
     def _check_cpu_offload(self, fsdp_model, cpu_offload):
         self.assertEqual(cpu_offload, fsdp_model.cpu_offload)
 
+    def _check_backward_prefetch(self, fsdp_model, backward_prefetch):
+        self.assertEqual(backward_prefetch, fsdp_model.backward_prefetch)
+
     @classmethod
     def _run(cls, rank, test_name, file_name, pipe):
         self = cls(test_name)