[dynamo] Fix for #127696 (#128358)
Test Plan:
`buck2 test @//mode/dev-nosan //executorch/exir/backend/...`
https://www.internalfb.com/intern/testinfra/testrun/12666373989243932
Differential Revision: D58384518
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128358
Approved by: https://github.com/ydwu4
diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py
index 00932f9..59f8c26 100644
--- a/torch/_dynamo/variables/higher_order_ops.py
+++ b/torch/_dynamo/variables/higher_order_ops.py
@@ -12,7 +12,7 @@
import torch.fx
import torch.nn
import torch.onnx.operators
-from torch._dynamo.utils import deepcopy_to_fake_tensor, get_fake_value, get_real_value
+from torch._dynamo.utils import get_fake_value
from torch._dynamo.variables import ConstantVariable
from torch._dynamo.variables.base import VariableTracker
from torch._dynamo.variables.builtin import BuiltinVariable
@@ -1149,17 +1149,15 @@
p_args = tuple(arg.as_proxy() for arg in args[1:])
real_sub_args = pytree.tree_map_only(
- torch.fx.Proxy, lambda a: get_real_value(a.node, tx.output), p_args
+ torch.fx.Proxy, lambda a: get_fake_value(a.node, tx), p_args
)
- example_res = lowered_module.original_module.module()(*real_sub_args)
+ example_value = lowered_module.original_module.module()(*real_sub_args)
# NOTE [Guaranteeing the 1-1 correspondence of FakeTensors and real tensors]:
# executorch modules promise not to alias inputs and outputs.
# Thus, output FakeTensors will correctly not alias input FakeTensors.
- _assert_tensors_nonaliasing(real_sub_args, example_res)
-
- example_value = deepcopy_to_fake_tensor(example_res, tx.fake_mode)
+ _assert_tensors_nonaliasing(real_sub_args, example_value)
p_args = (lowered_node,) + p_args