Refactor `move_constructor_to_cuda`. (#114626)

Follow-up: #114539

This PR introduces a minor change to the `move_constructor_to_cuda` implementation, while
refactoring the whole pass into a class. Here's a brief summary of the changes:

- Create a new `ConstructorMoverPass`
- Rephrase the condition:

```python
if not isinstance(
    node.target, torch._ops.OpOverload
) or node.target.namespace not in ("prims", "aten"):
    ...

if not (
    isinstance(node.target, torch._ops.OpOverload)
    and node.target.namespace in ("prims", "aten")
):
    ...
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114626
Approved by: https://github.com/eellison
diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py
index 03eae27..95d250f 100644
--- a/torch/_inductor/fx_passes/post_grad.py
+++ b/torch/_inductor/fx_passes/post_grad.py
@@ -987,169 +987,203 @@
     return inductor.kernel.mm.tuned_fused_int_mm_mul(mat1, mat2, mat3, out_dtype)
 
 
-def allows_mixed_devices(op):
-    return op in (
-        aten.index.Tensor,
-        aten.index_put.default,
-        aten.index_put_.default,
-        aten.copy.default,
-        aten.copy_.default,
-        aten.slice_scatter.default,
-    )
+class ConstructorMoverPass:
+    def __init__(self, target: str, allow_outputs: bool = False) -> None:
+        """
+        Move constructors from cpu to the target_device.
+
+        Sweeps through the module, looking for constructor nodes that can be moved
+        to the target_device.
+
+        A constructor node can be moved to the target_device iff all of its users
+        can also be moved (tested by cannot_be_moved). Otherwise, all dependent
+        constructor nodes won't be moved.
+
+        - target: target device type
+        - allow_outputs: allow outputs to be moved
+        """
+
+        self.target = target
+        self.allow_outputs = allow_outputs
+
+        assert isinstance(target, str), (
+            "target should be a string representing the device type. "
+            f"Got: {type(target).__name__}"
+        )
+
+    def allow_cpu_device(self, node: fx.Node) -> bool:
+        """
+        Returns whether a node that returns a tensor on the target device may have
+        cpu tensors as input.
+        """
+        return node.target in (
+            torch.ops.aten.index.Tensor,
+            torch.ops.aten.index_put.default,
+            torch.ops.aten.index_put_.default,
+            torch.ops.aten.copy.default,
+            torch.ops.aten.copy_.default,
+            torch.ops.aten.slice_scatter.default,
+        )
+
+    def cannot_be_moved(self, node: fx.Node) -> bool:
+        """
+        Returns whether a node can be moved to the target device.
+
+        If this function returns False, it means that this node and all of its users
+        won't be moved into the target device.
+        """
+        if node.target == "output":
+            return not self.allow_outputs
+
+        if not (
+            isinstance(node.target, torch._ops.OpOverload)
+            and node.target.namespace in ("prims", "aten")
+        ):
+            return True
+
+        return False
+
+    def get_node_device(self, node: fx.Node) -> Optional[torch.device]:
+        """
+        Get the device of a node.
+        """
+        ten = node.meta.get("val")
+        return None if not isinstance(ten, torch.Tensor) else ten.device
+
+    def get_cpu_indeg_count(self, graph: fx.Graph) -> Dict[fx.Node, int]:
+        """
+        Get the number of cpu inputs to a node
+        """
+        cpu_indeg: Dict[fx.Node, int] = Counter()
+
+        for node in graph.nodes:
+            cpu_count = 0
+
+            def add_cpu_inp(node):
+                nonlocal cpu_count
+                device = self.get_node_device(node)
+                cpu_count += device is not None and device.type == "cpu"
+
+            pytree.tree_map_only(fx.Node, add_cpu_inp, (node.args, node.kwargs))
+
+            if cpu_count:
+                cpu_indeg[node] = cpu_count
+
+        return cpu_indeg
+
+    def __call__(self, graph: fx.Graph) -> None:
+        target_devices = set()
+        constructors = []
+
+        for node in graph.nodes:
+            device = self.get_node_device(node)
+            if device and device.type == self.target:
+                target_devices.add(device)
+
+            if not (
+                isinstance(node.target, torch._ops.OpOverload)
+                and node.target.namespace in ("prims", "aten")
+            ):
+                continue
+
+            if not torch._subclasses.fake_tensor._is_tensor_constructor(node.target):
+                continue
+
+            if not node.kwargs.get("device") == torch.device("cpu"):
+                continue
+
+            constructors.append(node)
+
+        # not handling multiple target devices initially
+        if not constructors or len(target_devices) != 1:
+            return
+
+        movable_constructors = self.find_movable_constructors(graph, constructors)
+
+        for node in movable_constructors:
+            kwargs = node.kwargs.copy()
+            kwargs["device"] = next(iter(target_devices))
+            node.kwargs = kwargs
+
+    def find_movable_constructors(
+        self, graph: fx.Graph, constructors: List[fx.Node]
+    ) -> Set[fx.Node]:
+        """
+        Starting from the cpu constructors, iterate through the graph and test that all of their
+        downstream uses can safely be moved to cpu.
+        """
+        cpu_indeg: Dict[fx.Node, int] = self.get_cpu_indeg_count(graph)
+
+        # which constructors cannot be moved to cuda
+        cannot_move_to_cuda: Set[fx.Node] = set()
+
+        # For any node in the graph, which constructors does it have a dependency on
+        constructor_dependencies: Dict[fx.Node, Set[fx.Node]] = defaultdict(set)
+
+        # if a cpu node has a dependency on two different cpu constructors,
+        # then if either constructor cannot be moved to cuda, the other cannot as well.
+        # In this case any node with a dependency on one will have a dependency on the other
+        equal_constructor_sets: Dict[fx.Node, Set[fx.Node]] = {
+            c: {c} for c in constructors
+        }
+
+        def make_dependencies_equivalent(
+            set1: Set[fx.Node], set2: Set[fx.Node]
+        ) -> Set[fx.Node]:
+            # could use union find but not worth complexity here
+            set1.update(set2)
+            for obj in set1:
+                equal_constructor_sets[obj] = set1
+            return set1
+
+        queue: List[fx.Node] = list(constructors)
+
+        for c in queue:
+            constructor_dependencies[c].add(c)
+
+        while queue:
+            node = queue.pop()
+            dependencies = constructor_dependencies[node]
+
+            for user in node.users:
+                if self.cannot_be_moved(user):
+                    cannot_move_to_cuda.update(dependencies)
+                    break
+
+                # this node was used on a op which takes in multiple devices and output a cuda
+                # tensor. we can convert its cpu input to cuda without making further changes
+                node_device = self.get_node_device(user)
+                if (
+                    self.allow_cpu_device(user)
+                    and node_device
+                    and node_device.type == "cuda"
+                ):
+                    del cpu_indeg[user]
+                else:
+                    # otherwise, we should continue look at its downstream uses
+                    cpu_indeg[user] -= 1
+                    if cpu_indeg[user] == 0:
+                        del cpu_indeg[user]
+                        queue.append(user)
+
+                unioned_set = make_dependencies_equivalent(
+                    dependencies, constructor_dependencies[user]
+                )
+                constructor_dependencies[user] = unioned_set
+
+        for node in cpu_indeg:
+            if constructor_dependencies[node]:
+                cannot_move_to_cuda.update(constructor_dependencies[node])
+
+        all_cannot_move_to_cuda = cannot_move_to_cuda.copy()
+        for constructor in cannot_move_to_cuda:
+            all_cannot_move_to_cuda.update(equal_constructor_sets[constructor])
+
+        return set(constructors) - all_cannot_move_to_cuda
 
 
-def cannot_be_moved_to_cuda(node):
-    if node.target == "output":
-        return True
-
-    if not isinstance(
-        node.target, torch._ops.OpOverload
-    ) and node.target.namespace not in ("prims", "aten"):
-        return True
-
-    # only move ops to inductor lowerings for now,
-    # fallback ops may have weird cpu/cuda incompatibilities
-    return (
-        node.target not in torch._inductor.lowering.lowerings
-        or node.target in torch._inductor.lowering.fallbacks
-    )
-
-
-def get_node_device(node: fx.Node) -> Optional[torch.device]:
-    ten = node.meta.get("val")
-    return None if not isinstance(ten, torch.Tensor) else ten.device
-
-
-def get_cpu_indeg_count(graph) -> Dict[fx.Node, int]:
-    """
-    Get the number of cpu inputs to a node
-    """
-    cpu_indeg: Dict[fx.Node, int] = Counter()
-
-    for node in graph.nodes:
-        cpu_count = 0
-
-        def add_cpu_inp(node):
-            nonlocal cpu_count
-            device = get_node_device(node)
-            cpu_count += device is not None and device.type == "cpu"
-
-        pytree.tree_map_only(torch.fx.Node, add_cpu_inp, (node.args, node.kwargs))
-
-        if cpu_count:
-            cpu_indeg[node] = cpu_count
-
-    return cpu_indeg
-
-
-def move_constructors_to_cuda(graph):
+def move_constructors_to_cuda(graph: fx.Graph) -> None:
     """
     Moves intermediary tensors which are constructed on the cpu to cuda when safe
     """
-    if not torch.backends.cuda.is_built():
-        return
-
-    cuda_devices = set()
-    constructors = []
-
-    for node in graph.nodes:
-        device = get_node_device(node)
-        if device and device.type == "cuda":
-            cuda_devices.add(device)
-
-        if not isinstance(
-            node.target, torch._ops.OpOverload
-        ) or node.target.namespace not in ("prims", "aten"):
-            continue
-
-        if not torch._subclasses.fake_tensor._is_tensor_constructor(node.target):
-            continue
-
-        if not node.kwargs.get("device") == torch.device("cpu"):
-            continue
-
-        constructors.append(node)
-
-    # not handling multiple cuda devices initially
-    if not constructors or len(cuda_devices) != 1:
-        return
-
-    movable_constructors = find_movable_constructors(graph, constructors)
-
-    for node in movable_constructors:
-        kwargs = node.kwargs.copy()
-        kwargs["device"] = next(iter(cuda_devices))
-        node.kwargs = kwargs
-
-
-def find_movable_constructors(graph, constructors: List[fx.Node]) -> Set[fx.Node]:
-    """
-    Starting from the cpu constructors, iterate through the graph and test that all of their
-    downstream uses can safely be moved to cpu.
-    """
-    cpu_indeg: Dict[fx.Node, int] = get_cpu_indeg_count(graph)
-
-    # which constructors cannot be moved to cuda
-    cannot_move_to_cuda: Set[fx.Node] = set()
-
-    # For any node in the graph, which constructors does it have a dependency on
-    constructor_dependencies: Dict[fx.Node, Set[fx.Node]] = defaultdict(set)
-
-    # if a cpu node has a dependency on two different cpu constructors,
-    # then if either constructor cannot be moved to cuda, the other cannot as well.
-    # In this case any node with a dependency on one will have a dependency on the other
-    equal_constructor_sets: Dict[fx.Node, Set[fx.Node]] = {c: {c} for c in constructors}
-
-    def make_dependencies_equivalent(
-        set1: Set[fx.Node], set2: Set[fx.Node]
-    ) -> Set[fx.Node]:
-        # could use union find but not worth complexity here
-        set1.update(set2)
-        for obj in set1:
-            equal_constructor_sets[obj] = set1
-        return set1
-
-    queue: List[fx.Node] = list(constructors)
-
-    for c in queue:
-        constructor_dependencies[c].add(c)
-
-    while queue:
-        node = queue.pop()
-        dependencies = constructor_dependencies[node]
-
-        for user in node.users:
-            if cannot_be_moved_to_cuda(user):
-                cannot_move_to_cuda.update(dependencies)
-                break
-
-            # this node was used on a op which takes in multiple devices and output a cuda
-            # tensor. we can convert its cpu input to cuda without making further changes
-            node_device = get_node_device(user)
-            if (
-                allows_mixed_devices(user.target)
-                and node_device
-                and node_device.type == "cuda"
-            ):
-                del cpu_indeg[user]
-            else:
-                # otherwise, we should continue look at its downstream uses
-                cpu_indeg[user] -= 1
-                if cpu_indeg[user] == 0:
-                    del cpu_indeg[user]
-                    queue.append(user)
-
-            unioned_set = make_dependencies_equivalent(
-                dependencies, constructor_dependencies[user]
-            )
-            constructor_dependencies[user] = unioned_set
-
-    for node in cpu_indeg:
-        if constructor_dependencies[node]:
-            cannot_move_to_cuda.update(constructor_dependencies[node])
-
-    all_cannot_move_to_cuda = cannot_move_to_cuda.copy()
-    for constructor in cannot_move_to_cuda:
-        all_cannot_move_to_cuda.update(equal_constructor_sets[constructor])
-
-    return set(constructors) - all_cannot_move_to_cuda
+    ConstructorMoverPass("cuda")(graph)