[dtensor] implement distributed topk operator (#126711)

as titled. Implemented the topk operator in DTensor.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126711
Approved by: https://github.com/wz337
ghstack dependencies: #126710
diff --git a/test/distributed/_tensor/test_dtensor_ops.py b/test/distributed/_tensor/test_dtensor_ops.py
index d14d2d8..277a0cc 100644
--- a/test/distributed/_tensor/test_dtensor_ops.py
+++ b/test/distributed/_tensor/test_dtensor_ops.py
@@ -468,7 +468,6 @@
     xfail("take"),
     xfail("tensor_split"),
     xfail("to_sparse"),
-    xfail("topk"),
     xfail("trace"),
     xfail("trapezoid"),
     xfail("trapz"),
diff --git a/test/distributed/_tensor/test_math_ops.py b/test/distributed/_tensor/test_math_ops.py
index f810278..6469720 100644
--- a/test/distributed/_tensor/test_math_ops.py
+++ b/test/distributed/_tensor/test_math_ops.py
@@ -394,6 +394,33 @@
 
             self.assertEqual(x_local.grad, x_dist.grad.full_tensor())
 
+    @with_comms
+    def test_topk(self):
+        device_mesh = self.build_device_mesh()
+        placement_combs = [Shard(0), Shard(1), Shard(2), Replicate()]
+
+        comm_mode = CommDebugMode()
+
+        tensor = torch.randn(12, 8, 8, requires_grad=True)
+        global_topk = tensor.topk(3, dim=0)
+
+        for placement in placement_combs:
+            dtensor = distribute_tensor(tensor, device_mesh, (placement,))
+            with comm_mode:
+                out_dt = dtensor.topk(3, dim=0)
+            if placement.is_shard(0):
+                self.assertEqual(comm_mode.get_total_counts(), 1)
+                self.assertEqual(
+                    comm_mode.get_comm_counts()[funcol.all_gather_into_tensor],
+                    1,
+                )
+            out_full_values = out_dt.values.full_tensor()
+            self.assertEqual(global_topk.values, out_full_values)
+
+            # TODO: support backward scatter
+            # global_topk.values.sum().backward()
+            # out_full_values.sum().backward()
+
 
 if __name__ == "__main__":
     run_tests()
diff --git a/torch/distributed/_tensor/ops/math_ops.py b/torch/distributed/_tensor/ops/math_ops.py
index fa4066e..3b3b46c 100644
--- a/torch/distributed/_tensor/ops/math_ops.py
+++ b/torch/distributed/_tensor/ops/math_ops.py
@@ -1,4 +1,5 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates
+import itertools
 import math
 from dataclasses import dataclass
 from enum import Enum
@@ -17,6 +18,7 @@
     as_list,
     generate_redistribute_costs,
     is_tensor_evenly_shardable,
+    is_tensor_shardable,
     normalize_dim,
     normalize_dims,
     normalize_to_torch_size,
@@ -174,6 +176,31 @@
     return reduction_dims_map
 
 
+def _replicate_dims_start_at(
+    placements: Sequence[Placement], start_dim: int = 0
+) -> Tuple[Placement, ...]:
+    new_placements: List[Placement] = []
+    for p in placements:
+        if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim):
+            new_placements.append(Replicate())  # make it replicate
+        else:
+            new_placements.append(p)  # keep the placement
+    return tuple(new_placements)
+
+
+# return new_placements which align with placements but skip the skipped_dim
+def _skip_dim(
+    placements: Tuple[Placement, ...], skipped_dim: int
+) -> Tuple[Placement, ...]:
+    new_placements: List[Placement] = []
+    for p in placements:
+        if isinstance(p, Shard) and p.dim >= skipped_dim:
+            new_placements.append(Shard(p.dim - 1))
+        else:
+            new_placements.append(p)
+    return tuple(new_placements)
+
+
 def replicate_reduction_dims(
     placements: Tuple[Placement, ...], reduction_dims: List[int]
 ) -> Tuple[Placement, ...]:
@@ -954,26 +981,57 @@
     return out_tuple_strategy
 
 
-def _replicate_dims_start_at(
-    placements: Sequence[Placement], start_dim: int = 0
-) -> Tuple[Placement, ...]:
-    new_placements: List[Placement] = []
-    for p in placements:
-        if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim):
-            new_placements.append(Replicate())  # make it replicate
-        else:
-            new_placements.append(p)  # keep the placement
-    return tuple(new_placements)
+@register_op_strategy(
+    [aten.topk.default],
+    schema_info=RuntimeSchemaInfo(2),
+)
+def topk_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy:
+    input_strategy = cast(OpStrategy, op_schema.args_schema[0])
+    k = cast(int, op_schema.args_schema[1])
+    input_shape = input_strategy.shape
+    topk_dim = (
+        cast(int, op_schema.args_schema[2]) if len(op_schema.args_schema) > 2 else -1
+    )
+    topk_dim = normalize_dim(topk_dim, input_strategy.ndim)
 
+    all_mesh_dim_strategies = []
 
-# return new_placements which align with placements but skip the skipped_dim
-def _skip_dim(
-    placements: Tuple[Placement, ...], skipped_dim: int
-) -> Tuple[Placement, ...]:
-    new_placements: List[Placement] = []
-    for p in placements:
-        if isinstance(p, Shard) and p.dim >= skipped_dim:
-            new_placements.append(Shard(p.dim - 1))
-        else:
-            new_placements.append(p)
-    return tuple(new_placements)
+    for mesh_dim in range(mesh.ndim):
+        single_mesh_dim_strategies = []
+
+        # two outputs (values, indices), 1 input
+        # replicate always works
+        all_replicate: List[Placement] = [Replicate()] * 3
+        single_mesh_dim_strategies.append(all_replicate)
+
+        # every dim except topk dim should work
+        for dim in range(input_strategy.ndim):
+            if dim != topk_dim:
+                dim_shardings: List[Placement] = [Shard(dim)] * 3
+                single_mesh_dim_strategies.append(dim_shardings)
+
+            # TODO: topk on sharded dim requries non-trival reduction, address it later
+
+        all_mesh_dim_strategies.append(single_mesh_dim_strategies)
+
+    strategy_combs = itertools.product(*all_mesh_dim_strategies)
+
+    all_strategies = []
+    for strategy_comb in strategy_combs:
+        spec_list = []
+        for specs in zip(*strategy_comb):
+            spec_list.append(DTensorSpec(mesh, tuple(specs)))
+
+        input_spec = spec_list[2]
+        if is_tensor_shardable(input_shape, input_spec):
+            redistribute_cost = [
+                generate_redistribute_costs(input_strategy, input_spec)
+            ]
+            strategy = PlacementStrategy(
+                output_specs=tuple(spec_list[:2]),
+                input_specs=(input_spec,),
+                redistribute_cost=redistribute_cost,
+            )
+            all_strategies.append(strategy)
+
+    return OpStrategy(all_strategies)