| # Copyright 2023-2024 Arm Limited and/or its affiliates. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| # pyre-unsafe |
| |
| import logging |
| import os |
| from typing import Callable, final, List, Optional, Tuple |
| |
| import torch |
| from executorch.backends.arm.arm_backend import ArmBackend # usort: skip |
| from executorch.backends.arm._passes.tag_io_quant_pass import TagIOQuantPass |
| from executorch.backends.arm.operator_support.tosa_supported_operators import ( |
| TOSASupportedOperators, |
| ) |
| from executorch.backends.arm.tosa_specification import TosaSpecification |
| from executorch.exir.backend.compile_spec_schema import CompileSpec |
| from executorch.exir.backend.partitioner import ( |
| DelegationSpec, |
| Partitioner, |
| PartitionResult, |
| ) |
| from executorch.exir.backend.utils import tag_constant_data |
| from executorch.exir.passes import PassManager |
| from torch.export.exported_program import ExportedProgram |
| from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner |
| |
| logger = logging.getLogger(__name__) |
| logger.setLevel(logging.WARNING) |
| TOSA_DBG_VERBOSE = os.environ.get("TOSA_DBG_VERBOSE") == "1" |
| if TOSA_DBG_VERBOSE: |
| logging.basicConfig(level=logging.INFO) |
| logger.setLevel(logging.INFO) |
| |
| |
| @final |
| class ArmPartitioner(Partitioner): |
| def __init__(self, compile_spec: List[CompileSpec]) -> None: |
| self.delegation_spec = DelegationSpec(ArmBackend.__name__, compile_spec) |
| |
| def partition(self, exported_program: ExportedProgram) -> PartitionResult: |
| # Run the CapabilityBasedPartitioner to return the largest possible |
| # subgraphs containing the nodes with the tags |
| logger.info("ArmPartitioner::partition") |
| partition_tags = {} |
| |
| tosa_spec = TosaSpecification.create_from_compilespecs( |
| self.delegation_spec.compile_specs |
| ) |
| |
| logger.info(f"Partitioning for {tosa_spec}") |
| |
| for spec in self.delegation_spec.compile_specs: |
| if spec.key == "quantize_io" and spec.value.decode() == "True": |
| # Exclude IO quantization from the partition |
| passes = PassManager( |
| passes=[ |
| TagIOQuantPass(), |
| ] |
| ) |
| passes(exported_program.graph_module) |
| |
| capability_partitioner = CapabilityBasedPartitioner( |
| exported_program.graph_module, |
| TOSASupportedOperators(tosa_spec), |
| allows_single_node_partition=True, |
| ) |
| partition_list = capability_partitioner.propose_partitions() |
| for partition in partition_list: |
| for node in partition.nodes: |
| tag = f"tag{partition.id}" |
| node.meta["delegation_tag"] = tag |
| partition_tags[tag] = self.delegation_spec |
| |
| tag_constant_data(exported_program) |
| |
| return PartitionResult( |
| tagged_exported_program=exported_program, partition_tags=partition_tags |
| ) |
| |
| def ops_to_not_decompose( |
| self, |
| ep: ExportedProgram, |
| ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: |
| ops_to_not_decompose = [ |
| torch.ops.aten.linear.default, |
| torch.ops.aten.upsample_nearest2d.vec, |
| ] |
| return (ops_to_not_decompose, None) |