| # Copyright (c) Meta Platforms, Inc. and affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| # pyre-strict |
| |
| import copy |
| import json |
| import traceback |
| from contextlib import contextmanager |
| from dataclasses import asdict, dataclass |
| from typing import ( |
| Any, |
| Callable, |
| Dict, |
| Generator, |
| Iterable, |
| List, |
| Optional, |
| Set, |
| Tuple, |
| Union, |
| ) |
| |
| import executorch.extension.pytree as ex_pytree |
| import torch |
| import torch._dynamo as torchdynamo |
| import torch.fx as fx |
| |
| import torch.fx._pytree as fx_pytree |
| import torch.utils._pytree as pytree |
| |
| from executorch.exir.common import ( |
| extract_out_arguments, |
| format_schema_name, |
| no_dispatch, |
| setting_python_recursive_limit, |
| ) |
| from executorch.exir.error import ExportError, ExportErrorType, InternalError |
| from executorch.exir.graph_module import LeafValue |
| from executorch.exir.operator.convert import is_out_variant |
| from executorch.exir.types import ValueSpec |
| |
| from torch._C import _EnableTorchFunction, DisableTorchFunctionSubclass # @manual |
| from torch._decomp import core_aten_decompositions, get_decompositions |
| from torch._dynamo.guards import Guard |
| from torch._functorch.eager_transforms import _maybe_unwrap_functional_tensor |
| from torch.func import functionalize |
| from torch.fx.operator_schemas import normalize_function |
| from torch.utils._pytree import TreeSpec |
| |
| from typing_extensions import TypeAlias |
| |
| |
| Value: TypeAlias = Union[ |
| LeafValue, |
| Tuple["Value", ...], |
| List["Value"], |
| Dict[str, "Value"], |
| ] |
| |
| torchdynamo_enabled = False |
| |
| |
| def get_stacktrace() -> List[Dict[str, str]]: |
| """ |
| Get the current stacktrace (between trace() and __torch_dispatch__()) |
| Include the filename, function name, line number, and source code from the |
| start of the function to the given instruction. |
| |
| Return: |
| A list of stacktraces for each instruction along with the source code |
| context surrounding each instruction |
| """ |
| |
| stacktrace = traceback.extract_stack() |
| |
| # The stacktrace typically looks like this: |
| # |
| # 1. I stack frames from the top level runner (e.g., the |
| # test suite runner) |
| # 2. J frames in executorch/exir/tracer.py setting up tracing |
| # (call this INIT_EXIR) |
| # 3. K frames in user model code (this is what we want to save!) |
| # 4. 1 frame in executorch/exir/tracer.py __torch_function__ |
| # returning to tracer (call this TRACE_EXIR) |
| # 5. H frames in executorch/exir/tracer.py AND torch/_tensor.py |
| # doing all of the internal tracer handling |
| # |
| # The PyE tests assert that executorch/exir/tracer.py never shows |
| # up in the user provided stack traces, so we must oblige them. |
| # |
| # Assumptions: |
| # - Reentrant tracing is not a thing. Thus, the first time |
| # executorch/exir/tracer.py shows up in the trace, we know |
| # THAT is the point at which we start tracing. (An alternative |
| # is that the tracer entry point could record the stack trace |
| # at this time, but I didn't do this.) |
| # |
| # Our plan is to do a miniature stack machine traversing these |
| # stack machines. |
| |
| # Remove parts before the trace function and parts after entering |
| # __torch_dispatch__. Defaults to returning the entire stack trace. |
| init_exir_end = 0 |
| trace_exir_start = None |
| # A miniature state machine, referring to the frame segments described |
| # above. The locations are closed-open interval. |
| FIND_INIT_EXIR_START, FIND_INIT_EXIR_END, FIND_TRACE_EXIR_START = range(3) |
| state = FIND_INIT_EXIR_START |
| for i, frame in enumerate(stacktrace): |
| if state == FIND_INIT_EXIR_START: |
| if "executorch/exir/tracer.py" in frame.filename: |
| state = FIND_INIT_EXIR_END |
| elif state == FIND_INIT_EXIR_END: |
| if "executorch/exir/tracer.py" not in frame.filename: |
| init_exir_end = i |
| state = FIND_TRACE_EXIR_START |
| elif state == FIND_TRACE_EXIR_START: |
| if "executorch/exir/tracer.py" in frame.filename: |
| trace_exir_start = i |
| break |
| |
| stacktrace = stacktrace[init_exir_end:trace_exir_start] |
| |
| # Get the source code from the errored line to it |
| contexts: List[str] = [] |
| for s in stacktrace: |
| try: |
| with open(s.filename) as file: |
| # pyre-fixme[6]: For 1st param expected `Union[SupportsTrunc, bytes, |
| # str, SupportsInt, SupportsIndex]` but got `Optional[int]`. |
| lineno = int(s.lineno) |
| # Get the source code 5 lines above/below the current instruction |
| file_contents = [ |
| str(index + 1) + line for index, line in enumerate(file.readlines()) |
| ] |
| file_contents_above = "".join( |
| file_contents[max(lineno - 5, 0) : lineno] |
| ) |
| file_contents_below = "".join( |
| file_contents[lineno : min(lineno + 5, len(file_contents))] |
| ) |
| context = ( |
| file_contents_above |
| + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n" |
| + file_contents_below |
| ) |
| contexts.append(context) |
| except FileNotFoundError: |
| contexts.append("<unknown file: unknown line>") |
| |
| # torch.fx stack preservation logic expects strings to |
| # be passed around. Working with dictionary is lot easier |
| # to convert to string and vice versa. |
| frames: List[Dict[str, str]] = [] |
| for i, frame in enumerate(stacktrace): |
| frames.append( |
| { |
| "filename": str(frame.filename), |
| "lineno": str(frame.lineno), |
| "name": str(frame.name), |
| "line": str(frame.line), |
| "context": contexts[i], |
| } |
| ) |
| |
| return frames |
| |
| |
| def unwrap_functional(t: torch.Tensor) -> torch.Tensor: |
| assert isinstance(t, torch.Tensor) |
| return _maybe_unwrap_functional_tensor(t, reapply_views=False) |
| |
| |
| def unwrap_proxy(t: LeafValue) -> Union[LeafValue, torch.fx.Proxy]: |
| if not isinstance(t, torch.Tensor): |
| return t |
| t = unwrap_functional(t) |
| return t.proxy if isinstance(t, PythonTensor) else t |
| |
| |
| def single_return( |
| output: LeafValue, |
| proxy: torch.fx.Proxy, |
| wrapper: Callable[..., LeafValue], |
| ) -> LeafValue: |
| if isinstance(output, torch.Tensor): |
| return wrapper(output, proxy) |
| |
| return output |
| |
| |
| def tree_return( |
| outputs: Value, |
| proxy: torch.fx.Proxy, |
| wrapper: Callable[..., LeafValue], |
| meta_type: Callable[..., Iterable[ValueSpec]] = tuple, |
| ) -> Value: |
| i: int = 0 |
| |
| def wrap(o: LeafValue) -> LeafValue: |
| nonlocal i |
| ret = single_return(o, proxy[i], wrapper) |
| i += 1 |
| return ret |
| |
| return pytree.tree_map(wrap, outputs) |
| |
| |
| class DummyProxy: |
| def __init__(self) -> None: |
| class DummyNode: |
| def __init__(self): |
| self.meta = {} |
| |
| self.node = DummyNode() |
| |
| def __getitem__(self, key: str) -> "DummyProxy": |
| return DummyProxy() |
| |
| |
| class PythonTensor(torch.Tensor): |
| """ |
| A wrapper tensor subclass used in the DispatchTracer to keep track of |
| proxies to construct the FX graph. |
| |
| Wrapping something in PythonTensor implicitly detaches gradients. If |
| something required grad, we will collect it as if it were a leaf. A |
| consequence of detaching in this way is you need to maintain a parameter |
| cache when translating tensors into PythonTensor, so you don't create |
| multiple copies of a gradient (they are aliased, but they would count as |
| independent leaves). An alternate strategy would be to avoid implicitly |
| detaching and instead "catch" gradients as they exit the PythonTensor |
| boundary. |
| """ |
| |
| __slots__ = ["proxy", "is_immutable"] |
| |
| @staticmethod |
| def __new__( |
| cls, elem: torch.Tensor, proxy: torch.fx.Proxy, is_immutable: bool = False |
| ) -> torch.Tensor: |
| # assert not elem.requires_grad or not torch.is_grad_enabled() |
| |
| r = torch.Tensor._make_subclass(cls, elem, elem.requires_grad) |
| assert isinstance(r, PythonTensor) |
| r.is_immutable: bool = is_immutable |
| r.update_proxy(proxy) |
| return r |
| |
| def update_proxy(self, proxy: torch.fx.Proxy) -> None: |
| self.proxy = proxy |
| |
| def __repr__(self, *, tensor_contents: None = None) -> str: |
| with no_dispatch(): |
| return f"PythonTensor({self.as_subclass(torch.Tensor)})" |
| |
| @classmethod |
| def __torch_function__( |
| cls, |
| # pyre-ignore: Missing parameter annotation [2] |
| func, |
| # pyre-ignore: Missing parameter annotation [2] |
| types, |
| args: Tuple[Value, ...] = (), |
| kwargs: Optional[Dict[str, Value]] = None, |
| ) -> Value: |
| if kwargs is None: |
| kwargs = {} |
| if torch.is_inference_mode_enabled(): |
| if func is torch.nn.functional.layer_norm: |
| args, kwargs = normalize_function(func, args, kwargs) |
| input, normalized_shape = args |
| normalized_shape = list(normalized_shape) |
| return cls.__torch_dispatch__( |
| torch.ops.aten.layer_norm.default, |
| types, |
| (input, normalized_shape), |
| kwargs, |
| ) |
| elif func is torch.nn.functional.linear: |
| return cls.__torch_dispatch__( |
| torch.ops.aten.linear.default, types, args, kwargs |
| ) |
| with DisableTorchFunctionSubclass(): |
| return func(*args, **kwargs) |
| |
| @classmethod |
| def __torch_dispatch__( # noqa: C901 |
| cls, |
| func_overload: torch._ops.OpOverload, |
| # pyre-ignore: Missing parameter annotation [2] |
| types, |
| args: Tuple[Value, ...] = (), |
| kwargs: Optional[Dict[str, Value]] = None, |
| ) -> Value: |
| """ |
| This function is invoked every time an aten operation is called. |
| |
| Args: |
| func_overload: The function that was called that invoked this |
| torch_dispatch call |
| types: |
| args: Arguments that were passed into the function. Each argument |
| has type PythonTensor. |
| kwargs: Keyword arguments that were passed into the function. Each |
| argument has type PythonTensor. |
| """ |
| func = func_overload.overloadpacket |
| |
| kwargs = kwargs or {} |
| if is_out_variant(func._qualified_op_name, func_overload._overloadname): |
| out_args = extract_out_arguments(func_overload._schema, kwargs) |
| out_args_iter = [out_args] if not isinstance(out_args, list) else out_args |
| for out_arg_name, out_arg_val in out_args_iter: |
| if isinstance(out_arg_val, PythonTensor) and out_arg_val.is_immutable: |
| raise RuntimeError( |
| "Immutable tensor `{}` is potentially getting modified by {}".format( |
| out_arg_name, format_schema_name(func_overload._schema) |
| ) |
| ) |
| |
| # pyre-fixme[16]: Module `pytree` has no attribute `tree_map`. |
| proxy_args = ex_pytree.tree_map(unwrap_proxy, args) |
| # pyre-fixme[16]: Module `pytree` has no attribute `tree_map`. |
| proxy_kwargs = ex_pytree.tree_map(unwrap_proxy, kwargs) |
| |
| # Get the output of the function |
| g = _EnableTorchFunction() |
| try: |
| proxy_out = ( |
| func_overload(*proxy_args, **proxy_kwargs) |
| if DispatchTracer.get() or torchdynamo_enabled |
| # Disable node creation when no tracer is active. |
| else DummyProxy() |
| ) |
| finally: |
| del g |
| |
| with no_dispatch(): |
| real_out = func_overload(*args, **kwargs) |
| |
| # Kind of a hacky way to test if an op is in-place or not |
| if func.__name__[-1] == "_" and func.__name__[0] != "_": |
| if type(args[0]) == PythonTensor: |
| args[0].proxy = proxy_out |
| |
| if not torch.fx.traceback.has_preserved_node_meta(): |
| proxy_out.node.meta["stack_trace"] = json.dumps(get_stacktrace()) |
| |
| # Wrap the output tensors with the PythonTensor subclass to propagate to |
| # future tracing |
| def wrap_with_proxy(e: LeafValue, proxy: torch.fx.Proxy) -> LeafValue: |
| # Some ops (like native_batch_norm_backward) return undefined tensors that get |
| # converted into None in python. |
| # As the function signature expects tensors, if we directly return these None |
| # tensors back to C++, we'll error. |
| if e is None: |
| e = torch.empty(()) |
| |
| if type(e) == torch.Tensor: |
| return PythonTensor(e, proxy) |
| |
| # Inplace and out-variant ops may return one of their arguments, which is already |
| # a PythonTensor. In this case, we need to update the PythonTensor's associated |
| # proxy to the newly created proxy. |
| if type(e) == PythonTensor: |
| e.update_proxy(proxy) |
| return e |
| |
| return e |
| |
| retval = None |
| if not isinstance(real_out, (list, tuple)): |
| retval = single_return(real_out, proxy_out, wrap_with_proxy) |
| else: |
| retval = tree_return(real_out, proxy_out, wrap_with_proxy, type(real_out)) |
| return retval |
| |
| |
| @contextmanager |
| def using_tracer(tracer: Optional["DispatchTracer"]) -> Generator[None, None, None]: |
| """ |
| Set the "current" global tracer within the scope of using_tracer |
| context manager. |
| |
| Since various things we want to capture today with torch_dispatch |
| does not "trap" into dispatcher really (for example, cond() and |
| shape()), we need a separate singleton tracer exposed to user space |
| in addition to Dispatcher to trigger graph capturing. |
| """ |
| global TRACER |
| TRACER, prev = tracer, TRACER |
| try: |
| yield |
| finally: |
| TRACER = prev |
| |
| |
| class DispatchTracer(fx.Tracer): |
| def __init__(self) -> None: |
| super().__init__() |
| self.root: torch.nn.Module = torch.nn.Module() |
| self.tensor_attrs: Dict[torch.Tensor, str] = {} |
| self.submodules: Dict[fx.GraphModule, str] = {} |
| |
| def call_module( |
| self, |
| m: torch.nn.Module, |
| forward: Callable[..., Value], |
| args: Tuple[Value, ...], |
| kwargs: Dict[str, Value], |
| ) -> Value: |
| return forward(*args, **kwargs) |
| |
| def _module_getattr( |
| self, attr: str, attr_val: Value, parameter_proxy_cache: Dict[str, torch.Tensor] |
| ) -> Value: |
| if isinstance(attr_val, torch.nn.Parameter): |
| for n, p in self.root.named_parameters(): |
| if attr_val is p: |
| if n not in parameter_proxy_cache: |
| proxy = self.create_proxy("get_attr", n, (), {}) |
| parameter_proxy_cache[n] = PythonTensor(attr_val, proxy) |
| return parameter_proxy_cache[n] |
| return attr_val |
| return attr_val |
| |
| def create_arg(self, a: Value) -> torch.fx.Node: # noqa: C901 |
| if isinstance(a, torch.nn.Parameter): |
| for n, p in self.root.named_parameters(): |
| if a is p: |
| return self.create_node("get_attr", n, (), {}) |
| qualname: Optional[str] = None |
| |
| if not qualname: |
| i = 0 |
| while True: |
| qualname = f"_param_constant{i}" |
| if not hasattr(self.root, qualname): |
| break |
| i += 1 |
| setattr(self.root, qualname, a) |
| |
| return self.create_node("get_attr", qualname, (), {}) |
| |
| if isinstance(a, torch.Tensor): |
| qualname: Optional[str] = self.tensor_attrs.get(a) |
| |
| if not qualname: |
| i = 0 |
| while True: |
| qualname = f"_tensor_constant{i}" |
| if not hasattr(self.root, qualname): |
| break |
| i += 1 |
| self.tensor_attrs[a] = qualname |
| self.root.register_buffer(qualname, a) |
| |
| return self.create_node("get_attr", qualname, (), {}) |
| |
| # higher-order operator |
| if isinstance(a, fx.GraphModule): |
| if a not in self.submodules: |
| name_submodule = f"submodule_{len(self.submodules)}" |
| self.root.add_module(name_submodule, a) |
| self.submodules[a] = name_submodule |
| return self.create_node("get_attr", self.submodules[a], (), {}) |
| |
| return super().create_arg(a) |
| |
| @staticmethod |
| def get() -> "DispatchTracer": |
| return TRACER |
| |
| def trace( |
| self, |
| root: Callable[..., Value], |
| concrete_args: Tuple[Value, ...] = (), |
| in_spec: Optional[TreeSpec] = None, |
| ) -> Value: |
| """ |
| Traces the given graph module. |
| """ |
| with using_tracer(self): |
| return self._trace(root, concrete_args=concrete_args, in_spec=in_spec) |
| |
| def _trace( |
| self, |
| root: Callable[..., Value], |
| concrete_args: Tuple[Value, ...], |
| in_spec: Optional[TreeSpec], |
| ) -> Value: |
| self.root = torch.nn.Module() |
| root_fn = root |
| |
| tracer_cls = getattr(self, "__class__", None) |
| self.graph = fx.Graph(tracer_cls=tracer_cls) |
| # Don't support module, so tensor_attrs is always empty |
| self.tensor_attrs = {} |
| |
| # Wrap all inputs as a PythonTensor subclass and insert them into the FX |
| # graph as placeholder nodes |
| def wrap(arg: Value, i: int) -> Value: |
| placeholder = self.create_proxy("placeholder", f"ph_{i}", (), {}) |
| if isinstance(arg, torch.Tensor): |
| return PythonTensor(arg, placeholder, is_immutable=True) |
| else: |
| # torch._assert( |
| # placeholder == arg, |
| # f"ph_{i} has been specialized to have value {arg}", |
| # ) |
| return arg |
| |
| tree_args = [wrap(arg, i) for i, arg in enumerate(concrete_args)] |
| if in_spec: |
| tree_args = pytree.tree_unflatten(tree_args, in_spec) |
| |
| tree_out = root_fn(*tree_args) |
| |
| out_args, _ = pytree.tree_flatten(tree_out) |
| |
| def unwrap(out: LeafValue) -> Union[LeafValue, torch.fx.Proxy]: |
| # it's legit for a model to return a list of items some of which |
| # are None |
| if out is None: |
| return None |
| if not isinstance(out, torch.Tensor): |
| raise TypeError( |
| f"Expect model to return torch.Tensor, got type: '{type(out)}' (value: {out})." |
| ) |
| return unwrap_proxy(out) |
| |
| returns = [unwrap(out) for out in out_args] |
| |
| return_annotation = None |
| # some ops like torch.sub doesn't have annotations |
| if hasattr(root_fn, "__annotations__"): |
| return_annotation = root_fn.__annotations__.get("return", None) |
| |
| self.create_proxy( |
| "output", |
| "output", |
| (returns,), |
| {}, |
| type_expr=return_annotation, |
| ) |
| |
| self.submodule_paths = None |
| |
| return tree_out |
| |
| |
| TRACER: Optional[DispatchTracer] = None |
| TORCHDYNAMO_ENABLED: bool = False |
| |
| |
| @contextmanager |
| def using_dynamo(val: bool) -> Generator[None, None, None]: |
| global TORCHDYNAMO_ENABLED |
| TORCHDYNAMO_ENABLED, prev = val, TORCHDYNAMO_ENABLED |
| try: |
| yield |
| finally: |
| TORCHDYNAMO_ENABLED = prev |
| |
| |
| def flattened_dispatch_trace( |
| f: Callable[..., Value], |
| args: Tuple[LeafValue, ...], |
| guards: Set[Guard], |
| in_spec: Optional[TreeSpec] = None, |
| enable_functionalization: bool = True, |
| ) -> Tuple[torch.fx.GraphModule, Value]: |
| if not isinstance(args, tuple): |
| raise TypeError(f"Expecting 'args' to be a tuple, got: {type(args)}") |
| |
| tracer = DispatchTracer() |
| |
| if enable_functionalization: |
| f = functionalize(f, remove="mutations_and_views") |
| tree_out = tracer.trace(f, concrete_args=args, in_spec=in_spec) |
| |
| name = type(f).__name__ if isinstance(f, torch.nn.Module) else f.__name__ |
| gm = torch.fx.GraphModule(tracer.root, tracer.graph, name) |
| |
| return (gm, tree_out) |
| |
| |
| @dataclass |
| class ExirDynamoConfig: |
| """ |
| Manage Exir-specific configurations of Dynamo. |
| """ |
| |
| allow_rnn: bool = True |
| verbose: bool = True |
| assume_static_by_default: bool = False |
| |
| |
| def flatten_output(gm: torch.fx.GraphModule) -> None: |
| """ |
| Modifies the output nodes in the submodules to return the result |
| as a flattened list. This keeps it consistent with the result of |
| EXIR's tracer |
| """ |
| for node in reversed(gm.graph.nodes): |
| if node.op == "output": |
| assert len(node.args) == 1 |
| outputs = node.args[0] |
| returns, _ = pytree.tree_flatten(outputs) |
| node.args = (returns,) |
| return |
| raise RuntimeError(f"Could not find an output node in {gm.graph}") |
| |
| |
| def _default_decomposition_table( |
| _use_old_decomp_table=False, |
| ) -> Dict[torch._ops.OpOverload, Callable[..., Value]]: |
| if _use_old_decomp_table: |
| decomp_opset = [ |
| torch.ops.aten.log_sigmoid_forward, |
| torch.ops.aten.ones, |
| torch.ops.aten.arange.default, |
| torch.ops.aten.arange.start, |
| torch.ops.aten.transpose, |
| ] |
| # pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.e... |
| return get_decompositions(decomp_opset) |
| # pyre-fixme[7]: Expected `Dict[OpOverload, typing.Callable[..., executorch.exir.... |
| return core_aten_decompositions() |
| |
| |
| def dynamo_trace( |
| f: Callable[..., Value], |
| # pyre-ignore |
| args: Tuple[Any, ...], |
| aten_graph: bool, |
| tracing_mode: str = "real", |
| dynamo_config: Optional[ExirDynamoConfig] = None, |
| # pyre-ignore |
| dynamic_shapes: Optional[List[Any]] = None, |
| _use_old_decomp_table: bool = False, |
| ) -> Tuple[torch.fx.GraphModule, Set[Guard]]: |
| """ |
| TODO: Once we fully migrate to torchdynamo frontend, we will remove |
| this config option alltogether. For now, it helps with quick |
| experiments with playing around with TorchDynamo |
| """ |
| if dynamo_config is None: |
| dynamo_config = ExirDynamoConfig() |
| |
| with torchdynamo.config.patch( |
| asdict(dynamo_config) |
| ), setting_python_recursive_limit(2000): |
| torchdynamo.reset() |
| try: |
| # TODO merge executorch functionalization with official |
| # functionalization |
| # pyre-fixme[7]: Expected `Tuple[GraphModule, Set[Guard]]` but got |
| # `ExportResult`. |
| return torchdynamo.export( |
| f, |
| aten_graph=aten_graph, |
| tracing_mode=tracing_mode, |
| assume_static_by_default=dynamo_config.assume_static_by_default, |
| decomposition_table=( |
| _default_decomposition_table(_use_old_decomp_table) |
| if aten_graph |
| else None |
| ), |
| dynamic_shapes=dynamic_shapes, |
| )( |
| *copy.deepcopy(args), |
| ) |
| except torchdynamo.exc.Unsupported as exc: |
| raise ExportError( |
| ExportErrorType.NOT_SUPPORTED, |
| "The user code is using a feature we don't support. " |
| "Please try torchdynamo.explain() to get possible the reasons", |
| ) from exc |
| except Exception as exc: |
| raise InternalError( |
| "torchdynamo internal error occured. Please see above stacktrace" |
| ) from exc |
| |
| |
| def dispatch_trace( |
| f: Callable[..., Value], |
| args: Tuple[Value, ...], |
| ) -> torch.fx.GraphModule: |
| """ |
| Executes a given callable `f` with a given tuple of arguments. During |
| execution, Tensor operations are recorded in a fx.GraphModule, which is then |
| returned. |
| |
| Args: |
| f: A `nn.Module` or a Python function that implements an ML program. |
| args: A tuple of arguments of any type to be used as inputs for the tracing run. |
| |
| Returns: |
| EXIR contained in a fx.GraphModule |
| """ |
| trace_func = f |
| guards = set() |
| if TORCHDYNAMO_ENABLED: |
| # Copying args is safer in case downstream implementations of trace_func mutate them |
| trace_func, guards = dynamo_trace(trace_func, args, False) |
| |
| # Copying args is safer in case downstream implementations of trace_func mutate them |
| trace_args, in_spec = pytree.tree_flatten(args) |
| |
| in_args = copy.deepcopy(tuple(trace_args)) |
| gm, tree_out = flattened_dispatch_trace( |
| trace_func, |
| in_args, |
| guards, |
| in_spec, |
| enable_functionalization=False, |
| ) |
| |
| _, out_spec = pytree.tree_flatten(tree_out) |
| |
| gm.in_spec = in_spec |
| gm.out_spec = out_spec |
| |
| # TODO (tmanlaibaatar) This is bit clowny, but our |
| # dispatch_trace sometimes creates unused node that |
| # breaks functionalization. it seems too much trouble |
| # to fix it properly since dispatch_trace will be deprecated soon. |
| # Basically dispatch_trace struggles on: |
| # def f(x: torch.Tensor) -> torch.Tensor: |
| # return torch.ones(6, dtype=x.dtype) |
| changed = gm.graph.eliminate_dead_code() |
| if changed: |
| gm.recompile() |
| |
| in_args = copy.deepcopy(tuple(trace_args)) |
| assert callable(gm) |
| |
| # This wrapper is used for preserving the stacktrace |
| # during second round of tracing. |
| # pyre-ignore |
| def graph_with_interpreter(*args): |
| try: |
| args = fx_pytree.tree_flatten_spec(args, gm.in_spec) # type: ignore[assignment] |
| except Exception: |
| _, received_spec = pytree.tree_flatten(args) |
| raise RuntimeError( |
| "Trying to flatten user inputs with exported input tree spec: \n" |
| f"{gm.in_spec}\n" |
| "but actually got inputs with tree spec of: \n" |
| f"{received_spec}" |
| ) |
| with torch.fx.traceback.preserve_node_meta(): |
| res = gm(*args) |
| |
| if gm.out_spec is not None: |
| try: |
| res = pytree.tree_unflatten(res, gm.out_spec) |
| except Exception: |
| _, received_spec = pytree.tree_flatten(res) |
| raise RuntimeError( |
| "Trying to flatten user outputs with exported output tree spec: \n" |
| f"{gm.out_spec}\n" |
| "but actually got outputs with tree spec of: \n" |
| f"{received_spec}" |
| ) |
| return res |
| |
| gm, tree_out = flattened_dispatch_trace( |
| graph_with_interpreter, |
| in_args, |
| guards, |
| in_spec, |
| enable_functionalization=True, |
| ) |
| |
| gm.in_spec = in_spec |
| gm.out_spec = out_spec |
| |
| return gm |