[FSDP] Use `all_gather_into_tensor()` (#87077)

Let us silence some warnings 👍🏼
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87077
Approved by: https://github.com/rohan-varma
diff --git a/test/distributed/fsdp/test_fsdp_comm.py b/test/distributed/fsdp/test_fsdp_comm.py
index 590919f..c9946a9d 100644
--- a/test/distributed/fsdp/test_fsdp_comm.py
+++ b/test/distributed/fsdp/test_fsdp_comm.py
@@ -220,7 +220,7 @@
         # and if `use_no_sync=False`, we only run `num_iters` iterations
         # outside `no_sync()`
         num_iters = 3
-        with patch("torch.distributed._all_gather_base") as mock_all_gather, \
+        with patch("torch.distributed.all_gather_into_tensor") as mock_all_gather, \
                 patch("torch.distributed._reduce_scatter_base") as mock_reduce_scatter:
             def reset_mocks():
                 mock_all_gather.reset_mock()
diff --git a/torch/distributed/fsdp/flat_param.py b/torch/distributed/fsdp/flat_param.py
index 818d52f..6e30a03 100644
--- a/torch/distributed/fsdp/flat_param.py
+++ b/torch/distributed/fsdp/flat_param.py
@@ -799,7 +799,7 @@
             padded_unsharded_flat_param.numel() == expected_numel,
             f"Expects {expected_numel} numel but got {padded_unsharded_flat_param.numel()}",
         )
-        dist._all_gather_base(
+        dist.all_gather_into_tensor(
             padded_unsharded_flat_param,
             sharded_flat_param,
             self.process_group,
@@ -861,7 +861,7 @@
             self._check_sharded(flat_param.grad)
             flat_param._saved_grad_shard = flat_param.grad  # type: ignore[attr-defined]
             sharded_grad = flat_param._saved_grad_shard  # type: ignore[attr-defined]
-        dist._all_gather_base(padded_unsharded_grad, sharded_grad, self.process_group)
+        dist.all_gather_into_tensor(padded_unsharded_grad, sharded_grad, self.process_group)
         unsharded_size = self.flat_param._unpadded_unsharded_size
         flat_param.grad = padded_unsharded_grad[:unsharded_size.numel()].view(unsharded_size)
         self._use_unsharded_grad_views()
diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py
index 76d662b..08b2d5f 100644
--- a/torch/distributed/fsdp/fully_sharded_data_parallel.py
+++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py
@@ -577,7 +577,7 @@
             tensor_kwargs = {"dtype": torch.int32, "device": device}
             world_num_valid_indices = torch.zeros(self.world_size, **tensor_kwargs)
             local_num_valid_indices = torch.tensor([num_valid_indices], **tensor_kwargs)
-            dist._all_gather_base(
+            dist.all_gather_into_tensor(
                 world_num_valid_indices,
                 local_num_valid_indices,
                 group=self.process_group,
@@ -602,7 +602,7 @@
                 self.world_size * num_valid_indices, **tensor_kwargs
             )
             local_indices = torch.tensor(local_indices, **tensor_kwargs)
-            dist._all_gather_base(
+            dist.all_gather_into_tensor(
                 world_indices, local_indices, group=self.process_group
             )
             # Check that all ranks plan to all-gather the same index parameters
@@ -2608,9 +2608,10 @@
             )
 
         nonsharded_tensors = []
-        # TODO: Reduce the communication by using only one _all_gather_base to
-        # gather all the parameters in this layer. This can be achieved by
-        # concatenated all the local shards and then append the padding.
+        # TODO: Reduce the communication by using only one
+        # `all_gather_into_tensor()` to gather all the parameters in this
+        # layer. This can be achieved by concatenating all the local shards and
+        # then appending the padding.
         # https://github.com/pytorch/pytorch/issues/77461
         shared_fqns = [fqn for fqn, _, _ in self._shared_param_fqns]
         for fqn, _, _ in self._param_fqns:
@@ -2640,7 +2641,7 @@
             tensor = torch.empty(
                 chunk_size * self.world_size, dtype=local_tensor.dtype
             ).cuda()
-            dist._all_gather_base(tensor, local_tensor, group=self.process_group)
+            dist.all_gather_into_tensor(tensor, local_tensor, group=self.process_group)
             tensor = tensor.narrow(0, 0, param_numel).reshape(param.size())
             nonsharded_tensors.append(tensor)