| # Copyright (c) Meta Platforms, Inc. and affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| # pyre-strict |
| |
| """ |
| Helper functions for constructing a "leaf function" in FX graph. A "leaf |
| function" will be preserved as a call node in the the graph instead of |
| being traced through. |
| """ |
| |
| import torch |
| from executorch.exir.tracer import PythonTensor, unwrap_functional |
| |
| # pyre-fixme[21]: Could not find module `torch._C._functorch`. |
| from torch._C._functorch import ( # @manual=//caffe2/functorch:functorch" |
| is_functionaltensor, |
| ) |
| |
| from torch._functorch.eager_transforms import _assert_wrapped_functional # pyre-ignore |
| |
| |
| def update_with_proxy(t: torch.Tensor, proxy: torch.fx.Proxy) -> torch.Tensor: |
| unwrapped = unwrap_functional(t) |
| assert isinstance(unwrapped, PythonTensor) |
| unwrapped.update_proxy(proxy) |
| if is_functionaltensor(t): # type: ignore |
| _assert_wrapped_functional(unwrapped, t) |
| return t |