| # Copyright (c) Meta Platforms, Inc. and affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| # pyre-strict |
| |
| import torch |
| from executorch.exir.pass_base import ExportPass, PassResult |
| from torch._decomp import get_decompositions |
| from torch.fx.experimental.proxy_tensor import make_fx |
| |
| |
| class DecomposeScaledDotProductAttention(ExportPass): |
| """ |
| Decompose from scaled_dot_product_attention to multiple nodes. |
| """ |
| |
| def __init__(self, allow_non_fake_inputs: bool = True) -> None: |
| super().__init__() |
| # With allow_non_fake_inputs=False, we don't get _unsafe_view ops |
| # in the graph, we allow disabling it here. |
| self._allow_non_fake_inputs = allow_non_fake_inputs |
| |
| def call( |
| self, graph_module: torch.fx.GraphModule, allow_non_fake_inputs: bool = True |
| ) -> PassResult: |
| graph = graph_module.graph |
| for node in graph.nodes: |
| if node.target == torch.ops.aten.scaled_dot_product_attention.default: |
| input_tensors = (arg.meta["val"] for arg in node.args) |
| |
| # refer to pytorch/test/test_decomp.py |
| decomposed_module = make_fx( |
| node.target, |
| decomposition_table=get_decompositions( # pyre-fixme[6] |
| [ |
| torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default, |
| ] |
| ), |
| tracing_mode="fake", |
| _allow_non_fake_inputs=allow_non_fake_inputs, |
| )(*input_tensors) |
| with graph.inserting_before(node): |
| name_to_input_tensor_map = {} |
| for i, arg in enumerate(node.args): |
| name_to_input_tensor_map[f"arg{i}_1"] = arg |
| |
| decomposed_node_to_subgraph_node = {} |
| last_decomposed_node = None |
| # Create a mapping from input nodes in decomposed module to original nodes. |
| # In decomposed module, there are only input tensors for placeholder op. |
| for decomposed_node in decomposed_module.graph.nodes: |
| if decomposed_node.op == "placeholder": |
| decomposed_node_to_subgraph_node[decomposed_node] = ( |
| name_to_input_tensor_map[decomposed_node.name] |
| ) |
| |
| if decomposed_node.op == "output": |
| last_decomposed_node = decomposed_node.args[0] |
| |
| # Copy node from decompose graph module |
| for decomposed_node in decomposed_module.graph.nodes: |
| if decomposed_node.op == "placeholder": |
| continue |
| |
| if ( |
| decomposed_node.op == "output" |
| and last_decomposed_node is not None |
| ): |
| for user in node.users.copy(): |
| user.replace_input_with( |
| node, |
| decomposed_node_to_subgraph_node[ |
| last_decomposed_node |
| ], |
| ) |
| continue |
| |
| subgraph_node = graph.node_copy( |
| decomposed_node, |
| arg_transform=lambda x: decomposed_node_to_subgraph_node[ # noqa: B023 |
| x |
| ], |
| ) |
| subgraph_node.meta["source_fn_stack"] = [ |
| (subgraph_node, subgraph_node.target) |
| ] |
| decomposed_node_to_subgraph_node[decomposed_node] = ( |
| subgraph_node |
| ) |
| |
| graph.erase_node(node) |
| |
| graph.eliminate_dead_code() |
| graph_module.recompile() |
| return PassResult(graph_module, True) |