[dtensor][8/N] Introduce cost model for sharding (#109145)

This PR adds some basic comm cost model for sharding prop
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109145
Approved by: https://github.com/fduwjj
diff --git a/test/distributed/_tensor/test_basic_strategy.py b/test/distributed/_tensor/test_basic_strategy.py
deleted file mode 100644
index 0a5a026..0000000
--- a/test/distributed/_tensor/test_basic_strategy.py
+++ /dev/null
@@ -1,120 +0,0 @@
-# Owner(s): ["oncall: distributed"]
-
-import torch
-from torch.distributed._tensor import DeviceMesh
-from torch.distributed._tensor.ops.basic_strategy import (
-    EinsumDims,
-    gen_einsum_strategies,
-)
-
-from torch.testing._internal.common_utils import run_tests, TestCase
-from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase
-
-
-class TestEinsumDims(TestCase):
-    def test_batch_dims(self):
-        equation = "abc,abc->abc"
-        input_dims, output_dim = EinsumDims.parse_equation(equation)
-        edims = EinsumDims.parse_dims(input_dims, output_dim)
-
-        self.assertEqual(edims.batch_dims, ["a", "b", "c"])
-        self.assertEqual(edims.contracting_dims, [])
-        self.assertEqual(edims.lhs_out_only_dims, [])
-        self.assertEqual(edims.rhs_out_only_dims, [])
-
-    def test_mm_dims(self):
-        equation = "mk,kn->mn"
-        input_dims, output_dim = EinsumDims.parse_equation(equation)
-        edims = EinsumDims.parse_dims(input_dims, output_dim)
-
-        self.assertEqual(edims.batch_dims, [])
-        self.assertEqual(edims.contracting_dims, ["k"])
-        self.assertEqual(edims.lhs_out_only_dims, ["m"])
-        self.assertEqual(edims.rhs_out_only_dims, ["n"])
-
-    def test_bmm_dims(self):
-        equation = "bmk,bkn->bmn"
-        input_dims, output_dim = EinsumDims.parse_equation(equation)
-        edims = EinsumDims.parse_dims(input_dims, output_dim)
-
-        self.assertEqual(edims.batch_dims, ["b"])
-        self.assertEqual(edims.contracting_dims, ["k"])
-        self.assertEqual(edims.lhs_out_only_dims, ["m"])
-        self.assertEqual(edims.rhs_out_only_dims, ["n"])
-
-        equation = "bcmk,bckn->bcmn"
-        input_dims, output_dim = EinsumDims.parse_equation(equation)
-        edims = EinsumDims.parse_dims(input_dims, output_dim)
-
-        self.assertEqual(edims.batch_dims, ["b", "c"])
-        self.assertEqual(edims.contracting_dims, ["k"])
-        self.assertEqual(edims.lhs_out_only_dims, ["m"])
-        self.assertEqual(edims.rhs_out_only_dims, ["n"])
-
-    def test_free_dims(self):
-        equation = "abc,ab->abc"
-        input_dims, output_dim = EinsumDims.parse_equation(equation)
-        edims = EinsumDims.parse_dims(input_dims, output_dim)
-
-        self.assertEqual(edims.batch_dims, ["a", "b"])
-        self.assertEqual(edims.contracting_dims, [])
-        self.assertEqual(edims.lhs_out_only_dims, ["c"])
-        self.assertEqual(edims.rhs_out_only_dims, [])
-
-        equation = "abd,bf->abfd"
-        input_dims, output_dim = EinsumDims.parse_equation(equation)
-        edims = EinsumDims.parse_dims(input_dims, output_dim)
-
-        self.assertEqual(edims.batch_dims, ["b"])
-        self.assertEqual(edims.contracting_dims, [])
-        self.assertEqual(edims.lhs_out_only_dims, ["a", "d"])
-        self.assertEqual(edims.rhs_out_only_dims, ["f"])
-
-
-class TestEinsumStrategies(DTensorOpTestBase):
-    @property
-    def world_size(self) -> int:
-        return 4
-
-    def test_mm_1d_mesh(self):
-        mesh = self.build_device_mesh()
-
-        all_strats = gen_einsum_strategies("mk,kn->mn", mesh)
-        self.assertEqual(len(all_strats.strategies), 4)
-
-    def test_mm_2d_mesh(self):
-        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, 2))
-
-        all_strats = gen_einsum_strategies("mk,kn->mn", mesh)
-        self.assertEqual(len(all_strats.strategies), 16)
-
-    def test_bmm_1d_mesh(self):
-        mesh = self.build_device_mesh()
-
-        all_strats = gen_einsum_strategies("bmk,bkn->bmn", mesh)
-        self.assertEqual(len(all_strats.strategies), 5)
-
-    def test_bmm_2d_mesh(self):
-        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, 2))
-
-        all_strats = gen_einsum_strategies("bmk,bkn->bmn", mesh)
-        self.assertEqual(len(all_strats.strategies), 25)
-
-    def test_pointwise_1d_mesh(self):
-        mesh = self.build_device_mesh()
-
-        simple_strats = gen_einsum_strategies("abcd,abcd->abcd", mesh)
-        self.assertEqual(len(simple_strats.strategies), 5)
-
-        broadcast_strats = gen_einsum_strategies("bcd,abcd->abcd", mesh)
-        self.assertEqual(len(broadcast_strats.strategies), 5)
-
-    def test_linearity_1d_mesh(self):
-        mesh = self.build_device_mesh()
-
-        all_strats = gen_einsum_strategies("abcd,abcd->abcd", mesh, linearity=True)
-        self.assertEqual(len(all_strats.strategies), 6)
-
-
-if __name__ == "__main__":
-    run_tests()
diff --git a/test/distributed/_tensor/test_op_strategy.py b/test/distributed/_tensor/test_op_strategy.py
new file mode 100644
index 0000000..c13f2c4
--- /dev/null
+++ b/test/distributed/_tensor/test_op_strategy.py
@@ -0,0 +1,202 @@
+# Owner(s): ["oncall: distributed"]
+
+import torch
+from torch.distributed._tensor import DeviceMesh
+from torch.distributed._tensor._collective_utils import redistribute_cost
+from torch.distributed._tensor.ops.basic_strategy import (
+    EinsumDims,
+    gen_einsum_strategies,
+)
+from torch.distributed._tensor.placement_types import (
+    _Partial,
+    DTensorSpec,
+    Replicate,
+    Shard,
+    TensorMeta,
+)
+
+from torch.testing._internal.common_utils import run_tests, TestCase
+from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase
+
+
+class TestEinsumDims(TestCase):
+    def test_batch_dims(self):
+        equation = "abc,abc->abc"
+        input_dims, output_dim = EinsumDims.parse_equation(equation)
+        edims = EinsumDims.parse_dims(input_dims, output_dim)
+
+        self.assertEqual(edims.batch_dims, ["a", "b", "c"])
+        self.assertEqual(edims.contracting_dims, [])
+        self.assertEqual(edims.lhs_out_only_dims, [])
+        self.assertEqual(edims.rhs_out_only_dims, [])
+
+    def test_mm_dims(self):
+        equation = "mk,kn->mn"
+        input_dims, output_dim = EinsumDims.parse_equation(equation)
+        edims = EinsumDims.parse_dims(input_dims, output_dim)
+
+        self.assertEqual(edims.batch_dims, [])
+        self.assertEqual(edims.contracting_dims, ["k"])
+        self.assertEqual(edims.lhs_out_only_dims, ["m"])
+        self.assertEqual(edims.rhs_out_only_dims, ["n"])
+
+    def test_bmm_dims(self):
+        equation = "bmk,bkn->bmn"
+        input_dims, output_dim = EinsumDims.parse_equation(equation)
+        edims = EinsumDims.parse_dims(input_dims, output_dim)
+
+        self.assertEqual(edims.batch_dims, ["b"])
+        self.assertEqual(edims.contracting_dims, ["k"])
+        self.assertEqual(edims.lhs_out_only_dims, ["m"])
+        self.assertEqual(edims.rhs_out_only_dims, ["n"])
+
+        equation = "bcmk,bckn->bcmn"
+        input_dims, output_dim = EinsumDims.parse_equation(equation)
+        edims = EinsumDims.parse_dims(input_dims, output_dim)
+
+        self.assertEqual(edims.batch_dims, ["b", "c"])
+        self.assertEqual(edims.contracting_dims, ["k"])
+        self.assertEqual(edims.lhs_out_only_dims, ["m"])
+        self.assertEqual(edims.rhs_out_only_dims, ["n"])
+
+    def test_free_dims(self):
+        equation = "abc,ab->abc"
+        input_dims, output_dim = EinsumDims.parse_equation(equation)
+        edims = EinsumDims.parse_dims(input_dims, output_dim)
+
+        self.assertEqual(edims.batch_dims, ["a", "b"])
+        self.assertEqual(edims.contracting_dims, [])
+        self.assertEqual(edims.lhs_out_only_dims, ["c"])
+        self.assertEqual(edims.rhs_out_only_dims, [])
+
+        equation = "abd,bf->abfd"
+        input_dims, output_dim = EinsumDims.parse_equation(equation)
+        edims = EinsumDims.parse_dims(input_dims, output_dim)
+
+        self.assertEqual(edims.batch_dims, ["b"])
+        self.assertEqual(edims.contracting_dims, [])
+        self.assertEqual(edims.lhs_out_only_dims, ["a", "d"])
+        self.assertEqual(edims.rhs_out_only_dims, ["f"])
+
+
+class TestEinsumStrategies(DTensorOpTestBase):
+    @property
+    def world_size(self) -> int:
+        return 4
+
+    def test_mm_1d_mesh(self):
+        mesh = self.build_device_mesh()
+
+        all_strats = gen_einsum_strategies("mk,kn->mn", mesh)
+        self.assertEqual(len(all_strats.strategies), 4)
+
+    def test_mm_2d_mesh(self):
+        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, 2))
+
+        all_strats = gen_einsum_strategies("mk,kn->mn", mesh)
+        self.assertEqual(len(all_strats.strategies), 16)
+
+    def test_bmm_1d_mesh(self):
+        mesh = self.build_device_mesh()
+
+        all_strats = gen_einsum_strategies("bmk,bkn->bmn", mesh)
+        self.assertEqual(len(all_strats.strategies), 5)
+
+    def test_bmm_2d_mesh(self):
+        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, 2))
+
+        all_strats = gen_einsum_strategies("bmk,bkn->bmn", mesh)
+        self.assertEqual(len(all_strats.strategies), 25)
+
+    def test_pointwise_1d_mesh(self):
+        mesh = self.build_device_mesh()
+
+        simple_strats = gen_einsum_strategies("abcd,abcd->abcd", mesh)
+        self.assertEqual(len(simple_strats.strategies), 5)
+
+        broadcast_strats = gen_einsum_strategies("bcd,abcd->abcd", mesh)
+        self.assertEqual(len(broadcast_strats.strategies), 5)
+
+    def test_linearity_1d_mesh(self):
+        mesh = self.build_device_mesh()
+
+        all_strats = gen_einsum_strategies("abcd,abcd->abcd", mesh, linearity=True)
+        self.assertEqual(len(all_strats.strategies), 6)
+
+
+class TestCostModel(DTensorOpTestBase):
+    def _extract_tensor_meta(self, t) -> TensorMeta:
+        return TensorMeta(t.shape, t.stride(), t.dtype)
+
+    @property
+    def world_size(self) -> int:
+        return 4
+
+    def test_redistribute_cost_mesh_1d(self):
+        mesh_1d = self.build_device_mesh()
+        shard_placement = (Shard(0),)
+        replica_placement = (Replicate(),)
+        partial_placement = (_Partial(),)
+
+        global_tensor = torch.randn(10, 10)
+        global_tensor_meta = self._extract_tensor_meta(global_tensor)
+
+        # shard spec
+        shard_spec = DTensorSpec(mesh_1d, shard_placement, global_tensor_meta)
+        # replica spec
+        replica_spec = DTensorSpec(mesh_1d, replica_placement, global_tensor_meta)
+        # partial spec
+        partial_spec = DTensorSpec(mesh_1d, partial_placement, global_tensor_meta)
+
+        # make sure reshard cost is 0 for the same spec redistribute
+        for spec in [shard_spec, replica_spec, partial_spec]:
+            cost = redistribute_cost(spec, spec)
+            self.assertEqual(cost, 0)
+
+        # shard -> replicate
+        allgather_cost = redistribute_cost(shard_spec, replica_spec)
+        # partial -> shard
+        reduce_scatter_cost = redistribute_cost(partial_spec, shard_spec)
+        # partial -> replicate
+        allreduce_cost = redistribute_cost(partial_spec, replica_spec)
+        self.assertEqual(allgather_cost, reduce_scatter_cost)
+        self.assertEqual(allreduce_cost + 1, allgather_cost + reduce_scatter_cost)
+        # shard to partial
+        cost = redistribute_cost(shard_spec, partial_spec)
+        self.assertEqual(cost, float("inf"))
+
+    def test_redistribute_cost_mesh_2d(self):
+        mesh_2d = DeviceMesh(
+            self.device_type, torch.arange(self.world_size).reshape(2, 2)
+        )
+        shard_placement = (Shard(0), Shard(0))
+        replica_placement = (Replicate(), Replicate())
+        partial_placement = (_Partial(), _Partial())
+
+        global_tensor = torch.randn(8, 8)
+        global_tensor_meta = self._extract_tensor_meta(global_tensor)
+
+        # shard spec
+        shard_spec = DTensorSpec(mesh_2d, shard_placement, global_tensor_meta)
+        # replica spec
+        replica_spec = DTensorSpec(mesh_2d, replica_placement, global_tensor_meta)
+        # partial spec
+        partial_spec = DTensorSpec(mesh_2d, partial_placement, global_tensor_meta)
+
+        # make sure reshard cost is 0 for the same spec redistribute
+        for spec in [shard_spec, replica_spec, partial_spec]:
+            cost = redistribute_cost(spec, spec)
+            self.assertEqual(cost, 0)
+
+        # shard -> replicate
+        allgather_cost = redistribute_cost(shard_spec, replica_spec)
+        # partial -> replicate
+        allreduce_cost = redistribute_cost(partial_spec, replica_spec)
+        # partial -> shard
+        reduce_scatter_cost = redistribute_cost(partial_spec, shard_spec)
+        self.assertTrue(allreduce_cost > allgather_cost)
+        self.assertTrue(allreduce_cost > reduce_scatter_cost)
+
+
+if __name__ == "__main__":
+    run_tests()
diff --git a/torch/distributed/_tensor/_collective_utils.py b/torch/distributed/_tensor/_collective_utils.py
index f5b99d8..f954021 100644
--- a/torch/distributed/_tensor/_collective_utils.py
+++ b/torch/distributed/_tensor/_collective_utils.py
@@ -1,9 +1,11 @@
 import logging
