blob: ca0ba2c84fb65d234f61d4e527c426e14780381a [file] [log] [blame]
import copy
from typing import Callable, Tuple
from unittest.mock import patch
import torch
import torch._dynamo as torchdynamo
from torch._decomp import core_aten_decompositions
from torch._dispatch.python import enable_python_dispatcher
from torch.nn.utils import stateless
from torch.utils import _pytree as pytree
from torch._functorch.aot_autograd import (
AOTConfig,
create_aot_dispatcher_function,
default_partition,
run_functionalized_fw_and_collect_metadata,
)
from torch.fx.experimental.proxy_tensor import (
get_torch_dispatch_modes,
has_proxy_slot,
ProxyTorchDispatchMode,
set_proxy_slot,
)
from torch._functorch.eager_transforms import _unwrap_all_tensors_from_functional
from .workflow import ExportedProgram
CORE_ATEN_DECOMPOSITIONS_TABLE = core_aten_decompositions()
__all__ = ["do_not_use_experimental_export"]
def _aot_capture(mod, flat_args):
"""
A wrapper around aot_autograd() to mix AOT Autograd + torch.export.
Some assumptions were made about the AOT Autograd internal:
1. The functionalization metadata format.
2. Calling convention of returned forward graph.
3. make_fx() internal proxy storage.
In the current context we're just experimenting the idea so it's possible things
could break. For the next step we should find a way to upstream something reasonable.
"""
param_list = [
*mod.named_parameters(remove_duplicate=False),
*mod.named_buffers(remove_duplicate=False),
]
params = dict(param_list)
params_flat, params_spec = pytree.tree_flatten(params)
params_len = len(params_flat)
full_args = []
full_args.extend(params_flat)
full_args.extend(flat_args)
def functional_call(*args):
with stateless._reparametrize_module(
mod,
pytree.tree_unflatten(args[:params_len], params_spec), # type: ignore[arg-type]
):
return torch.fx.Interpreter(mod).run(*args[params_len:])
out_spec = None
with enable_python_dispatcher():
fw_metadata = run_functionalized_fw_and_collect_metadata(
lambda *args: pytree.tree_flatten(functional_call(*args))[0],
keep_input_mutations=False,
)(*copy.deepcopy(full_args)) # type: ignore[operator]
assert len(fw_metadata.input_info) == len(full_args)
mutated_input_indices = [
i
for i, input_info in enumerate(fw_metadata.input_info)
if input_info.mutates_data or input_info.mutates_metadata
]
graph_module = None
def fw_compiler(gm, inputs):
nonlocal graph_module
graph_module = gm
num_fwd_returns = None
def partition_fn(joint_module, joint_inputs, *, num_fwd_outputs, **kwargs):
nonlocal num_fwd_returns
num_fwd_returns = num_fwd_outputs
return default_partition(
joint_module, joint_inputs, num_fwd_outputs=num_fwd_outputs, **kwargs
)
def set_state_proxies(state_args):
modes = get_torch_dispatch_modes()
proxy_tensor_modes = [m for m in modes if isinstance(m, ProxyTorchDispatchMode)]
if len(proxy_tensor_modes) == 0:
return
assert len(state_args) == len(params_flat)
for i, arg in enumerate(state_args):
tracer = next(
m.tracer for m in proxy_tensor_modes if has_proxy_slot(arg, m.tracer)
)
set_proxy_slot(arg, tracer, params_flat[i])
aot_config = AOTConfig(
fw_compiler=fw_compiler,
bw_compiler=lambda gm, inputs: None,
partition_fn=partition_fn,
decompositions=CORE_ATEN_DECOMPOSITIONS_TABLE, # type: ignore[arg-type]
num_params_buffers=params_len,
aot_id=-1,
keep_inference_input_mutations=False,
dynamic_shapes=True,
)
def exported_call(*args):
state_args = args[:params_len]
unwrapped_state_args = _unwrap_all_tensors_from_functional(
state_args, reapply_views=False
)
set_state_proxies(unwrapped_state_args)
with torch.fx.traceback.preserve_node_meta():
outputs = functional_call(*args)
nonlocal out_spec
outputs, out_spec = pytree.tree_flatten(outputs)
return outputs
with torch.enable_grad():
create_aot_dispatcher_function(
exported_call,
full_args,
aot_config,
)
assert graph_module is not None
for i, node in enumerate(graph_module.graph.nodes):
if i == len(params_flat):
break
assert node.op == "placeholder" and len(node.users) == 0
graph_module.graph.erase_node(node)
output_node = next(iter(reversed(graph_module.graph.nodes)))
assert output_node.op == "output" and len(output_node.args) == 1
assert num_fwd_returns is not None
# Turncate the output so we only output what we need.
output_node.args = (
output_node.args[0][
: len(mutated_input_indices) + len(fw_metadata.output_info)
],
)
graph_module.graph.eliminate_dead_code()
graph_module.recompile()
def find_mutation_destinations(gm, w):
assert isinstance(w, torch.Tensor)
ret = [
name for name, x in [*gm.named_parameters(), *gm.named_buffers()] if x is w
]
assert len(ret) != 0, "Cannot find mutation destination."
return ret
mutation = [
(
"copy_",
output_node.args[0][k].name,
find_mutation_destinations(graph_module, param_list[i][1]),
)
for k, i in enumerate(mutated_input_indices)
]
assert out_spec is not None
return graph_module, mutation, out_spec
@patch.object(torchdynamo.config, "dynamic_shapes", True)
@patch.object(torchdynamo.config, "capture_scalar_outputs", True)
@patch.object(torchdynamo.config, "guard_nn_modules", True)
@patch.object(torchdynamo.config, "specialize_int", True)
@patch.object(torchdynamo.config, "allow_rnn", True)
@patch.object(torchdynamo.config, "verbose", True)
def do_not_use_experimental_export(f: Callable, args: Tuple, training=False):
"""
This prototype is under heavy development. Pls don't use it if you are
not part of PyTorch 2.0 Export team.
"""
if training:
NotImplementedError("training mode is not supported yet")
flattened_args, in_spec = pytree.tree_flatten(args)
# Doing it twice so that if graph_module accidentally modifies the input
# we still get the same original input.
original_flat_args = tuple(flattened_args)
flat_args = tuple(flattened_args)
graph_module, guards = torchdynamo.export(f, *args, aten_graph=False)
# TODO (tmanlaibaatar) do sth with guards?
graph_module, _, out_spec = _aot_capture(graph_module, flat_args)
return ExportedProgram(fw_module=graph_module, example_inputs=original_flat_args, in_spec=in_spec, out_spec=out_spec)