[FSDP] Use `all_gather_into_tensor()` (#87077)
Let us silence some warnings 👍🏼
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87077
Approved by: https://github.com/rohan-varma
diff --git a/test/distributed/fsdp/test_fsdp_comm.py b/test/distributed/fsdp/test_fsdp_comm.py
index 590919f..c9946a9d 100644
--- a/test/distributed/fsdp/test_fsdp_comm.py
+++ b/test/distributed/fsdp/test_fsdp_comm.py
@@ -220,7 +220,7 @@
# and if `use_no_sync=False`, we only run `num_iters` iterations
# outside `no_sync()`
num_iters = 3
- with patch("torch.distributed._all_gather_base") as mock_all_gather, \
+ with patch("torch.distributed.all_gather_into_tensor") as mock_all_gather, \
patch("torch.distributed._reduce_scatter_base") as mock_reduce_scatter:
def reset_mocks():
mock_all_gather.reset_mock()
diff --git a/torch/distributed/fsdp/flat_param.py b/torch/distributed/fsdp/flat_param.py
index 818d52f..6e30a03 100644
--- a/torch/distributed/fsdp/flat_param.py
+++ b/torch/distributed/fsdp/flat_param.py
@@ -799,7 +799,7 @@
padded_unsharded_flat_param.numel() == expected_numel,
f"Expects {expected_numel} numel but got {padded_unsharded_flat_param.numel()}",
)
- dist._all_gather_base(
+ dist.all_gather_into_tensor(
padded_unsharded_flat_param,
sharded_flat_param,
self.process_group,
@@ -861,7 +861,7 @@
self._check_sharded(flat_param.grad)
flat_param._saved_grad_shard = flat_param.grad # type: ignore[attr-defined]
sharded_grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
- dist._all_gather_base(padded_unsharded_grad, sharded_grad, self.process_group)
+ dist.all_gather_into_tensor(padded_unsharded_grad, sharded_grad, self.process_group)
unsharded_size = self.flat_param._unpadded_unsharded_size
flat_param.grad = padded_unsharded_grad[:unsharded_size.numel()].view(unsharded_size)
self._use_unsharded_grad_views()
diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py
index 76d662b..08b2d5f 100644
--- a/torch/distributed/fsdp/fully_sharded_data_parallel.py
+++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py
@@ -577,7 +577,7 @@
tensor_kwargs = {"dtype": torch.int32, "device": device}
world_num_valid_indices = torch.zeros(self.world_size, **tensor_kwargs)
local_num_valid_indices = torch.tensor([num_valid_indices], **tensor_kwargs)
- dist._all_gather_base(
+ dist.all_gather_into_tensor(
world_num_valid_indices,
local_num_valid_indices,
group=self.process_group,
@@ -602,7 +602,7 @@
self.world_size * num_valid_indices, **tensor_kwargs
)
local_indices = torch.tensor(local_indices, **tensor_kwargs)
- dist._all_gather_base(
+ dist.all_gather_into_tensor(
world_indices, local_indices, group=self.process_group
)
# Check that all ranks plan to all-gather the same index parameters
@@ -2608,9 +2608,10 @@
)
nonsharded_tensors = []
- # TODO: Reduce the communication by using only one _all_gather_base to
- # gather all the parameters in this layer. This can be achieved by
- # concatenated all the local shards and then append the padding.
+ # TODO: Reduce the communication by using only one
+ # `all_gather_into_tensor()` to gather all the parameters in this
+ # layer. This can be achieved by concatenating all the local shards and
+ # then appending the padding.
# https://github.com/pytorch/pytorch/issues/77461
shared_fqns = [fqn for fqn, _, _ in self._shared_param_fqns]
for fqn, _, _ in self._param_fqns:
@@ -2640,7 +2641,7 @@
tensor = torch.empty(
chunk_size * self.world_size, dtype=local_tensor.dtype
).cuda()
- dist._all_gather_base(tensor, local_tensor, group=self.process_group)
+ dist.all_gather_into_tensor(tensor, local_tensor, group=self.process_group)
tensor = tensor.narrow(0, 0, param_numel).reshape(param.size())
nonsharded_tensors.append(tensor)