+import math
 
 from typing import List, Optional
 
 import torch
-from torch.distributed._tensor.device_mesh import DeviceMesh
+import torch.distributed._tensor.placement_types as placement_types
+from torch.distributed._tensor.device_mesh import DeviceMesh, mesh_resources
 from torch.distributed.distributed_c10d import (
     all_to_all,
     broadcast,
@@ -158,3 +160,128 @@
             async_op=async_op,
         )
     return work
+
+
+def spec_to_bytes(spec: "placement_types.DTensorSpec") -> int:
+    assert spec.tensor_meta is not None, "spec should have tensor meta defined!"
+    return spec.tensor_meta.dtype.itemsize * math.prod(spec.shape)
+
+
+def get_bandwidth_factor(mesh: DeviceMesh) -> List[float]:
+    # generate bandwidth factor for intra-host/inter-host communication pattern
+    factors = [1.0] * mesh.ndim
+    num_devices_per_host = mesh_resources.num_devices_per_host(mesh.device_type)
+
+    num_devices = 1
+    for mesh_dim in reversed(range(mesh.ndim)):
+        num_devices *= mesh.size(mesh_dim)
+        if num_devices <= num_devices_per_host:
+            # magic number for intra-host communication bandwidth factor
+            # TODO: see if we need to tweak this or offer a way for user
+            # to specify the bandwidths
+            factors[mesh_dim] = 0.2
+
+    return factors
+
+
+def allgather_cost(num_bytes: float, mesh: DeviceMesh, mesh_dim: int) -> float:
+    num_devices_on_mesh_dim = mesh.size(mesh_dim)
+    bandwidth_factor = get_bandwidth_factor(mesh)[mesh_dim]
+    # constant latency factor + bandwidth cost
+    return (
+        1
+        + bandwidth_factor
+        * num_bytes
+        * (num_devices_on_mesh_dim - 1)
+        / num_devices_on_mesh_dim
+    )
+
+
+def allreduce_cost(num_bytes: float, mesh: DeviceMesh, mesh_dim: int) -> float:
+    num_devices_on_mesh_dim = mesh.size(mesh_dim)
+    bandwidth_factor = get_bandwidth_factor(mesh)[mesh_dim]
+    # allreduce have 2x comm bytes compare to allgather/reduce_scatter
+    return (
+        1
+        + 2
+        * bandwidth_factor
+        * num_bytes
+        * (num_devices_on_mesh_dim - 1)
+        / num_devices_on_mesh_dim
+    )
+
+
+def reduce_scatter_cost(
+    num_bytes: float,
+    mesh: DeviceMesh,
+    mesh_dim: int,
+) -> float:
+    num_devices_on_mesh_dim = mesh.size(mesh_dim)
+    bandwidth_factor = get_bandwidth_factor(mesh)[mesh_dim]
+    # constant latency factor + bandwidth cost
+    return (
+        1
+        + bandwidth_factor
+        * num_bytes
+        * (num_devices_on_mesh_dim - 1)
+        / num_devices_on_mesh_dim
+    )
+
+
+def redistribute_cost(
+    current_spec: "placement_types.DTensorSpec",
+    target_spec: "placement_types.DTensorSpec",
+) -> float:
+    """
+    This function returns the cost of redistribute from current to target DTensorSpec.
+
+    NOTE:
+    1. Only consider communication cost here, since computation costs for redistribute
+       are quite trival (i.e. we only need to narrow or simple division)
+    2. Only consider redistribute cost on same mesh, cross mesh communication cost is
+       not quite needed for operator strategy estimation/selection.
+    """
+    if current_spec.mesh != target_spec.mesh:
+        # make infinite cost if meshes are not same
+        # TODO: see if we want to support this once there's cross mesh communication
+        return float("inf")
+
+    if current_spec.is_replicated():
+        # short-cut:
+        # comm cost is 0 if current spec is already full replication
+        return 0.0
+
+    mesh = current_spec.mesh
+    cost = 0.0
+    comm_bytes = spec_to_bytes(current_spec) / current_spec.num_shards
+    # Transformation that considered for redistribute cost:
+    # 1. allgather 2. alltoall
+    # 3. allreduce 4. reduce_scatter
+    for i, (current, target) in enumerate(
+        zip(current_spec.placements, target_spec.placements)
+    ):
+        if current == target:
+            continue
+        if current.is_shard() and target.is_replicate():
+            # allgather gives larger comm bytes
+            comm_bytes *= mesh.size(i)
+            # add up allgather comm cost
+            cost += allgather_cost(comm_bytes, current_spec.mesh, i)
+        elif current.is_shard() and target.is_shard():
+            # should be alltoall comm, since we haven't implement it yet, add penalty
+            # to favor allgather instead
+            cost += allgather_cost(comm_bytes, current_spec.mesh, i) + 1.0
+        elif current.is_partial() and target.is_replicate():
+            # add up allreduce comm cost
+            cost += allreduce_cost(comm_bytes, current_spec.mesh, i)
+        elif current.is_partial() and target.is_shard():
+            # add up reduce_scatter comm cost
+            cost += reduce_scatter_cost(comm_bytes, current_spec.mesh, i)
+            # after reduce_scatter the comm bytes for further collectives halved.
+            comm_bytes /= mesh.size(i)
+        elif current.is_shard() and target.is_partial():
+            # ban shard -> partial as it does not make sense to perform
+            # this redistribute
+            return float("inf")
+
+    return cost
diff --git a/torch/distributed/_tensor/device_mesh.py b/torch/distributed/_tensor/device_mesh.py
index 75a32b3..c61f28d 100644
--- a/torch/distributed/_tensor/device_mesh.py
+++ b/torch/distributed/_tensor/device_mesh.py
@@ -85,6 +85,16 @@
                 return parent_mesh.mesh_dim_names.index(child_mesh_dim_name)
         return None
 
