[tp] refactor and fix PrepareModuleInput for DTensor inputs (#128431)

as titled, this PR refactors the PrepareModuleInput style to have common
method prepare_input_arg, allow both args/kwargs to reuse this logic

This also fixes https://github.com/pytorch/pytorch/issues/128365

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128431
Approved by: https://github.com/awgu
diff --git a/test/distributed/tensor/parallel/test_tp_style.py b/test/distributed/tensor/parallel/test_tp_style.py
index e2a9a01..776bdc9 100644
--- a/test/distributed/tensor/parallel/test_tp_style.py
+++ b/test/distributed/tensor/parallel/test_tp_style.py
@@ -317,6 +317,18 @@
         self.assertEqual(comm_mode.get_total_counts(), 2)
         self.assertEqual(output.shape, (1 * self.world_size, 8))
 
+        # test the case where x is a DTensor
+        x_dt = DTensor.from_local(
+            torch.randn(1, 8, device=self.device_type), mesh, [Shard(0)]
+        )
+        with comm_mode:
+            output = test_kwonly_mod(
+                x=x_dt, z=torch.ones(1, 8, device=self.device_type)
+            )
+
+        self.assertEqual(comm_mode.get_total_counts(), 2)
+        self.assertEqual(output.shape, (1 * self.world_size, 8))
+
     @with_comms
     def test_prepare_module_output(self):
         mesh = init_device_mesh(self.device_type, (self.world_size,))
diff --git a/torch/distributed/tensor/parallel/style.py b/torch/distributed/tensor/parallel/style.py
index f532b97..00d85bf 100644
--- a/torch/distributed/tensor/parallel/style.py
+++ b/torch/distributed/tensor/parallel/style.py
@@ -1,7 +1,7 @@
 # mypy: allow-untyped-defs
 # Copyright (c) Meta Platforms, Inc. and affiliates
 from abc import ABC, abstractmethod
-from typing import Optional, Union, Tuple, Dict
+from typing import Optional, Union, Tuple, Dict, Any
 from functools import partial
 
 import torch
@@ -400,6 +400,23 @@
             assert len(self.input_kwarg_layouts) == len(self.desired_input_kwarg_layouts), \
                 "input_kwarg_layouts and desired_input_kwarg_layouts should have same length!"
 
+    def _prepare_input_arg(self, input: Any, mesh: DeviceMesh, input_layout: Placement, desired_layout: Placement):
+        if input_layout is not None:
+            if isinstance(input, DTensor):
+                # TODO: re-enable the check once we fix the compile path
+                # assert inp.placements[0] == input_layout
+                dt_inp = input
+            else:
+                assert isinstance(input, torch.Tensor), "expecting input to be a torch.Tensor!"
+                dt_inp = DTensor.from_local(input, mesh, (input_layout,), run_check=False)
+
+            if desired_layout is not None and input_layout != desired_layout:
+                dt_inp = dt_inp.redistribute(placements=(desired_layout,))
+
+            return dt_inp.to_local() if self.use_local_output else dt_inp
+        else:
+            return input
+
     def _prepare_input_fn(self, inputs, device_mesh):
         if self.input_layouts is None:
             return inputs
@@ -409,21 +426,8 @@
         if len(inputs) != len(self.input_layouts):
             raise ValueError("module inputs and input_layouts should have same length!")
 
-        assert self.desired_input_layouts is not None, "desired module inputs should not be None!"
         for inp, input_layout, desired_layout in zip(inputs, self.input_layouts, self.desired_input_layouts):
-            if input_layout is not None:
-                if isinstance(inp, DTensor):
-                    # TODO: re-enable the check once we fix the compile path
-                    # assert inp.placements[0] == input_layout
-                    dt_inp = inp
-                else:
-                    dt_inp = DTensor.from_local(inp, device_mesh, (input_layout,), run_check=False)
-
-                if desired_layout is not None and input_layout != desired_layout:
-                    dt_inp = dt_inp.redistribute(placements=(desired_layout,))
-                prepared_inputs.append(dt_inp.to_local() if self.use_local_output else dt_inp)
-            else:
-                prepared_inputs.append(inp)
+            prepared_inputs.append(self._prepare_input_arg(inp, device_mesh, input_layout, desired_layout))
         return tuple(prepared_inputs)
 
     def _prepare_input_kwarg_fn(self, inputs, kwarg_inputs, device_mesh):
@@ -431,20 +435,10 @@
         prepared_kwarg_inputs = {}
         for kwarg_key in kwarg_inputs.keys():
             kwarg_val = kwarg_inputs[kwarg_key]
-            input_layout = None
-            if kwarg_key in self.input_kwarg_layouts:
-                input_layout = self.input_kwarg_layouts[kwarg_key]
-                assert isinstance(kwarg_val, torch.Tensor), f"input of key {kwarg_key} to the module should be a Tensor!"
-                kwarg_val = DTensor.from_local(kwarg_val, device_mesh, (input_layout,), run_check=False)
+            input_layout = self.input_kwarg_layouts.get(kwarg_key)
+            desired_input_layout = self.desired_input_kwarg_layouts.get(kwarg_key)
 
-                if kwarg_key in self.desired_input_kwarg_layouts:
-                    desired_layout = self.desired_input_kwarg_layouts[kwarg_key]
-                    if desired_layout != input_layout:
-                        kwarg_val = kwarg_val.redistribute(placements=(desired_layout,))
-
-                prepared_kwarg_inputs[kwarg_key] = kwarg_val.to_local() if self.use_local_output else kwarg_val
-            else:
-                prepared_kwarg_inputs[kwarg_key] = kwarg_val
+            prepared_kwarg_inputs[kwarg_key] = self._prepare_input_arg(kwarg_val, device_mesh, input_layout, desired_input_layout)
 
         return (prepared_arg_inputs, prepared_kwarg_inputs)