Prevent sum overflow in broadcast_object_list (#70336)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70336
broadcast_object_list casted the sum of all object lengths to int from long causing overflows.
Test Plan:
Increased size of Tensor used in object transfers to have >2GB storage requirement (in distributed_test.py)
Without fix the length will overflow and the program will request a negative sized Tensor:
```
RuntimeError: Trying to create tensor with negative dimension -2147482417: [-2147482417]
```
With fix it will pass the test.
Test used on server with GPUs:
buck test mode/dev-nosan //caffe2/test/distributed:distributed_nccl_spawn --local -- broadcast_object
Differential Revision: D33281300
fbshipit-source-id: 1bc83e8624edc14e747eeced7bc8a7a10e443ee4
diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py
index 0661c66..db09489 100644
--- a/torch/distributed/distributed_c10d.py
+++ b/torch/distributed/distributed_c10d.py
@@ -1864,7 +1864,7 @@
object_tensor = torch.cat(tensor_list)
else:
object_tensor = torch.empty(
- torch.sum(object_sizes_tensor).int().item(), # type: ignore[arg-type]
+ torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type]
dtype=torch.uint8,
)
diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py
index 77cc21f..1487a0d 100644
--- a/torch/testing/_internal/distributed/distributed_test.py
+++ b/torch/testing/_internal/distributed/distributed_test.py
@@ -124,7 +124,6 @@
foo_cpu_tensor = Foo(torch.randn(3, 3))
-
COLLECTIVES_OBJECT_TEST_LIST = [
{"key1": 3, "key2": 4, "key3": {"nested": True}},
f,
@@ -5454,6 +5453,9 @@
def _test_allgather_object(self, subgroup=None):
# Only set device for NCCL backend since it must use GPUs.
+
+ gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy()
+
backend = os.environ["BACKEND"]
if backend == "nccl":
# Case where rank != GPU device.
@@ -5462,9 +5464,7 @@
# If GPU test, add object with GPU tensor
if backend == "nccl":
- COLLECTIVES_OBJECT_TEST_LIST.append(Foo(torch.randn(3, 3, device=0)))
-
- gather_objects = COLLECTIVES_OBJECT_TEST_LIST
+ gather_objects.append(Foo(torch.randn(3, 3, device=0)))
output_gathered = [None for _ in range(dist.get_world_size())]
dist.all_gather_object(
@@ -5498,7 +5498,7 @@
def _test_gather_object(self, pg=None):
# Ensure stateful objects can be gathered
- gather_objects = COLLECTIVES_OBJECT_TEST_LIST
+ gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy()
output_gathered = [None for _ in range(dist.get_world_size(pg))]
gather_on_rank = 0
my_rank = dist.get_rank(pg)
@@ -6241,6 +6241,12 @@
loss.backward()
def _test_broadcast_object_list(self, group=None):
+ gather_objects = COLLECTIVES_OBJECT_TEST_LIST.copy()
+ # Create Tensor with > 2^31 Bytes storage requirements
+ gather_objects.append(Foo(torch.randn(3, 178956971)))
+
+
+
# Only set device for NCCL backend since it must use GPUs.
# Case where rank != GPU device.
next_rank = (self.rank + 1) % int(self.world_size)
@@ -6251,12 +6257,12 @@
src_rank = 0
# If GPU test, add object with GPU tensor
if backend == "nccl":
- COLLECTIVES_OBJECT_TEST_LIST.append(Foo(torch.randn(3, 3, device=0)))
+ gather_objects.append(Foo(torch.randn(3, 3, device=0)))
objects = (
- COLLECTIVES_OBJECT_TEST_LIST
+ gather_objects
if self.rank == src_rank
- else [None for _ in COLLECTIVES_OBJECT_TEST_LIST]
+ else [None for _ in gather_objects]
)
# Single object test with device specified. Backend="gloo", device=cpu
@@ -6264,12 +6270,12 @@
single_obj_list = [objects[0]]
if self.rank != src_rank:
self.assertNotEqual(
- single_obj_list[0], COLLECTIVES_OBJECT_TEST_LIST[0]
+ single_obj_list[0], gather_objects[0]
)
dist.broadcast_object_list(
single_obj_list, src=0, group=group, device=torch.device("cpu")
)
- self.assertEqual(single_obj_list[0], COLLECTIVES_OBJECT_TEST_LIST[0])
+ self.assertEqual(single_obj_list[0], gather_objects[0])
# Single object test with device specified. Backend="gloo", device=current_device+1
# The test is gated by the fact GPU count is the same as world size to avoid the case
@@ -6278,37 +6284,37 @@
single_obj_list = [objects[0]]
if self.rank != src_rank:
self.assertNotEqual(
- single_obj_list[0], COLLECTIVES_OBJECT_TEST_LIST[0]
+ single_obj_list[0], gather_objects[0]
)
dist.broadcast_object_list(
single_obj_list, src=0, group=group, device=torch.device(next_rank)
)
- self.assertEqual(single_obj_list[0], COLLECTIVES_OBJECT_TEST_LIST[0])
+ self.assertEqual(single_obj_list[0], gather_objects[0])
# Single object test with device specified. Backend="nccl", device=current_device+1
if backend == "nccl" and torch.cuda.device_count() == int(self.world_size):
single_obj_list = [objects[0]]
if self.rank != src_rank:
self.assertNotEqual(
- single_obj_list[0], COLLECTIVES_OBJECT_TEST_LIST[0]
+ single_obj_list[0], gather_objects[0]
)
dist.broadcast_object_list(
single_obj_list, src=0, group=group, device=torch.device(next_rank)
)
- self.assertEqual(single_obj_list[0], COLLECTIVES_OBJECT_TEST_LIST[0])
+ self.assertEqual(single_obj_list[0], gather_objects[0])
# Single object test: backward compatibility with device unspecified
single_obj_list = [objects[0]]
if self.rank != src_rank:
- self.assertNotEqual(single_obj_list[0], COLLECTIVES_OBJECT_TEST_LIST[0])
+ self.assertNotEqual(single_obj_list[0], gather_objects[0])
dist.broadcast_object_list(single_obj_list, src=0, group=group)
- self.assertEqual(single_obj_list[0], COLLECTIVES_OBJECT_TEST_LIST[0])
+ self.assertEqual(single_obj_list[0], gather_objects[0])
# Multiple input objects test
if self.rank != src_rank:
- self.assertNotEqual(objects, COLLECTIVES_OBJECT_TEST_LIST)
+ self.assertNotEqual(objects, gather_objects)
dist.broadcast_object_list(objects, src=0, group=group)
- self.assertEqual(objects, COLLECTIVES_OBJECT_TEST_LIST)
+ self.assertEqual(objects, gather_objects)
@require_backend(DistTestCases.backend_feature["gpu"])
@require_n_gpus_for_nccl_backend(