commit | d3a11a01980fdc2fd811a95330d342f280441d27 | [log] [tgz] |
---|---|---|
author | xinan.lin <xinan.lin@intel.com> | Tue Jul 16 10:50:10 2024 -0700 |
committer | PyTorch MergeBot <pytorchmergebot@users.noreply.github.com> | Wed Jul 17 10:13:36 2024 +0000 |
tree | 1a39db62873a0837a7849b7b9916f68e390cd21b | |
parent | 2af2d26562d0103571146374413a021d43ff5489 [diff] |
[Inductor] Handle device_put op in constant folding. (#130824) Fix #130823 Pull Request resolved: https://github.com/pytorch/pytorch/pull/130824 Approved by: https://github.com/eellison, https://github.com/EikanWang ghstack dependencies: #130817
diff --git a/torch/_inductor/fx_passes/joint_graph.py b/torch/_inductor/fx_passes/joint_graph.py index 680651e..791f366 100644 --- a/torch/_inductor/fx_passes/joint_graph.py +++ b/torch/_inductor/fx_passes/joint_graph.py
@@ -266,6 +266,10 @@ if isinstance(out, torch.Tensor) and out.numel() == 1: return out + # handle device_put op + if node.target == prims.device_put.default: + return super(ConstantFolder, self).run_node(node) + # constructors ops if ( node.op == "call_function"