[dtensor][8/N] Introduce cost model for sharding (#109145)
This PR adds some basic comm cost model for sharding prop
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109145
Approved by: https://github.com/fduwjj
diff --git a/test/distributed/_tensor/test_basic_strategy.py b/test/distributed/_tensor/test_basic_strategy.py
deleted file mode 100644
index 0a5a026..0000000
--- a/test/distributed/_tensor/test_basic_strategy.py
+++ /dev/null
@@ -1,120 +0,0 @@
-# Owner(s): ["oncall: distributed"]
-
-import torch
-from torch.distributed._tensor import DeviceMesh
-from torch.distributed._tensor.ops.basic_strategy import (
- EinsumDims,
- gen_einsum_strategies,
-)
-
-from torch.testing._internal.common_utils import run_tests, TestCase
-from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase
-
-
-class TestEinsumDims(TestCase):
- def test_batch_dims(self):
- equation = "abc,abc->abc"
- input_dims, output_dim = EinsumDims.parse_equation(equation)
- edims = EinsumDims.parse_dims(input_dims, output_dim)
-
- self.assertEqual(edims.batch_dims, ["a", "b", "c"])
- self.assertEqual(edims.contracting_dims, [])
- self.assertEqual(edims.lhs_out_only_dims, [])
- self.assertEqual(edims.rhs_out_only_dims, [])
-
- def test_mm_dims(self):
- equation = "mk,kn->mn"
- input_dims, output_dim = EinsumDims.parse_equation(equation)
- edims = EinsumDims.parse_dims(input_dims, output_dim)
-
- self.assertEqual(edims.batch_dims, [])
- self.assertEqual(edims.contracting_dims, ["k"])
- self.assertEqual(edims.lhs_out_only_dims, ["m"])
- self.assertEqual(edims.rhs_out_only_dims, ["n"])
-
- def test_bmm_dims(self):
- equation = "bmk,bkn->bmn"
- input_dims, output_dim = EinsumDims.parse_equation(equation)
- edims = EinsumDims.parse_dims(input_dims, output_dim)
-
- self.assertEqual(edims.batch_dims, ["b"])
- self.assertEqual(edims.contracting_dims, ["k"])
- self.assertEqual(edims.lhs_out_only_dims, ["m"])
- self.assertEqual(edims.rhs_out_only_dims, ["n"])
-
- equation = "bcmk,bckn->bcmn"
- input_dims, output_dim = EinsumDims.parse_equation(equation)
- edims = EinsumDims.parse_dims(input_dims, output_dim)
-
- self.assertEqual(edims.batch_dims, ["b", "c"])
- self.assertEqual(edims.contracting_dims, ["k"])
- self.assertEqual(edims.lhs_out_only_dims, ["m"])
- self.assertEqual(edims.rhs_out_only_dims, ["n"])
-
- def test_free_dims(self):
- equation = "abc,ab->abc"
- input_dims, output_dim = EinsumDims.parse_equation(equation)
- edims = EinsumDims.parse_dims(input_dims, output_dim)
-
- self.assertEqual(edims.batch_dims, ["a", "b"])
- self.assertEqual(edims.contracting_dims, [])
- self.assertEqual(edims.lhs_out_only_dims, ["c"])
- self.assertEqual(edims.rhs_out_only_dims, [])
-
- equation = "abd,bf->abfd"
- input_dims, output_dim = EinsumDims.parse_equation(equation)
- edims = EinsumDims.parse_dims(input_dims, output_dim)
-
- self.assertEqual(edims.batch_dims, ["b"])
- self.assertEqual(edims.contracting_dims, [])
- self.assertEqual(edims.lhs_out_only_dims, ["a", "d"])
- self.assertEqual(edims.rhs_out_only_dims, ["f"])
-
-
-class TestEinsumStrategies(DTensorOpTestBase):
- @property
- def world_size(self) -> int:
- return 4
-
- def test_mm_1d_mesh(self):
- mesh = self.build_device_mesh()
-
- all_strats = gen_einsum_strategies("mk,kn->mn", mesh)
- self.assertEqual(len(all_strats.strategies), 4)
-
- def test_mm_2d_mesh(self):
- mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, 2))
-
- all_strats = gen_einsum_strategies("mk,kn->mn", mesh)
- self.assertEqual(len(all_strats.strategies), 16)
-
- def test_bmm_1d_mesh(self):
- mesh = self.build_device_mesh()
-
- all_strats = gen_einsum_strategies("bmk,bkn->bmn", mesh)
- self.assertEqual(len(all_strats.strategies), 5)
-
- def test_bmm_2d_mesh(self):
- mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, 2))
-
- all_strats = gen_einsum_strategies("bmk,bkn->bmn", mesh)
- self.assertEqual(len(all_strats.strategies), 25)
-
- def test_pointwise_1d_mesh(self):
- mesh = self.build_device_mesh()
-
- simple_strats = gen_einsum_strategies("abcd,abcd->abcd", mesh)
- self.assertEqual(len(simple_strats.strategies), 5)
-
- broadcast_strats = gen_einsum_strategies("bcd,abcd->abcd", mesh)
- self.assertEqual(len(broadcast_strats.strategies), 5)
-
- def test_linearity_1d_mesh(self):
- mesh = self.build_device_mesh()
-
- all_strats = gen_einsum_strategies("abcd,abcd->abcd", mesh, linearity=True)
- self.assertEqual(len(all_strats.strategies), 6)
-
-
-if __name__ == "__main__":
- run_tests()
diff --git a/test/distributed/_tensor/test_op_strategy.py b/test/distributed/_tensor/test_op_strategy.py
new file mode 100644
index 0000000..c13f2c4
--- /dev/null
+++ b/test/distributed/_tensor/test_op_strategy.py
@@ -0,0 +1,202 @@
+# Owner(s): ["oncall: distributed"]
+
+import torch
+from torch.distributed._tensor import DeviceMesh
+from torch.distributed._tensor._collective_utils import redistribute_cost
+from torch.distributed._tensor.ops.basic_strategy import (
+ EinsumDims,
+ gen_einsum_strategies,
+)
+from torch.distributed._tensor.placement_types import (
+ _Partial,
+ DTensorSpec,
+ Replicate,
+ Shard,
+ TensorMeta,
+)
+
+from torch.testing._internal.common_utils import run_tests, TestCase
+from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase
+
+
+class TestEinsumDims(TestCase):
+ def test_batch_dims(self):
+ equation = "abc,abc->abc"
+ input_dims, output_dim = EinsumDims.parse_equation(equation)
+ edims = EinsumDims.parse_dims(input_dims, output_dim)
+
+ self.assertEqual(edims.batch_dims, ["a", "b", "c"])
+ self.assertEqual(edims.contracting_dims, [])
+ self.assertEqual(edims.lhs_out_only_dims, [])
+ self.assertEqual(edims.rhs_out_only_dims, [])
+
+ def test_mm_dims(self):
+ equation = "mk,kn->mn"
+ input_dims, output_dim = EinsumDims.parse_equation(equation)
+ edims = EinsumDims.parse_dims(input_dims, output_dim)
+
+ self.assertEqual(edims.batch_dims, [])
+ self.assertEqual(edims.contracting_dims, ["k"])
+ self.assertEqual(edims.lhs_out_only_dims, ["m"])
+ self.assertEqual(edims.rhs_out_only_dims, ["n"])
+
+ def test_bmm_dims(self):
+ equation = "bmk,bkn->bmn"
+ input_dims, output_dim = EinsumDims.parse_equation(equation)
+ edims = EinsumDims.parse_dims(input_dims, output_dim)
+
+ self.assertEqual(edims.batch_dims, ["b"])
+ self.assertEqual(edims.contracting_dims, ["k"])
+ self.assertEqual(edims.lhs_out_only_dims, ["m"])
+ self.assertEqual(edims.rhs_out_only_dims, ["n"])
+
+ equation = "bcmk,bckn->bcmn"
+ input_dims, output_dim = EinsumDims.parse_equation(equation)
+ edims = EinsumDims.parse_dims(input_dims, output_dim)
+
+ self.assertEqual(edims.batch_dims, ["b", "c"])
+ self.assertEqual(edims.contracting_dims, ["k"])
+ self.assertEqual(edims.lhs_out_only_dims, ["m"])
+ self.assertEqual(edims.rhs_out_only_dims, ["n"])
+
+ def test_free_dims(self):
+ equation = "abc,ab->abc"
+ input_dims, output_dim = EinsumDims.parse_equation(equation)
+ edims = EinsumDims.parse_dims(input_dims, output_dim)
+
+ self.assertEqual(edims.batch_dims, ["a", "b"])
+ self.assertEqual(edims.contracting_dims, [])
+ self.assertEqual(edims.lhs_out_only_dims, ["c"])
+ self.assertEqual(edims.rhs_out_only_dims, [])
+
+ equation = "abd,bf->abfd"
+ input_dims, output_dim = EinsumDims.parse_equation(equation)
+ edims = EinsumDims.parse_dims(input_dims, output_dim)
+
+ self.assertEqual(edims.batch_dims, ["b"])
+ self.assertEqual(edims.contracting_dims, [])
+ self.assertEqual(edims.lhs_out_only_dims, ["a", "d"])
+ self.assertEqual(edims.rhs_out_only_dims, ["f"])
+
+
+class TestEinsumStrategies(DTensorOpTestBase):
+ @property
+ def world_size(self) -> int:
+ return 4
+
+ def test_mm_1d_mesh(self):
+ mesh = self.build_device_mesh()
+
+ all_strats = gen_einsum_strategies("mk,kn->mn", mesh)
+ self.assertEqual(len(all_strats.strategies), 4)
+
+ def test_mm_2d_mesh(self):
+ mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, 2))
+
+ all_strats = gen_einsum_strategies("mk,kn->mn", mesh)
+ self.assertEqual(len(all_strats.strategies), 16)
+
+ def test_bmm_1d_mesh(self):
+ mesh = self.build_device_mesh()
+
+ all_strats = gen_einsum_strategies("bmk,bkn->bmn", mesh)
+ self.assertEqual(len(all_strats.strategies), 5)
+
+ def test_bmm_2d_mesh(self):
+ mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, 2))
+
+ all_strats = gen_einsum_strategies("bmk,bkn->bmn", mesh)
+ self.assertEqual(len(all_strats.strategies), 25)
+
+ def test_pointwise_1d_mesh(self):
+ mesh = self.build_device_mesh()
+
+ simple_strats = gen_einsum_strategies("abcd,abcd->abcd", mesh)
+ self.assertEqual(len(simple_strats.strategies), 5)
+
+ broadcast_strats = gen_einsum_strategies("bcd,abcd->abcd", mesh)
+ self.assertEqual(len(broadcast_strats.strategies), 5)
+
+ def test_linearity_1d_mesh(self):
+ mesh = self.build_device_mesh()
+
+ all_strats = gen_einsum_strategies("abcd,abcd->abcd", mesh, linearity=True)
+ self.assertEqual(len(all_strats.strategies), 6)
+
+
+class TestCostModel(DTensorOpTestBase):
+ def _extract_tensor_meta(self, t) -> TensorMeta:
+ return TensorMeta(t.shape, t.stride(), t.dtype)
+
+ @property
+ def world_size(self) -> int:
+ return 4
+
+ def test_redistribute_cost_mesh_1d(self):
+ mesh_1d = self.build_device_mesh()
+ shard_placement = (Shard(0),)
+ replica_placement = (Replicate(),)
+ partial_placement = (_Partial(),)
+
+ global_tensor = torch.randn(10, 10)
+ global_tensor_meta = self._extract_tensor_meta(global_tensor)
+
+ # shard spec
+ shard_spec = DTensorSpec(mesh_1d, shard_placement, global_tensor_meta)
+ # replica spec
+ replica_spec = DTensorSpec(mesh_1d, replica_placement, global_tensor_meta)
+ # partial spec
+ partial_spec = DTensorSpec(mesh_1d, partial_placement, global_tensor_meta)
+
+ # make sure reshard cost is 0 for the same spec redistribute
+ for spec in [shard_spec, replica_spec, partial_spec]:
+ cost = redistribute_cost(spec, spec)
+ self.assertEqual(cost, 0)
+
+ # shard -> replicate
+ allgather_cost = redistribute_cost(shard_spec, replica_spec)
+ # partial -> shard
+ reduce_scatter_cost = redistribute_cost(partial_spec, shard_spec)
+ # partial -> replicate
+ allreduce_cost = redistribute_cost(partial_spec, replica_spec)
+ self.assertEqual(allgather_cost, reduce_scatter_cost)
+ self.assertEqual(allreduce_cost + 1, allgather_cost + reduce_scatter_cost)
+ # shard to partial
+ cost = redistribute_cost(shard_spec, partial_spec)
+ self.assertEqual(cost, float("inf"))
+
+ def test_redistribute_cost_mesh_2d(self):
+ mesh_2d = DeviceMesh(
+ self.device_type, torch.arange(self.world_size).reshape(2, 2)
+ )
+ shard_placement = (Shard(0), Shard(0))
+ replica_placement = (Replicate(), Replicate())
+ partial_placement = (_Partial(), _Partial())
+
+ global_tensor = torch.randn(8, 8)
+ global_tensor_meta = self._extract_tensor_meta(global_tensor)
+
+ # shard spec
+ shard_spec = DTensorSpec(mesh_2d, shard_placement, global_tensor_meta)
+ # replica spec
+ replica_spec = DTensorSpec(mesh_2d, replica_placement, global_tensor_meta)
+ # partial spec
+ partial_spec = DTensorSpec(mesh_2d, partial_placement, global_tensor_meta)
+
+ # make sure reshard cost is 0 for the same spec redistribute
+ for spec in [shard_spec, replica_spec, partial_spec]:
+ cost = redistribute_cost(spec, spec)
+ self.assertEqual(cost, 0)
+
+ # shard -> replicate
+ allgather_cost = redistribute_cost(shard_spec, replica_spec)
+ # partial -> replicate
+ allreduce_cost = redistribute_cost(partial_spec, replica_spec)
+ # partial -> shard
+ reduce_scatter_cost = redistribute_cost(partial_spec, shard_spec)
+ self.assertTrue(allreduce_cost > allgather_cost)
+ self.assertTrue(allreduce_cost > reduce_scatter_cost)
+
+
+if __name__ == "__main__":
+ run_tests()
diff --git a/torch/distributed/_tensor/_collective_utils.py b/torch/distributed/_tensor/_collective_utils.py
index f5b99d8..f954021 100644
--- a/torch/distributed/_tensor/_collective_utils.py
+++ b/torch/distributed/_tensor/_collective_utils.py
@@ -1,9 +1,11 @@
import logging
+import math
from typing import List, Optional
import torch
-from torch.distributed._tensor.device_mesh import DeviceMesh
+import torch.distributed._tensor.placement_types as placement_types
+from torch.distributed._tensor.device_mesh import DeviceMesh, mesh_resources
from torch.distributed.distributed_c10d import (
all_to_all,
broadcast,
@@ -158,3 +160,128 @@
async_op=async_op,
)
return work
+
+
+def spec_to_bytes(spec: "placement_types.DTensorSpec") -> int:
+ assert spec.tensor_meta is not None, "spec should have tensor meta defined!"
+ return spec.tensor_meta.dtype.itemsize * math.prod(spec.shape)
+
+
+def get_bandwidth_factor(mesh: DeviceMesh) -> List[float]:
+ # generate bandwidth factor for intra-host/inter-host communication pattern
+ factors = [1.0] * mesh.ndim
+ num_devices_per_host = mesh_resources.num_devices_per_host(mesh.device_type)
+
+ num_devices = 1
+ for mesh_dim in reversed(range(mesh.ndim)):
+ num_devices *= mesh.size(mesh_dim)
+ if num_devices <= num_devices_per_host:
+ # magic number for intra-host communication bandwidth factor
+ # TODO: see if we need to tweak this or offer a way for user
+ # to specify the bandwidths
+ factors[mesh_dim] = 0.2
+
+ return factors
+
+
+def allgather_cost(num_bytes: float, mesh: DeviceMesh, mesh_dim: int) -> float:
+ num_devices_on_mesh_dim = mesh.size(mesh_dim)
+ bandwidth_factor = get_bandwidth_factor(mesh)[mesh_dim]
+ # constant latency factor + bandwidth cost
+ return (
+ 1
+ + bandwidth_factor
+ * num_bytes
+ * (num_devices_on_mesh_dim - 1)
+ / num_devices_on_mesh_dim
+ )
+
+
+def allreduce_cost(num_bytes: float, mesh: DeviceMesh, mesh_dim: int) -> float:
+ num_devices_on_mesh_dim = mesh.size(mesh_dim)
+ bandwidth_factor = get_bandwidth_factor(mesh)[mesh_dim]
+ # allreduce have 2x comm bytes compare to allgather/reduce_scatter
+ return (
+ 1
+ + 2
+ * bandwidth_factor
+ * num_bytes
+ * (num_devices_on_mesh_dim - 1)
+ / num_devices_on_mesh_dim
+ )
+
+
+def reduce_scatter_cost(
+ num_bytes: float,
+ mesh: DeviceMesh,
+ mesh_dim: int,
+) -> float:
+ num_devices_on_mesh_dim = mesh.size(mesh_dim)
+ bandwidth_factor = get_bandwidth_factor(mesh)[mesh_dim]
+ # constant latency factor + bandwidth cost
+ return (
+ 1
+ + bandwidth_factor
+ * num_bytes
+ * (num_devices_on_mesh_dim - 1)
+ / num_devices_on_mesh_dim
+ )
+
+
+def redistribute_cost(
+ current_spec: "placement_types.DTensorSpec",
+ target_spec: "placement_types.DTensorSpec",
+) -> float:
+ """
+ This function returns the cost of redistribute from current to target DTensorSpec.
+
+ NOTE:
+ 1. Only consider communication cost here, since computation costs for redistribute
+ are quite trival (i.e. we only need to narrow or simple division)
+ 2. Only consider redistribute cost on same mesh, cross mesh communication cost is
+ not quite needed for operator strategy estimation/selection.
+ """
+ if current_spec.mesh != target_spec.mesh:
+ # make infinite cost if meshes are not same
+ # TODO: see if we want to support this once there's cross mesh communication
+ return float("inf")
+
+ if current_spec.is_replicated():
+ # short-cut:
+ # comm cost is 0 if current spec is already full replication
+ return 0.0
+
+ mesh = current_spec.mesh
+ cost = 0.0
+ comm_bytes = spec_to_bytes(current_spec) / current_spec.num_shards
+ # Transformation that considered for redistribute cost:
+ # 1. allgather 2. alltoall
+ # 3. allreduce 4. reduce_scatter
+ for i, (current, target) in enumerate(
+ zip(current_spec.placements, target_spec.placements)
+ ):
+ if current == target:
+ continue
+ if current.is_shard() and target.is_replicate():
+ # allgather gives larger comm bytes
+ comm_bytes *= mesh.size(i)
+ # add up allgather comm cost
+ cost += allgather_cost(comm_bytes, current_spec.mesh, i)
+ elif current.is_shard() and target.is_shard():
+ # should be alltoall comm, since we haven't implement it yet, add penalty
+ # to favor allgather instead
+ cost += allgather_cost(comm_bytes, current_spec.mesh, i) + 1.0
+ elif current.is_partial() and target.is_replicate():
+ # add up allreduce comm cost
+ cost += allreduce_cost(comm_bytes, current_spec.mesh, i)
+ elif current.is_partial() and target.is_shard():
+ # add up reduce_scatter comm cost
+ cost += reduce_scatter_cost(comm_bytes, current_spec.mesh, i)
+ # after reduce_scatter the comm bytes for further collectives halved.
+ comm_bytes /= mesh.size(i)
+ elif current.is_shard() and target.is_partial():
+ # ban shard -> partial as it does not make sense to perform
+ # this redistribute
+ return float("inf")
+
+ return cost
diff --git a/torch/distributed/_tensor/device_mesh.py b/torch/distributed/_tensor/device_mesh.py
index 75a32b3..c61f28d 100644
--- a/torch/distributed/_tensor/device_mesh.py
+++ b/torch/distributed/_tensor/device_mesh.py
@@ -85,6 +85,16 @@
return parent_mesh.mesh_dim_names.index(child_mesh_dim_name)
return None
+ @staticmethod
+ def num_devices_per_host(device_type: str) -> int:
+ return _get_device_handle(device_type).device_count()
+
+ @staticmethod
+ def num_hosts(device_type: str) -> int:
+ # ProcessGroup can't tell us this info so we have to infer it, assume
+ # homogeneous hardware for now
+ return get_world_size() // _MeshEnv.num_devices_per_host(device_type)
+
mesh_resources: _MeshEnv = _MeshEnv()
diff --git a/torch/distributed/_tensor/op_schema.py b/torch/distributed/_tensor/op_schema.py
index 62f4614..ec3a759 100644
--- a/torch/distributed/_tensor/op_schema.py
+++ b/torch/distributed/_tensor/op_schema.py
@@ -58,6 +58,12 @@
output_spec: DTensorSpec
input_specs: Optional[Sequence[DTensorSpec]] = None
+ # redistribute costs for this op placement strategy
+ # we need a nested list to record the cost for each
+ # operand of this operator, and for each operand of
+ # this operator it might have multiple placement strategies
+ redistribute_cost: Optional[List[List[float]]] = None
+
def pretty_print_placements(self, placements):
return "".join([str(p) for p in placements])
diff --git a/torch/distributed/_tensor/ops/embedding_ops.py b/torch/distributed/_tensor/ops/embedding_ops.py
index 6ceb3d1..ea033af 100644
--- a/torch/distributed/_tensor/ops/embedding_ops.py
+++ b/torch/distributed/_tensor/ops/embedding_ops.py
@@ -24,16 +24,14 @@
"DTensor does not support row-wise sharded embedding operation yet!"
)
- if all(
- placement.is_replicate() for placement in weight_spec.placements
- ) and inp_spec.placements == [Shard(0)]:
+ if weight_spec.is_replicated() and inp_spec.placements == [Shard(0)]:
# Embedding table is replicated, input ids are sharded along batch
# dimension. Output lookups should match input sharding spec in this case.
return OutputSharding(
output_spec=DTensorSpec(mesh=inp_spec.mesh, placements=inp_spec.placements)
)
- if all(placement.is_replicate() for placement in inp_spec.placements):
+ if inp_spec.is_replicated():
weight_dim_map = weight_spec.dim_map
output_dim_map = inp_spec.dim_map
output_dim_map.append(weight_dim_map[1])
diff --git a/torch/distributed/_tensor/ops/math_ops.py b/torch/distributed/_tensor/ops/math_ops.py
index d42426a..9c52e33 100644
--- a/torch/distributed/_tensor/ops/math_ops.py
+++ b/torch/distributed/_tensor/ops/math_ops.py
@@ -15,6 +15,7 @@
from torch.distributed._tensor.ops.common_rules import pointwise_rule
from torch.distributed._tensor.ops.utils import (
as_list,
+ generate_redistribute_costs,
normalize_dims,
register_op_strategy,
register_prop_rule,
@@ -139,6 +140,7 @@
out_placements = map_placements_after_reduction(
input_spec.placements, reduce_dims, reduce_dims_map, reduction_op
)
+ redistribute_cost = [generate_redistribute_costs(input_strategy, input_spec)]
reduction_strategy.strategies.append(
PlacementStrategy(
output_spec=DTensorSpec(
@@ -146,6 +148,7 @@
placements=out_placements,
),
input_specs=(input_spec,),
+ redistribute_cost=redistribute_cost,
)
)
diff --git a/torch/distributed/_tensor/ops/utils.py b/torch/distributed/_tensor/ops/utils.py
index 720b24d..53582e9 100644
--- a/torch/distributed/_tensor/ops/utils.py
+++ b/torch/distributed/_tensor/ops/utils.py
@@ -4,7 +4,9 @@
from typing import cast, Iterable, List, Sequence, Union
import torch
+from torch.distributed._tensor._collective_utils import redistribute_cost
from torch.distributed._tensor.api import DTensor
+from torch.distributed._tensor.op_schema import OpStrategy
from torch.distributed._tensor.placement_types import DTensorSpec, Shard
@@ -100,3 +102,13 @@
def is_tensor_partial(spec: DTensorSpec) -> bool:
"""Return True if tensor is partial on the mesh"""
return any(p.is_partial() for p in spec.placements)
+
+
+def generate_redistribute_costs(
+ src_strategy: OpStrategy, dst_spec: DTensorSpec
+) -> List[float]:
+ redistribute_costs: List[float] = []
+ for strat in src_strategy.strategies:
+ redistribute_costs.append(redistribute_cost(strat.output_spec, dst_spec))
+
+ return redistribute_costs
diff --git a/torch/distributed/_tensor/placement_types.py b/torch/distributed/_tensor/placement_types.py
index f12827b..8d35066 100644
--- a/torch/distributed/_tensor/placement_types.py
+++ b/torch/distributed/_tensor/placement_types.py
@@ -533,3 +533,9 @@
placements[m] = Shard(i)
return cls(mesh, tuple(placements), tensor_meta=tensor_meta)
+
+ def is_replicated(self):
+ """
+ return True if the current DTensorSpec replicates on all mesh dims (devices)
+ """
+ return all(placement.is_replicate() for placement in self.placements)
diff --git a/torch/distributed/_tensor/sharding_prop.py b/torch/distributed/_tensor/sharding_prop.py
index 5416844..e47a833 100644
--- a/torch/distributed/_tensor/sharding_prop.py
+++ b/torch/distributed/_tensor/sharding_prop.py
@@ -1,5 +1,6 @@
from functools import lru_cache
-from typing import Callable, cast, Dict, Optional
+from itertools import chain
+from typing import Callable, cast, Dict, List, Optional
import torch
from torch._ops import OpOverload
@@ -169,9 +170,8 @@
op_strategy = self.op_strategy_funcs[op_schema.op](mesh, strategy_schema)
assert isinstance(op_strategy, OpStrategy)
- # we take the first strategy for now
- # TODO: add a min cost selection logic
- output_strategy = op_strategy.strategies[0]
+ output_strategy = self._select_strategy(op_strategy)
+
needs_redistribute = False
expected_input_specs = []
for idx, input_spec in enumerate(op_schema.args_spec):
@@ -264,3 +264,19 @@
raise NotImplementedError(
f"Operator {op_schema.op} does not have a sharding strategy registered."
)
+
+ def _select_strategy(self, strategy: OpStrategy) -> PlacementStrategy:
+ if len(strategy.strategies) == 1:
+ # short cut with only one possible strategy
+ return strategy.strategies[0]
+
+ strategy_costs: List[float] = []
+ for strtg in strategy.strategies:
+ assert (
+ strtg.redistribute_cost is not None
+ ), "must set redistribute cost each strategy!"
+ redistribute_cost = sum(chain.from_iterable(strtg.redistribute_cost))
+ strategy_costs.append(redistribute_cost)
+
+ # for eager execution, we just select the one with the minimal redistribute cost
+ return strategy.strategies[strategy_costs.index(min(strategy_costs))]