[DTensor][Test] Update implicit replication unit tests for tensor arg being the first in args list (#127803)
Change the operands order so we can have test coverage for when the first arg is a tensor arg instead of DTensor arg.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127803
Approved by: https://github.com/XilunWu
diff --git a/test/distributed/_tensor/test_dtensor.py b/test/distributed/_tensor/test_dtensor.py
index e2758e3..c716d92 100644
--- a/test/distributed/_tensor/test_dtensor.py
+++ b/test/distributed/_tensor/test_dtensor.py
@@ -784,7 +784,9 @@
from torch.distributed._tensor.experimental import implicit_replication
with implicit_replication():
- out_dt = sharded_dtensor + torch.ones(3, device=self.device_type)
+ # We put the scalar tensor as the left operand so we can test out
+ # when a non-dtensor is a the arg in the args list.
+ out_dt = torch.ones(3, device=self.device_type) + sharded_dtensor
self.assertEqual(out_dt.placements, [Shard(0)])
self.assertEqual(out_dt.shape, (4 * self.world_size, 3))
local_shard = out_dt.to_local()
@@ -802,7 +804,7 @@
ndim_0_tensor = torch.tensor(1, device=self.device_type)
def add_scalar_tensor_with_dtensor():
- return sharded_dtensor + ndim_0_tensor
+ return ndim_0_tensor + sharded_dtensor
result = add_scalar_tensor_with_dtensor().to_local()
self.assertEqual(result, local_tensor + ndim_0_tensor)
@@ -814,7 +816,7 @@
# automatically turn tensor to DTensor replicate when ndim = 1 and numel = 1
numel_1_tensor = torch.tensor([1], device=self.device_type)
self.assertEqual(
- (sharded_dtensor + numel_1_tensor).to_local(), local_tensor + numel_1_tensor
+ (numel_1_tensor + sharded_dtensor).to_local(), numel_1_tensor + local_tensor
)