| # Copyright 2023-2024 Arm Limited and/or its affiliates. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| # pyre-unsafe |
| |
| from typing import List |
| |
| import serializer.tosa_serializer as ts |
| import torch |
| from executorch.backends.arm.operators.node_visitor import ( |
| NodeVisitor, |
| register_node_visitor, |
| ) |
| from executorch.backends.arm.tosa_mapping import TosaArg |
| from serializer.tosa_serializer import TosaOp |
| |
| |
| def permutation_vector_to_matrix(permutation_vector: list[int]) -> torch.Tensor: |
| """ |
| Converts a permutation vector of length N to a NxN matrix that describes the same permutation. |
| for example: |
| (1,0,2) |
| -> |
| [0 1 0] |
| |1 0 0| |
| [0 0 1] |
| """ |
| N = len(permutation_vector) |
| P = torch.zeros(N, N) |
| for row_index, col_index in enumerate(permutation_vector): |
| P[row_index][col_index] = 1 |
| return P |
| |
| |
| def permutation_matrix_to_vector(permutation_matrix: torch.Tensor) -> list[int]: |
| """ |
| Converts a NxN permutation matrix to a permutation vector of length N that describes the same permutation. |
| [0 1 0] |
| |1 0 0| |
| [0 0 1] |
| -> |
| (1,0,2) |
| """ |
| N = len(permutation_matrix) |
| assert N == len( |
| permutation_matrix[0] |
| ), f"A permutation matrix must be square, got shape {permutation_matrix.shape}" |
| |
| p = [0] * N |
| for row_index, row in enumerate(permutation_matrix): |
| saw_one = False |
| for col_index, value in enumerate(row): |
| if value == 1: |
| assert ( |
| not saw_one |
| ), f"A permutation matrix can only have one 1 per row, got row {row}." |
| p[row_index] = col_index |
| saw_one = True |
| else: |
| assert ( |
| value == 0 |
| ), f"A permutation matrix only contains 1's and 0's, got value {value}." |
| return p |
| |
| |
| @register_node_visitor |
| class PermuteVisitor(NodeVisitor): |
| target = "aten.permute_copy.default" |
| |
| def __init__(self, *args): |
| super().__init__(*args) |
| |
| def define_node( |
| self, |
| node: torch.fx.Node, |
| tosa_graph: ts.TosaSerializer, |
| inputs: List[TosaArg], |
| output: TosaArg, |
| is_quant_node: bool, |
| ) -> None: |
| # The permutation vector describes a permutation P in default Pytorch dim_order. |
| # For rank 4, the default dim_order NCHW. |
| # E.g. (2,3,0,1) -> permute (n,c,h,w) to (w,c,n,h) |
| permutation_vector = inputs[1].special |
| |
| if output.dim_order != tuple(range(len(output.dim_order))): |
| # the permutation vector can't be used directly if we are not in NCHW dim_order. |
| # We need to first transform to NCHW, apply P, |
| # and then transform back to the original dim_order. |
| # This transformation, S, is also a permutation, with the dim_order as permutation vector. |
| |
| # To do this, represent P and S with permutation matrices. |
| # Matrices can handle chained transformations and inversion easily. |
| S = permutation_vector_to_matrix(output.dim_order) |
| # The inverse of a permutation matrix is its transpose. |
| S_inverse = S.transpose(1, 0) |
| P = permutation_vector_to_matrix(permutation_vector) |
| |
| # The complete transformation is S * P * S_inverse. |
| transformation_matrix = S.matmul(P.matmul(S_inverse)) |
| |
| # Luckily, since it is just a combination of permutations, the result is also a permutation |
| # that can again be described by a new permutation vector. |
| permutation_vector = permutation_matrix_to_vector(transformation_matrix) |
| |
| attr = ts.TosaSerializerAttribute() |
| attr.TransposeAttribute(permutation_vector) |
| tosa_graph.addOperator( |
| TosaOp.Op().TRANSPOSE, [inputs[0].name], [output.name], attr |
| ) |