[dtensor] fix new_empty_strided op (#107835)

This PR fixes the new_empty_strided op to become replicate from sharding
when necessary, this is a quick fix to resolve https://github.com/pytorch/pytorch/issues/107661

We'll need to think more about the behavior of this op when it comes to
sharding, one possibility is to follow the input sharding, but given the
output shape of this op might not be the same as the input, it's hard to
say we should follow the input sharding, further improvement needed once
we figure out the op syntax
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107835
Approved by: https://github.com/fduwjj
diff --git a/test/distributed/_tensor/test_dtensor.py b/test/distributed/_tensor/test_dtensor.py
index 323475b..728afd6 100644
--- a/test/distributed/_tensor/test_dtensor.py
+++ b/test/distributed/_tensor/test_dtensor.py
@@ -246,6 +246,29 @@
             self.assertEqual(sharded_tensor.grad.stride(), [1, 3 * self.world_size])
 
     @with_comms
+    def test_dtensor_new_empty_strided(self):
+        device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
+        local_tensor = torch.randn(8, 8, requires_grad=True, device=self.device_type)
+        my_dtensor = distribute_tensor(local_tensor, device_mesh, [Shard(0)])
+        new_strided_dtensor = my_dtensor.new_empty_strided(
+            (8, 8), (8, 1), requires_grad=True
+        )
+        # test the op produces new dtensor and autograd works
+        self.assertEqual(new_strided_dtensor.shape, my_dtensor.shape)
+        new_strided_dtensor.sum().backward()
+        self.assertIsNotNone(new_strided_dtensor.grad)
+        self.assertIsInstance(new_strided_dtensor.grad, DTensor)
+
+        # test backward new_empty_strided with sharding works correctly
+        my_dtensor.to_local().sum().backward()
+        local_tensor.sum().backward()
+        self.assertEqual(my_dtensor.grad, new_strided_dtensor.grad)
+        self.assertEqual(
+            my_dtensor.grad.redistribute(placements=[Replicate()]).to_local(),
+            local_tensor.grad,
+        )
+
+    @with_comms
     def test_dtensor_async_output(self):
         # Tests that if the output of some dtensor operations  isn't used in any compute,
         # the output should be an AsyncCollectiveTensor (representing the fact that
diff --git a/torch/distributed/_tensor/ops/tensor_ops.py b/torch/distributed/_tensor/ops/tensor_ops.py
index 1a8bbf1..e8cb30e 100644
--- a/torch/distributed/_tensor/ops/tensor_ops.py
+++ b/torch/distributed/_tensor/ops/tensor_ops.py
@@ -43,7 +43,6 @@
         aten.detach.default,
         aten.equal.default,
         aten.is_same_size.default,
-        aten.new_empty_strided.default,  # TODO: re-think new_empty_strided
     ]
 )
 def default_strategy(
@@ -109,6 +108,7 @@
         aten.new_full.default,
         aten.new_ones.default,
         aten.new_zeros.default,
+        aten.new_empty_strided.default,  # TODO: re-think new_empty_strided
     ]
 )
 def new_factory_strategy(