[dtensor] implement distributed topk operator (#126711)
as titled. Implemented the topk operator in DTensor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126711
Approved by: https://github.com/wz337
ghstack dependencies: #126710
diff --git a/test/distributed/_tensor/test_dtensor_ops.py b/test/distributed/_tensor/test_dtensor_ops.py
index d14d2d8..277a0cc 100644
--- a/test/distributed/_tensor/test_dtensor_ops.py
+++ b/test/distributed/_tensor/test_dtensor_ops.py
@@ -468,7 +468,6 @@
xfail("take"),
xfail("tensor_split"),
xfail("to_sparse"),
- xfail("topk"),
xfail("trace"),
xfail("trapezoid"),
xfail("trapz"),
diff --git a/test/distributed/_tensor/test_math_ops.py b/test/distributed/_tensor/test_math_ops.py
index f810278..6469720 100644
--- a/test/distributed/_tensor/test_math_ops.py
+++ b/test/distributed/_tensor/test_math_ops.py
@@ -394,6 +394,33 @@
self.assertEqual(x_local.grad, x_dist.grad.full_tensor())
+ @with_comms
+ def test_topk(self):
+ device_mesh = self.build_device_mesh()
+ placement_combs = [Shard(0), Shard(1), Shard(2), Replicate()]
+
+ comm_mode = CommDebugMode()
+
+ tensor = torch.randn(12, 8, 8, requires_grad=True)
+ global_topk = tensor.topk(3, dim=0)
+
+ for placement in placement_combs:
+ dtensor = distribute_tensor(tensor, device_mesh, (placement,))
+ with comm_mode:
+ out_dt = dtensor.topk(3, dim=0)
+ if placement.is_shard(0):
+ self.assertEqual(comm_mode.get_total_counts(), 1)
+ self.assertEqual(
+ comm_mode.get_comm_counts()[funcol.all_gather_into_tensor],
+ 1,
+ )
+ out_full_values = out_dt.values.full_tensor()
+ self.assertEqual(global_topk.values, out_full_values)
+
+ # TODO: support backward scatter
+ # global_topk.values.sum().backward()
+ # out_full_values.sum().backward()
+
if __name__ == "__main__":
run_tests()
diff --git a/torch/distributed/_tensor/ops/math_ops.py b/torch/distributed/_tensor/ops/math_ops.py
index fa4066e..3b3b46c 100644
--- a/torch/distributed/_tensor/ops/math_ops.py
+++ b/torch/distributed/_tensor/ops/math_ops.py
@@ -1,4 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
+import itertools
import math
from dataclasses import dataclass
from enum import Enum
@@ -17,6 +18,7 @@
as_list,
generate_redistribute_costs,
is_tensor_evenly_shardable,
+ is_tensor_shardable,
normalize_dim,
normalize_dims,
normalize_to_torch_size,
@@ -174,6 +176,31 @@
return reduction_dims_map
+def _replicate_dims_start_at(
+ placements: Sequence[Placement], start_dim: int = 0
+) -> Tuple[Placement, ...]:
+ new_placements: List[Placement] = []
+ for p in placements:
+ if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim):
+ new_placements.append(Replicate()) # make it replicate
+ else:
+ new_placements.append(p) # keep the placement
+ return tuple(new_placements)
+
+
+# return new_placements which align with placements but skip the skipped_dim
+def _skip_dim(
+ placements: Tuple[Placement, ...], skipped_dim: int
+) -> Tuple[Placement, ...]:
+ new_placements: List[Placement] = []
+ for p in placements:
+ if isinstance(p, Shard) and p.dim >= skipped_dim:
+ new_placements.append(Shard(p.dim - 1))
+ else:
+ new_placements.append(p)
+ return tuple(new_placements)
+
+
def replicate_reduction_dims(
placements: Tuple[Placement, ...], reduction_dims: List[int]
) -> Tuple[Placement, ...]:
@@ -954,26 +981,57 @@
return out_tuple_strategy
-def _replicate_dims_start_at(
- placements: Sequence[Placement], start_dim: int = 0
-) -> Tuple[Placement, ...]:
- new_placements: List[Placement] = []
- for p in placements:
- if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim):
- new_placements.append(Replicate()) # make it replicate
- else:
- new_placements.append(p) # keep the placement
- return tuple(new_placements)
+@register_op_strategy(
+ [aten.topk.default],
+ schema_info=RuntimeSchemaInfo(2),
+)
+def topk_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
+ input_strategy = cast(OpStrategy, op_schema.args_schema[0])
+ k = cast(int, op_schema.args_schema[1])
+ input_shape = input_strategy.shape
+ topk_dim = (
+ cast(int, op_schema.args_schema[2]) if len(op_schema.args_schema) > 2 else -1
+ )
+ topk_dim = normalize_dim(topk_dim, input_strategy.ndim)
+ all_mesh_dim_strategies = []
-# return new_placements which align with placements but skip the skipped_dim
-def _skip_dim(
- placements: Tuple[Placement, ...], skipped_dim: int
-) -> Tuple[Placement, ...]:
- new_placements: List[Placement] = []
- for p in placements:
- if isinstance(p, Shard) and p.dim >= skipped_dim:
- new_placements.append(Shard(p.dim - 1))
- else:
- new_placements.append(p)
- return tuple(new_placements)
+ for mesh_dim in range(mesh.ndim):
+ single_mesh_dim_strategies = []
+
+ # two outputs (values, indices), 1 input
+ # replicate always works
+ all_replicate: List[Placement] = [Replicate()] * 3
+ single_mesh_dim_strategies.append(all_replicate)
+
+ # every dim except topk dim should work
+ for dim in range(input_strategy.ndim):
+ if dim != topk_dim:
+ dim_shardings: List[Placement] = [Shard(dim)] * 3
+ single_mesh_dim_strategies.append(dim_shardings)
+
+ # TODO: topk on sharded dim requries non-trival reduction, address it later
+
+ all_mesh_dim_strategies.append(single_mesh_dim_strategies)
+
+ strategy_combs = itertools.product(*all_mesh_dim_strategies)
+
+ all_strategies = []
+ for strategy_comb in strategy_combs:
+ spec_list = []
+ for specs in zip(*strategy_comb):
+ spec_list.append(DTensorSpec(mesh, tuple(specs)))
+
+ input_spec = spec_list[2]
+ if is_tensor_shardable(input_shape, input_spec):
+ redistribute_cost = [
+ generate_redistribute_costs(input_strategy, input_spec)
+ ]
+ strategy = PlacementStrategy(
+ output_specs=tuple(spec_list[:2]),
+ input_specs=(input_spec,),
+ redistribute_cost=redistribute_cost,
+ )
+ all_strategies.append(strategy)
+
+ return OpStrategy(all_strategies)