[PT-D] Enable init ops for DTensor (#92651)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92651
Approved by: https://github.com/wanchaol
diff --git a/test/distributed/_tensor/test_init.py b/test/distributed/_tensor/test_init.py
new file mode 100644
index 0000000..c7d6395
--- /dev/null
+++ b/test/distributed/_tensor/test_init.py
@@ -0,0 +1,44 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# Owner(s): ["oncall: distributed"]
+
+import torch
+from torch.distributed._tensor import (
+ DTensor,
+ Shard,
+)
+from torch.testing._internal.common_utils import run_tests
+from torch.testing._internal.distributed._tensor.common_dtensor import (
+ DTensorTestBase,
+ with_comms,
+)
+
+
+class DTensorInitOpsTest(DTensorTestBase):
+ def _run_init_op(self, init_op, *args, **kwargs):
+ device_mesh = self.build_device_mesh()
+ shard_spec = [Shard(0)]
+ input_size = (8, 4)
+ input_tensor = torch.randn(*input_size, device=self.device_type)
+ dtensor = DTensor.from_local(input_tensor, device_mesh, shard_spec)
+ local_tensor_clone = torch.clone(input_tensor)
+ torch.manual_seed(self.rank)
+ local_tensor_clone = init_op(local_tensor_clone, *args, **kwargs)
+ torch.manual_seed(self.rank)
+ dtensor = init_op(dtensor, *args, **kwargs)
+ self.assertEqual(local_tensor_clone, dtensor.to_local())
+
+ @with_comms
+ def test_init_ops(self):
+ self._run_init_op(
+ torch.nn.init.kaiming_uniform_,
+ a=0,
+ mode="fan_in",
+ nonlinearity="leaky_relu",
+ )
+ self._run_init_op(torch.nn.init.normal_, mean=1.5, std=0.8)
+ self._run_init_op(torch.nn.init.uniform_, a=0, b=1.2)
+ self._run_init_op(torch.nn.init.constant_, 2.4)
+
+
+if __name__ == "__main__":
+ run_tests()
diff --git a/torch/distributed/_tensor/ops/pointwise_ops.py b/torch/distributed/_tensor/ops/pointwise_ops.py
index 08f6c3a..0c75168 100644
--- a/torch/distributed/_tensor/ops/pointwise_ops.py
+++ b/torch/distributed/_tensor/ops/pointwise_ops.py
@@ -120,6 +120,7 @@
"aten.conj_physical.default",
"aten.conj_physical.out",
"aten.conj_physical_.default",
+ "aten.constant_.default",
"aten.copy_sign.Scalar",
"aten.copy_sign.Scalar_out",
"aten.copy_sign.Tensor",
@@ -376,24 +377,29 @@
DTensor._op_to_rules[op] = pointwise_rule
-@register_prop_rule("aten.native_dropout.default")
-def dropout_rule(op_schema: OpSchema) -> OutputSharding:
- self_spec = cast(DTensorSpec, op_schema.args_schema[0])
+def _register_non_deterministic_op(op):
+ @register_prop_rule(op)
+ def non_deterministic_rule(op_schema: OpSchema) -> OutputSharding:
+ self_spec = cast(DTensorSpec, op_schema.args_schema[0])
- # TODO: We are specializing dropout_rule now because it's
- # a non-deterministic algorithm, and replication does not
- # not support non-deterministic op yet. We should remove
- # this rule and make dropout to use pointwise rule instead
- # once we support non-deterministic op.
- replicate_or_partial = False
- for placement in self_spec.placements:
- if isinstance(placement, (Replicate, _Partial)):
- replicate_or_partial = True
- break
+ # TODO: We are specializing non_deterministic_rule now because
+ # replicate does not support this op yet. We should remove
+ # this rule once we support non-deterministic op for replicate.
+ replicate_or_partial = False
+ for placement in self_spec.placements:
+ if isinstance(placement, (Replicate, _Partial)):
+ replicate_or_partial = True
+ break
- if replicate_or_partial:
- return OutputSharding(
- None, failed_reason="Dropout with replication is not supported yet!"
- )
- else:
- return OutputSharding(self_spec)
+ if replicate_or_partial:
+ return OutputSharding(
+ None, failed_reason=f"{op} with replication is not supported yet!"
+ )
+ else:
+ return OutputSharding(self_spec)
+
+
+_register_non_deterministic_op("aten.native_dropout.default")
+_register_non_deterministic_op("aten.uniform_.default")
+_register_non_deterministic_op("aten.normal_.default")
+_register_non_deterministic_op("aten.kaiming_uniform_.default")