blob: b1c8de43470e0fae835b79447429265b4e3f294b [file] [log] [blame]
# Copyright (c) Meta Platforms, Inc. and affiliates
import logging
from typing import Dict, Optional
import torch
from torch import fx
from torch.export.unflatten import InterpreterModule
logger = logging.getLogger(__name__)
def flatten_args_detach(args):
"""
Flatten the args into a list form and detach the tensors from computational graph.
"""
flat_detached_args = []
def extract_tensor_args(a):
nonlocal flat_detached_args
if isinstance(a, torch.Tensor):
val = a.detach().requires_grad_(a.requires_grad)
flat_detached_args.append(val)
return val
else:
flat_detached_args.append(a)
return a
new_args = fx.node.map_aggregate(
args,
extract_tensor_args,
)
return new_args, flat_detached_args
def flatten_args(args):
"""
Flatten the args into a list form.
"""
flat_args = []
def extract_tensor_args(a):
nonlocal flat_args
flat_args.append(a)
return a
fx.node.map_aggregate(
args,
extract_tensor_args,
)
return flat_args
def modify_graph_op_device(
gm: torch.fx.GraphModule,
new_device: torch.device,
):
"""
Modify the device argument of all "call_function" nodes in the graph. This
is useful for moving the graph to a different device. In particular for
generator ops, like torch.ones.
"""
modified = False
for node in gm.graph.nodes:
if node.op == "call_function":
if "device" in node.kwargs and node.kwargs["device"] != new_device:
logger.debug(
f"Changing device of Node {node.name} from {node.kwargs['device']} to {new_device}" # noqa: G004
)
node.update_kwarg("device", new_device)
modified = True
elif node.op == "call_module":
# Recursively modify "device" in submodules
submod = gm.get_submodule(node.target)
if isinstance(submod, torch.fx.GraphModule):
modify_graph_op_device(submod, new_device)
elif isinstance(submod, InterpreterModule):
# If unflattening has been performed, we need to access its graph module by `.graph_module`
modify_graph_op_device(submod.graph_module, new_device)
else:
logger.warning(
f"Skipping device modification for submodule {node.target} because it is a {type(submod)}" # noqa: G004
)
if modified:
gm.recompile()
class QualnameMapMixin:
"""
A mixin class that helps a `Pipe` object to remap its qualnames back to
original qualnames.
"""
def __init__(
self,
splitter_qualname_map: Optional[Dict[str, str]] = None,
tracer_qualname_map: Optional[Dict[str, str]] = None,
):
self.new_to_old_qualname_mapping: Dict[str, str] = splitter_qualname_map or {}
self.tracer_qualname_map = tracer_qualname_map
def remap_qualname(self, qualname: str):
# TODO: annoying
if qualname.startswith("split_gm."):
qualname = qualname[len("split_gm.") :]
name_before_split = None
if qualname in self.new_to_old_qualname_mapping:
name_before_split = self.new_to_old_qualname_mapping[qualname]
else:
# The qualname map does not store recursive items, thus,
# when passed a qualname with leaves, we need to perform longest prefix match
# Split from the right, one each time
split_names = qualname.rsplit(".", 1)
leaf = split_names[-1]
while len(split_names) > 1:
prefix = split_names[0]
if prefix in self.new_to_old_qualname_mapping:
old_prefix = self.new_to_old_qualname_mapping[prefix]
name_before_split = ".".join([old_prefix, leaf])
break
split_names = prefix.rsplit(".", 1)
leaf = ".".join([split_names[-1], leaf])
if name_before_split is None:
raise RuntimeError(f"Could not find mapping for {qualname}")
if self.tracer_qualname_map is not None:
return self.tracer_qualname_map[name_before_split]
else:
return name_before_split