[PT-D] Enable init ops for DTensor (#92651)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92651
Approved by: https://github.com/wanchaol
diff --git a/test/distributed/_tensor/test_init.py b/test/distributed/_tensor/test_init.py
new file mode 100644
index 0000000..c7d6395
--- /dev/null
+++ b/test/distributed/_tensor/test_init.py
@@ -0,0 +1,44 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# Owner(s): ["oncall: distributed"]
+
+import torch
+from torch.distributed._tensor import (
+    DTensor,
+    Shard,
+)
+from torch.testing._internal.common_utils import run_tests
+from torch.testing._internal.distributed._tensor.common_dtensor import (
+    DTensorTestBase,
+    with_comms,
+)
+
+
+class DTensorInitOpsTest(DTensorTestBase):
+    def _run_init_op(self, init_op, *args, **kwargs):
+        device_mesh = self.build_device_mesh()
+        shard_spec = [Shard(0)]
+        input_size = (8, 4)
+        input_tensor = torch.randn(*input_size, device=self.device_type)
+        dtensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
+        local_tensor_clone = torch.clone(input_tensor)
+        torch.manual_seed(self.rank)
+        local_tensor_clone = init_op(local_tensor_clone, *args, **kwargs)
+        torch.manual_seed(self.rank)
+        dtensor = init_op(dtensor, *args, **kwargs)
+        self.assertEqual(local_tensor_clone, dtensor.to_local())
+
+    @with_comms
+    def test_init_ops(self):
+        self._run_init_op(
+            torch.nn.init.kaiming_uniform_,
+            a=0,
+            mode="fan_in",
+            nonlinearity="leaky_relu",
+        )
+        self._run_init_op(torch.nn.init.normal_, mean=1.5, std=0.8)
+        self._run_init_op(torch.nn.init.uniform_, a=0, b=1.2)
+        self._run_init_op(torch.nn.init.constant_, 2.4)
+
+
+if __name__ == "__main__":
+    run_tests()
diff --git a/torch/distributed/_tensor/ops/pointwise_ops.py b/torch/distributed/_tensor/ops/pointwise_ops.py
index 08f6c3a..0c75168 100644
--- a/torch/distributed/_tensor/ops/pointwise_ops.py
+++ b/torch/distributed/_tensor/ops/pointwise_ops.py
@@ -120,6 +120,7 @@
     "aten.conj_physical.default",
     "aten.conj_physical.out",
     "aten.conj_physical_.default",
+    "aten.constant_.default",
     "aten.copy_sign.Scalar",
     "aten.copy_sign.Scalar_out",
     "aten.copy_sign.Tensor",
@@ -376,24 +377,29 @@
     DTensor._op_to_rules[op] = pointwise_rule
 
 
-@register_prop_rule("aten.native_dropout.default")
-def dropout_rule(op_schema: OpSchema) -> OutputSharding:
-    self_spec = cast(DTensorSpec, op_schema.args_schema[0])
+def _register_non_deterministic_op(op):
+    @register_prop_rule(op)
+    def non_deterministic_rule(op_schema: OpSchema) -> OutputSharding:
+        self_spec = cast(DTensorSpec, op_schema.args_schema[0])
 
-    # TODO: We are specializing dropout_rule now because it's
-    # a non-deterministic algorithm, and replication does not
-    # not support non-deterministic op yet. We should remove
-    # this rule and make dropout to use pointwise rule instead
-    # once we support non-deterministic op.
-    replicate_or_partial = False
-    for placement in self_spec.placements:
-        if isinstance(placement, (Replicate, _Partial)):
-            replicate_or_partial = True
-            break
+        # TODO: We are specializing non_deterministic_rule now because
+        # replicate does not support this op yet. We should remove
+        # this rule once we support non-deterministic op for replicate.
+        replicate_or_partial = False
+        for placement in self_spec.placements:
+            if isinstance(placement, (Replicate, _Partial)):
+                replicate_or_partial = True
+                break
 
-    if replicate_or_partial:
-        return OutputSharding(
-            None, failed_reason="Dropout with replication is not supported yet!"
-        )
-    else:
-        return OutputSharding(self_spec)
+        if replicate_or_partial:
+            return OutputSharding(
+                None, failed_reason=f"{op} with replication is not supported yet!"
+            )
+        else:
+            return OutputSharding(self_spec)
+
+
+_register_non_deterministic_op("aten.native_dropout.default")
+_register_non_deterministic_op("aten.uniform_.default")
+_register_non_deterministic_op("aten.normal_.default")
+_register_non_deterministic_op("aten.kaiming_uniform_.default")