| //===- llvm/CodeGen/TileShapeInfo.h - ---------------------------*- C++ -*-===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| /// \file Shape utility for AMX. |
| /// AMX hardware requires to config the shape of tile data register before use. |
| /// The 2D shape includes row and column. In AMX intrinsics interface the shape |
| /// is passed as 1st and 2nd parameter and they are lowered as the 1st and 2nd |
| /// machine operand of AMX pseudo instructions. ShapeT class is to facilitate |
| /// tile config and register allocator. The row and column are machine operand |
| /// of AMX pseudo instructions. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #ifndef LLVM_CODEGEN_TILESHAPEINFO_H |
| #define LLVM_CODEGEN_TILESHAPEINFO_H |
| |
| #include "llvm/ADT/DenseMapInfo.h" |
| #include "llvm/CodeGen/MachineInstr.h" |
| #include "llvm/CodeGen/MachineOperand.h" |
| #include "llvm/CodeGen/MachineRegisterInfo.h" |
| #include "llvm/CodeGen/Register.h" |
| |
| namespace llvm { |
| |
| class ShapeT { |
| public: |
| ShapeT(MachineOperand *Row, MachineOperand *Col, |
| const MachineRegisterInfo *MRI = nullptr) |
| : Row(Row), Col(Col) { |
| if (MRI) |
| deduceImm(MRI); |
| } |
| ShapeT() |
| : Row(nullptr), Col(nullptr), RowImm(InvalidImmShape), |
| ColImm(InvalidImmShape) {} |
| bool operator==(const ShapeT &Shape) const { |
| MachineOperand *R = Shape.Row; |
| MachineOperand *C = Shape.Col; |
| if (!R || !C) |
| return false; |
| if (!Row || !Col) |
| return false; |
| if (Row->getReg() == R->getReg() && Col->getReg() == C->getReg()) |
| return true; |
| if ((RowImm != InvalidImmShape) && (ColImm != InvalidImmShape)) |
| return RowImm == Shape.getRowImm() && ColImm == Shape.getColImm(); |
| return false; |
| } |
| |
| bool operator!=(const ShapeT &Shape) const { return !(*this == Shape); } |
| |
| MachineOperand *getRow() const { return Row; } |
| |
| MachineOperand *getCol() const { return Col; } |
| |
| int64_t getRowImm() const { return RowImm; } |
| |
| int64_t getColImm() const { return ColImm; } |
| |
| bool isValid() { return (Row != nullptr) && (Col != nullptr); } |
| |
| void deduceImm(const MachineRegisterInfo *MRI) { |
| // All def must be the same value, otherwise it is invalid MIs. |
| // Find the immediate. |
| // TODO copy propagation. |
| auto GetImm = [&](Register Reg) { |
| int64_t Imm = InvalidImmShape; |
| for (const MachineOperand &DefMO : MRI->def_operands(Reg)) { |
| const auto *MI = DefMO.getParent(); |
| if (MI->isMoveImmediate()) { |
| Imm = MI->getOperand(1).getImm(); |
| break; |
| } |
| } |
| return Imm; |
| }; |
| RowImm = GetImm(Row->getReg()); |
| ColImm = GetImm(Col->getReg()); |
| } |
| |
| private: |
| static constexpr int64_t InvalidImmShape = -1; |
| MachineOperand *Row; |
| MachineOperand *Col; |
| int64_t RowImm; |
| int64_t ColImm; |
| }; |
| |
| } // namespace llvm |
| |
| #endif |