[c10d] Remove deprecated use of torch.LongTensor, torch.ByteTensor (#55861)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55861
APIs such as torch.LongTensor and torch.ByteTensor are deprecated and
the recommended API is torch.tensor(args, dtype=...). Use this API in
distributed_c10d.
ghstack-source-id: 126777875
Test Plan: CI
Reviewed By: pbelevich
Differential Revision: D27726600
fbshipit-source-id: 07eb8168d93697593589002c93c3903ce29431ef
diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py
index 4e36650..d3c3963 100644
--- a/torch/distributed/distributed_c10d.py
+++ b/torch/distributed/distributed_c10d.py
@@ -1463,8 +1463,8 @@
f = io.BytesIO()
_pickler(f).dump(obj)
byte_storage = torch.ByteStorage.from_buffer(f.getvalue()) # type: ignore[attr-defined]
- byte_tensor = torch.ByteTensor(byte_storage)
- local_size = torch.LongTensor([byte_tensor.numel()])
+ byte_tensor = torch.tensor(byte_storage, dtype=torch.uint8)
+ local_size = torch.tensor([byte_tensor.numel()], dtype=torch.long)
return byte_tensor, local_size
@@ -1556,7 +1556,9 @@
all_gather(output_tensors, input_tensor, group=group)
# Deserialize outputs back to object.
for i, tensor in enumerate(output_tensors):
- tensor = tensor.type(torch.ByteTensor) # type:ignore[call-overload]
+ tensor = tensor.type(torch.uint8) # type:ignore[call-overload]
+ if tensor.device != torch.device("cpu"):
+ tensor = tensor.cpu()
tensor_size = object_size_list[i]
object_list[i] = _tensor_to_object(tensor, tensor_size)
@@ -1656,7 +1658,7 @@
if my_rank != dst:
return
for i, tensor in enumerate(output_tensors):
- tensor = tensor.type(torch.ByteTensor) # type: ignore[call-overload]
+ tensor = tensor.type(torch.uint8) # type: ignore[call-overload]
tensor_size = object_size_list[i]
object_gather_list[i] = _tensor_to_object(tensor, tensor_size)
@@ -1718,7 +1720,7 @@
tensor_list, size_list = zip(*[_object_to_tensor(obj) for obj in object_list])
object_sizes_tensor = torch.cat(size_list)
else:
- object_sizes_tensor = torch.LongTensor(len(object_list))
+ object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long)
group_backend = get_backend(group)
is_nccl_backend = group_backend == Backend.NCCL
@@ -1738,7 +1740,10 @@
if my_rank == src:
object_tensor = torch.cat(tensor_list)
else:
- object_tensor = torch.ByteTensor(torch.sum(object_sizes_tensor).item())
+ object_tensor = torch.empty(
+ torch.sum(object_sizes_tensor).int().item(), # type: ignore[arg-type]
+ dtype=torch.uint8
+ )
if is_nccl_backend:
object_tensor = object_tensor.to(current_device)
@@ -1748,7 +1753,9 @@
if my_rank != src:
for i, obj_size in enumerate(object_sizes_tensor):
obj_view = object_tensor[offset : offset + obj_size]
- obj_view = obj_view.type(torch.ByteTensor) # type: ignore[call-overload]
+ obj_view = obj_view.type(torch.uint8) # type: ignore[call-overload]
+ if obj_view.device != torch.device("cpu"):
+ obj_view = obj_view.cpu()
offset += obj_size
object_list[i] = _tensor_to_object(obj_view, obj_size)
@@ -1821,7 +1828,6 @@
)
tensor_list, tensor_sizes = list(tensor_list), list(tensor_sizes)
- obj_tensor_size = torch.LongTensor([0])
# Src rank broadcasts the maximum tensor size. This is because all ranks are
# expected to call into scatter() with equal-sized tensors.
if my_rank == src:
@@ -1829,11 +1835,11 @@
for tensor in tensor_list:
tensor.resize_(max_tensor_size)
else:
- max_tensor_size = torch.LongTensor([0])
+ max_tensor_size = torch.tensor([0], dtype=torch.long)
broadcast(max_tensor_size, src=src, group=group)
# Scatter actual serialized objects
- output_tensor = torch.ByteTensor(max_tensor_size.item())
+ output_tensor = torch.empty(max_tensor_size.item(), dtype=torch.uint8)
scatter(
output_tensor,
scatter_list=None if my_rank != src else tensor_list,
@@ -1842,6 +1848,7 @@
)
# Scatter per-object sizes to trim tensors when deserializing back to object
+ obj_tensor_size = torch.tensor([0], dtype=torch.long)
scatter(
obj_tensor_size,
scatter_list=None if my_rank != src else tensor_sizes,