[tp] refactor and fix PrepareModuleInput for DTensor inputs (#128431)
as titled, this PR refactors the PrepareModuleInput style to have common
method prepare_input_arg, allow both args/kwargs to reuse this logic
This also fixes https://github.com/pytorch/pytorch/issues/128365
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128431
Approved by: https://github.com/awgu
diff --git a/test/distributed/tensor/parallel/test_tp_style.py b/test/distributed/tensor/parallel/test_tp_style.py
index e2a9a01..776bdc9 100644
--- a/test/distributed/tensor/parallel/test_tp_style.py
+++ b/test/distributed/tensor/parallel/test_tp_style.py
@@ -317,6 +317,18 @@
self.assertEqual(comm_mode.get_total_counts(), 2)
self.assertEqual(output.shape, (1 * self.world_size, 8))
+ # test the case where x is a DTensor
+ x_dt = DTensor.from_local(
+ torch.randn(1, 8, device=self.device_type), mesh, [Shard(0)]
+ )
+ with comm_mode:
+ output = test_kwonly_mod(
+ x=x_dt, z=torch.ones(1, 8, device=self.device_type)
+ )
+
+ self.assertEqual(comm_mode.get_total_counts(), 2)
+ self.assertEqual(output.shape, (1 * self.world_size, 8))
+
@with_comms
def test_prepare_module_output(self):
mesh = init_device_mesh(self.device_type, (self.world_size,))
diff --git a/torch/distributed/tensor/parallel/style.py b/torch/distributed/tensor/parallel/style.py
index f532b97..00d85bf 100644
--- a/torch/distributed/tensor/parallel/style.py
+++ b/torch/distributed/tensor/parallel/style.py
@@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
from abc import ABC, abstractmethod
-from typing import Optional, Union, Tuple, Dict
+from typing import Optional, Union, Tuple, Dict, Any
from functools import partial
import torch
@@ -400,6 +400,23 @@
assert len(self.input_kwarg_layouts) == len(self.desired_input_kwarg_layouts), \
"input_kwarg_layouts and desired_input_kwarg_layouts should have same length!"
+ def _prepare_input_arg(self, input: Any, mesh: DeviceMesh, input_layout: Placement, desired_layout: Placement):
+ if input_layout is not None:
+ if isinstance(input, DTensor):
+ # TODO: re-enable the check once we fix the compile path
+ # assert inp.placements[0] == input_layout
+ dt_inp = input
+ else:
+ assert isinstance(input, torch.Tensor), "expecting input to be a torch.Tensor!"
+ dt_inp = DTensor.from_local(input, mesh, (input_layout,), run_check=False)
+
+ if desired_layout is not None and input_layout != desired_layout:
+ dt_inp = dt_inp.redistribute(placements=(desired_layout,))
+
+ return dt_inp.to_local() if self.use_local_output else dt_inp
+ else:
+ return input
+
def _prepare_input_fn(self, inputs, device_mesh):
if self.input_layouts is None:
return inputs
@@ -409,21 +426,8 @@
if len(inputs) != len(self.input_layouts):
raise ValueError("module inputs and input_layouts should have same length!")
- assert self.desired_input_layouts is not None, "desired module inputs should not be None!"
for inp, input_layout, desired_layout in zip(inputs, self.input_layouts, self.desired_input_layouts):
- if input_layout is not None:
- if isinstance(inp, DTensor):
- # TODO: re-enable the check once we fix the compile path
- # assert inp.placements[0] == input_layout
- dt_inp = inp
- else:
- dt_inp = DTensor.from_local(inp, device_mesh, (input_layout,), run_check=False)
-
- if desired_layout is not None and input_layout != desired_layout:
- dt_inp = dt_inp.redistribute(placements=(desired_layout,))
- prepared_inputs.append(dt_inp.to_local() if self.use_local_output else dt_inp)
- else:
- prepared_inputs.append(inp)
+ prepared_inputs.append(self._prepare_input_arg(inp, device_mesh, input_layout, desired_layout))
return tuple(prepared_inputs)
def _prepare_input_kwarg_fn(self, inputs, kwarg_inputs, device_mesh):
@@ -431,20 +435,10 @@
prepared_kwarg_inputs = {}
for kwarg_key in kwarg_inputs.keys():
kwarg_val = kwarg_inputs[kwarg_key]
- input_layout = None
- if kwarg_key in self.input_kwarg_layouts:
- input_layout = self.input_kwarg_layouts[kwarg_key]
- assert isinstance(kwarg_val, torch.Tensor), f"input of key {kwarg_key} to the module should be a Tensor!"
- kwarg_val = DTensor.from_local(kwarg_val, device_mesh, (input_layout,), run_check=False)
+ input_layout = self.input_kwarg_layouts.get(kwarg_key)
+ desired_input_layout = self.desired_input_kwarg_layouts.get(kwarg_key)
- if kwarg_key in self.desired_input_kwarg_layouts:
- desired_layout = self.desired_input_kwarg_layouts[kwarg_key]
- if desired_layout != input_layout:
- kwarg_val = kwarg_val.redistribute(placements=(desired_layout,))
-
- prepared_kwarg_inputs[kwarg_key] = kwarg_val.to_local() if self.use_local_output else kwarg_val
- else:
- prepared_kwarg_inputs[kwarg_key] = kwarg_val
+ prepared_kwarg_inputs[kwarg_key] = self._prepare_input_arg(kwarg_val, device_mesh, input_layout, desired_input_layout)
return (prepared_arg_inputs, prepared_kwarg_inputs)