+    @staticmethod
+    def num_devices_per_host(device_type: str) -> int:
+        return _get_device_handle(device_type).device_count()
+
+    @staticmethod
+    def num_hosts(device_type: str) -> int:
+        # ProcessGroup can't tell us this info so we have to infer it, assume
+        # homogeneous hardware for now
+        return get_world_size() // _MeshEnv.num_devices_per_host(device_type)
+
 
 mesh_resources: _MeshEnv = _MeshEnv()
 
diff --git a/torch/distributed/_tensor/op_schema.py b/torch/distributed/_tensor/op_schema.py
index 62f4614..ec3a759 100644
--- a/torch/distributed/_tensor/op_schema.py
+++ b/torch/distributed/_tensor/op_schema.py
@@ -58,6 +58,12 @@
     output_spec: DTensorSpec
     input_specs: Optional[Sequence[DTensorSpec]] = None
 
+    # redistribute costs for this op placement strategy
+    # we need a nested list to record the cost for each
+    # operand of this operator, and for each operand of
+    # this operator it might have multiple placement strategies
+    redistribute_cost: Optional[List[List[float]]] = None
+
     def pretty_print_placements(self, placements):
         return "".join([str(p) for p in placements])
 
diff --git a/torch/distributed/_tensor/ops/embedding_ops.py b/torch/distributed/_tensor/ops/embedding_ops.py
index 6ceb3d1..ea033af 100644
--- a/torch/distributed/_tensor/ops/embedding_ops.py
+++ b/torch/distributed/_tensor/ops/embedding_ops.py
@@ -24,16 +24,14 @@
             "DTensor does not support row-wise sharded embedding operation yet!"
         )
 
