Refactor `move_constructor_to_cuda`. (#114626)
Follow-up: #114539
This PR introduces a minor change to the `move_constructor_to_cuda` implementation, while
refactoring the whole pass into a class. Here's a brief summary of the changes:
- Create a new `ConstructorMoverPass`
- Rephrase the condition:
```python
if not isinstance(
node.target, torch._ops.OpOverload
) or node.target.namespace not in ("prims", "aten"):
...
if not (
isinstance(node.target, torch._ops.OpOverload)
and node.target.namespace in ("prims", "aten")
):
...
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114626
Approved by: https://github.com/eellison
diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py
index 03eae27..95d250f 100644
--- a/torch/_inductor/fx_passes/post_grad.py
+++ b/torch/_inductor/fx_passes/post_grad.py
@@ -987,169 +987,203 @@
return inductor.kernel.mm.tuned_fused_int_mm_mul(mat1, mat2, mat3, out_dtype)
-def allows_mixed_devices(op):
- return op in (
- aten.index.Tensor,
- aten.index_put.default,
- aten.index_put_.default,
- aten.copy.default,
- aten.copy_.default,
- aten.slice_scatter.default,
- )
+class ConstructorMoverPass:
+ def __init__(self, target: str, allow_outputs: bool = False) -> None:
+ """
+ Move constructors from cpu to the target_device.
+
+ Sweeps through the module, looking for constructor nodes that can be moved
+ to the target_device.
+
+ A constructor node can be moved to the target_device iff all of its users
+ can also be moved (tested by cannot_be_moved). Otherwise, all dependent
+ constructor nodes won't be moved.
+
+ - target: target device type
+ - allow_outputs: allow outputs to be moved
+ """
+
+ self.target = target
+ self.allow_outputs = allow_outputs
+
+ assert isinstance(target, str), (
+ "target should be a string representing the device type. "
+ f"Got: {type(target).__name__}"
+ )
+
+ def allow_cpu_device(self, node: fx.Node) -> bool:
+ """
+ Returns whether a node that returns a tensor on the target device may have
+ cpu tensors as input.
+ """
+ return node.target in (
+ torch.ops.aten.index.Tensor,
+ torch.ops.aten.index_put.default,
+ torch.ops.aten.index_put_.default,
+ torch.ops.aten.copy.default,
+ torch.ops.aten.copy_.default,
+ torch.ops.aten.slice_scatter.default,
+ )
+
+ def cannot_be_moved(self, node: fx.Node) -> bool:
+ """
+ Returns whether a node can be moved to the target device.
+
+ If this function returns False, it means that this node and all of its users
+ won't be moved into the target device.
+ """
+ if node.target == "output":
+ return not self.allow_outputs
+
+ if not (
+ isinstance(node.target, torch._ops.OpOverload)
+ and node.target.namespace in ("prims", "aten")
+ ):
+ return True
+
+ return False
+
+ def get_node_device(self, node: fx.Node) -> Optional[torch.device]:
+ """
+ Get the device of a node.
+ """
+ ten = node.meta.get("val")
+ return None if not isinstance(ten, torch.Tensor) else ten.device
+
+ def get_cpu_indeg_count(self, graph: fx.Graph) -> Dict[fx.Node, int]:
+ """
+ Get the number of cpu inputs to a node
+ """
+ cpu_indeg: Dict[fx.Node, int] = Counter()
+
+ for node in graph.nodes:
+ cpu_count = 0
+
+ def add_cpu_inp(node):
+ nonlocal cpu_count
+ device = self.get_node_device(node)
+ cpu_count += device is not None and device.type == "cpu"
+
+ pytree.tree_map_only(fx.Node, add_cpu_inp, (node.args, node.kwargs))
+
+ if cpu_count:
+ cpu_indeg[node] = cpu_count
+
+ return cpu_indeg
+
+ def __call__(self, graph: fx.Graph) -> None:
+ target_devices = set()
+ constructors = []
+
+ for node in graph.nodes:
+ device = self.get_node_device(node)
+ if device and device.type == self.target:
+ target_devices.add(device)
+
+ if not (
+ isinstance(node.target, torch._ops.OpOverload)
+ and node.target.namespace in ("prims", "aten")
+ ):
+ continue
+
+ if not torch._subclasses.fake_tensor._is_tensor_constructor(node.target):
+ continue
+
+ if not node.kwargs.get("device") == torch.device("cpu"):
+ continue
+
+ constructors.append(node)
+
+ # not handling multiple target devices initially
+ if not constructors or len(target_devices) != 1:
+ return
+
+ movable_constructors = self.find_movable_constructors(graph, constructors)
+
+ for node in movable_constructors:
+ kwargs = node.kwargs.copy()
+ kwargs["device"] = next(iter(target_devices))
+ node.kwargs = kwargs
+
+ def find_movable_constructors(
+ self, graph: fx.Graph, constructors: List[fx.Node]
+ ) -> Set[fx.Node]:
+ """
+ Starting from the cpu constructors, iterate through the graph and test that all of their
+ downstream uses can safely be moved to cpu.
+ """
+ cpu_indeg: Dict[fx.Node, int] = self.get_cpu_indeg_count(graph)
+
+ # which constructors cannot be moved to cuda
+ cannot_move_to_cuda: Set[fx.Node] = set()
+
+ # For any node in the graph, which constructors does it have a dependency on
+ constructor_dependencies: Dict[fx.Node, Set[fx.Node]] = defaultdict(set)
+
+ # if a cpu node has a dependency on two different cpu constructors,
+ # then if either constructor cannot be moved to cuda, the other cannot as well.
+ # In this case any node with a dependency on one will have a dependency on the other
+ equal_constructor_sets: Dict[fx.Node, Set[fx.Node]] = {
+ c: {c} for c in constructors
+ }
+
+ def make_dependencies_equivalent(
+ set1: Set[fx.Node], set2: Set[fx.Node]
+ ) -> Set[fx.Node]:
+ # could use union find but not worth complexity here
+ set1.update(set2)
+ for obj in set1:
+ equal_constructor_sets[obj] = set1
+ return set1
+
+ queue: List[fx.Node] = list(constructors)
+
+ for c in queue:
+ constructor_dependencies[c].add(c)
+
+ while queue:
+ node = queue.pop()
+ dependencies = constructor_dependencies[node]
+
+ for user in node.users:
+ if self.cannot_be_moved(user):
+ cannot_move_to_cuda.update(dependencies)
+ break
+
+ # this node was used on a op which takes in multiple devices and output a cuda
+ # tensor. we can convert its cpu input to cuda without making further changes
+ node_device = self.get_node_device(user)
+ if (
+ self.allow_cpu_device(user)
+ and node_device
+ and node_device.type == "cuda"
+ ):
+ del cpu_indeg[user]
+ else:
+ # otherwise, we should continue look at its downstream uses
+ cpu_indeg[user] -= 1
+ if cpu_indeg[user] == 0:
+ del cpu_indeg[user]
+ queue.append(user)
+
+ unioned_set = make_dependencies_equivalent(
+ dependencies, constructor_dependencies[user]
+ )
+ constructor_dependencies[user] = unioned_set
+
+ for node in cpu_indeg:
+ if constructor_dependencies[node]:
+ cannot_move_to_cuda.update(constructor_dependencies[node])
+
+ all_cannot_move_to_cuda = cannot_move_to_cuda.copy()
+ for constructor in cannot_move_to_cuda:
+ all_cannot_move_to_cuda.update(equal_constructor_sets[constructor])
+
+ return set(constructors) - all_cannot_move_to_cuda
-def cannot_be_moved_to_cuda(node):
- if node.target == "output":
- return True
-
- if not isinstance(
- node.target, torch._ops.OpOverload
- ) and node.target.namespace not in ("prims", "aten"):
- return True
-
- # only move ops to inductor lowerings for now,
- # fallback ops may have weird cpu/cuda incompatibilities
- return (
- node.target not in torch._inductor.lowering.lowerings
- or node.target in torch._inductor.lowering.fallbacks
- )
-
-
-def get_node_device(node: fx.Node) -> Optional[torch.device]:
- ten = node.meta.get("val")
- return None if not isinstance(ten, torch.Tensor) else ten.device
-
-
-def get_cpu_indeg_count(graph) -> Dict[fx.Node, int]:
- """
- Get the number of cpu inputs to a node
- """
- cpu_indeg: Dict[fx.Node, int] = Counter()
-
- for node in graph.nodes:
- cpu_count = 0
-
- def add_cpu_inp(node):
- nonlocal cpu_count
- device = get_node_device(node)
- cpu_count += device is not None and device.type == "cpu"
-
- pytree.tree_map_only(torch.fx.Node, add_cpu_inp, (node.args, node.kwargs))
-
- if cpu_count:
- cpu_indeg[node] = cpu_count
-
- return cpu_indeg
-
-
-def move_constructors_to_cuda(graph):
+def move_constructors_to_cuda(graph: fx.Graph) -> None:
"""
Moves intermediary tensors which are constructed on the cpu to cuda when safe
"""
- if not torch.backends.cuda.is_built():
- return
-
- cuda_devices = set()
- constructors = []
-
- for node in graph.nodes:
- device = get_node_device(node)
- if device and device.type == "cuda":
- cuda_devices.add(device)
-
- if not isinstance(
- node.target, torch._ops.OpOverload
- ) or node.target.namespace not in ("prims", "aten"):
- continue
-
- if not torch._subclasses.fake_tensor._is_tensor_constructor(node.target):
- continue
-
- if not node.kwargs.get("device") == torch.device("cpu"):
- continue
-
- constructors.append(node)
-
- # not handling multiple cuda devices initially
- if not constructors or len(cuda_devices) != 1:
- return
-
- movable_constructors = find_movable_constructors(graph, constructors)
-
- for node in movable_constructors:
- kwargs = node.kwargs.copy()
- kwargs["device"] = next(iter(cuda_devices))
- node.kwargs = kwargs
-
-
-def find_movable_constructors(graph, constructors: List[fx.Node]) -> Set[fx.Node]:
- """
- Starting from the cpu constructors, iterate through the graph and test that all of their
- downstream uses can safely be moved to cpu.
- """
- cpu_indeg: Dict[fx.Node, int] = get_cpu_indeg_count(graph)
-
- # which constructors cannot be moved to cuda
- cannot_move_to_cuda: Set[fx.Node] = set()
-
- # For any node in the graph, which constructors does it have a dependency on
- constructor_dependencies: Dict[fx.Node, Set[fx.Node]] = defaultdict(set)
-
- # if a cpu node has a dependency on two different cpu constructors,
- # then if either constructor cannot be moved to cuda, the other cannot as well.
- # In this case any node with a dependency on one will have a dependency on the other
- equal_constructor_sets: Dict[fx.Node, Set[fx.Node]] = {c: {c} for c in constructors}
-
- def make_dependencies_equivalent(
- set1: Set[fx.Node], set2: Set[fx.Node]
- ) -> Set[fx.Node]:
- # could use union find but not worth complexity here
- set1.update(set2)
- for obj in set1:
- equal_constructor_sets[obj] = set1
- return set1
-
- queue: List[fx.Node] = list(constructors)
-
- for c in queue:
- constructor_dependencies[c].add(c)
-
- while queue:
- node = queue.pop()
- dependencies = constructor_dependencies[node]
-
- for user in node.users:
- if cannot_be_moved_to_cuda(user):
- cannot_move_to_cuda.update(dependencies)
- break
-
- # this node was used on a op which takes in multiple devices and output a cuda
- # tensor. we can convert its cpu input to cuda without making further changes
- node_device = get_node_device(user)
- if (
- allows_mixed_devices(user.target)
- and node_device
- and node_device.type == "cuda"
- ):
- del cpu_indeg[user]
- else:
- # otherwise, we should continue look at its downstream uses
- cpu_indeg[user] -= 1
- if cpu_indeg[user] == 0:
- del cpu_indeg[user]
- queue.append(user)
-
- unioned_set = make_dependencies_equivalent(
- dependencies, constructor_dependencies[user]
- )
- constructor_dependencies[user] = unioned_set
-
- for node in cpu_indeg:
- if constructor_dependencies[node]:
- cannot_move_to_cuda.update(constructor_dependencies[node])
-
- all_cannot_move_to_cuda = cannot_move_to_cuda.copy()
- for constructor in cannot_move_to_cuda:
- all_cannot_move_to_cuda.update(equal_constructor_sets[constructor])
-
- return set(constructors) - all_cannot_move_to_cuda
+ ConstructorMoverPass("cuda")(graph)