blob: bbc155dc4516860c5f666c04619824c095e463ba [file] [log] [blame]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
def merge_view_copy_chains(graph: torch.fx.Graph) -> torch.fx.Graph:
"""
Find chains of view_copy nodes and merge them into one view_copy node.
Only merges view_copy nodes that are not used by any other nodes.
"""
ops = exir_ops.edge
view_op = ops.aten.view_copy.default
for node in graph.nodes:
if node.op == "call_function" and node.target == view_op:
# find ending view_copy node in chain
end_node = node
while (
end_node.op == "call_function"
and end_node.target == view_op
and len(end_node.users) == 1
and list(end_node.users)[0].target == view_op
):
end_node = list(end_node.users)[0]
# we can swap the first node's shape arg with the last node's shape arg
if node != end_node:
with graph.inserting_after(node):
new_args = (node.args[0], end_node.args[1])
node.args = new_args
end_node.replace_all_uses_with(node)
graph.eliminate_dead_code()
return graph
class FuseViewCopyTransform(ExportPass):
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
graph_module.graph = merge_view_copy_chains(graph_module.graph)
return PassResult(graph_module, True)