-    if all(
-        placement.is_replicate() for placement in weight_spec.placements
-    ) and inp_spec.placements == [Shard(0)]:
+    if weight_spec.is_replicated() and inp_spec.placements == [Shard(0)]:
         # Embedding table is replicated, input ids are sharded along batch
         # dimension. Output lookups should match input sharding spec in this case.
         return OutputSharding(
             output_spec=DTensorSpec(mesh=inp_spec.mesh, placements=inp_spec.placements)
         )
 
-    if all(placement.is_replicate() for placement in inp_spec.placements):
+    if inp_spec.is_replicated():
         weight_dim_map = weight_spec.dim_map
         output_dim_map = inp_spec.dim_map
         output_dim_map.append(weight_dim_map[1])
diff --git a/torch/distributed/_tensor/ops/math_ops.py b/torch/distributed/_tensor/ops/math_ops.py
index d42426a..9c52e33 100644
--- a/torch/distributed/_tensor/ops/math_ops.py
+++ b/torch/distributed/_tensor/ops/math_ops.py
@@ -15,6 +15,7 @@
 from torch.distributed._tensor.ops.common_rules import pointwise_rule
 from torch.distributed._tensor.ops.utils import (
     as_list,
+    generate_redistribute_costs,
     normalize_dims,
     register_op_strategy,
     register_prop_rule,
@@ -139,6 +140,7 @@
         out_placements = map_placements_after_reduction(
             input_spec.placements, reduce_dims, reduce_dims_map, reduction_op
         )
+        redistribute_cost = [generate_redistribute_costs(input_strategy, input_spec)]
         reduction_strategy.strategies.append(
             PlacementStrategy(
                 output_spec=DTensorSpec(
@@ -146,6 +148,7 @@
                     placements=out_placements,
                 ),
                 input_specs=(input_spec,),
+                redistribute_cost=redistribute_cost,
             )
         )
 
diff --git a/torch/distributed/_tensor/ops/utils.py b/torch/distributed/_tensor/ops/utils.py
index 720b24d..53582e9 100644
--- a/torch/distributed/_tensor/ops/utils.py
+++ b/torch/distributed/_tensor/ops/utils.py
@@ -4,7 +4,9 @@
 from typing import cast, Iterable, List, Sequence, Union
 
 import torch
+from torch.distributed._tensor._collective_utils import redistribute_cost
 from torch.distributed._tensor.api import DTensor
+from torch.distributed._tensor.op_schema import OpStrategy
 from torch.distributed._tensor.placement_types import DTensorSpec, Shard
 
 
@@ -100,3 +102,13 @@
 def is_tensor_partial(spec: DTensorSpec) -> bool:
     """Return True if tensor is partial on the mesh"""
     return any(p.is_partial() for p in spec.placements)
+
+
+def generate_redistribute_costs(
+    src_strategy: OpStrategy, dst_spec: DTensorSpec
+) -> List[float]:
+    redistribute_costs: List[float] = []
+    for strat in src_strategy.strategies:
+        redistribute_costs.append(redistribute_cost(strat.output_spec, dst_spec))
+
+    return redistribute_costs
diff --git a/torch/distributed/_tensor/placement_types.py b/torch/distributed/_tensor/placement_types.py
index f12827b..8d35066 100644
--- a/torch/distributed/_tensor/placement_types.py
+++ b/torch/distributed/_tensor/placement_types.py
@@ -533,3 +533,9 @@
                 placements[m] = Shard(i)
 
         return cls(mesh, tuple(placements), tensor_meta=tensor_meta)
+
+    def is_replicated(self):
+        """
+        return True if the current DTensorSpec replicates on all mesh dims (devices)
+        """
+        return all(placement.is_replicate() for placement in self.placements)
diff --git a/torch/distributed/_tensor/sharding_prop.py b/torch/distributed/_tensor/sharding_prop.py
index 5416844..e47a833 100644
--- a/torch/distributed/_tensor/sharding_prop.py
+++ b/torch/distributed/_tensor/sharding_prop.py
@@ -1,5 +1,6 @@
 from functools import lru_cache
-from typing import Callable, cast, Dict, Optional
+from itertools import chain
+from typing import Callable, cast, Dict, List, Optional
 
 import torch
 from torch._ops import OpOverload
@@ -169,9 +170,8 @@
             op_strategy = self.op_strategy_funcs[op_schema.op](mesh, strategy_schema)
 
             assert isinstance(op_strategy, OpStrategy)
-            # we take the first strategy for now
-            # TODO: add a min cost selection logic
-            output_strategy = op_strategy.strategies[0]
+            output_strategy = self._select_strategy(op_strategy)
+
             needs_redistribute = False
             expected_input_specs = []
             for idx, input_spec in enumerate(op_schema.args_spec):
@@ -264,3 +264,19 @@
             raise NotImplementedError(
                 f"Operator {op_schema.op} does not have a sharding strategy registered."
             )
+
+    def _select_strategy(self, strategy: OpStrategy) -> PlacementStrategy:
+        if len(strategy.strategies) == 1:
+            # short cut with only one possible strategy
+            return strategy.strategies[0]
+
+        strategy_costs: List[float] = []
+        for strtg in strategy.strategies:
+            assert (
+                strtg.redistribute_cost is not None
+            ), "must set redistribute cost each strategy!"
+            redistribute_cost = sum(chain.from_iterable(strtg.redistribute_cost))
+            strategy_costs.append(redistribute_cost)
+
+        # for eager execution, we just select the one with the minimal redistribute cost
+        return strategy.strategies[strategy_costs.index(min(strategy_costs))]