| # Copyright (c) Meta Platforms, Inc. and affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| # pyre-strict |
| |
| import torch |
| from executorch.exir.dialects._ops import ops as exir_ops |
| from executorch.exir.pass_base import ExportPass, PassResult |
| |
| |
| def merge_view_copy_chains(graph: torch.fx.Graph) -> torch.fx.Graph: |
| """ |
| Find chains of view_copy nodes and merge them into one view_copy node. |
| Only merges view_copy nodes that are not used by any other nodes. |
| """ |
| ops = exir_ops.edge |
| view_op = ops.aten.view_copy.default |
| for node in graph.nodes: |
| if node.op == "call_function" and node.target == view_op: |
| # find ending view_copy node in chain |
| end_node = node |
| while ( |
| end_node.op == "call_function" |
| and end_node.target == view_op |
| and len(end_node.users) == 1 |
| and list(end_node.users)[0].target == view_op |
| ): |
| end_node = list(end_node.users)[0] |
| # we can swap the first node's shape arg with the last node's shape arg |
| if node != end_node: |
| with graph.inserting_after(node): |
| new_args = (node.args[0], end_node.args[1]) |
| node.args = new_args |
| end_node.replace_all_uses_with(node) |
| |
| graph.eliminate_dead_code() |
| return graph |
| |
| |
| class FuseViewCopyTransform(ExportPass): |
| def call(self, graph_module: torch.fx.GraphModule) -> PassResult: |
| graph_module.graph = merge_view_copy_chains(graph_module.graph) |
| return PassResult(graph_module, True) |