| //===- GVNExpression.h - GVN Expression classes -----------------*- C++ -*-===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| /// \file |
| /// |
| /// The header file for the GVN pass that contains expression handling |
| /// classes |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #ifndef LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H |
| #define LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H |
| |
| #include "llvm/ADT/Hashing.h" |
| #include "llvm/ADT/iterator_range.h" |
| #include "llvm/Analysis/MemorySSA.h" |
| #include "llvm/IR/Constant.h" |
| #include "llvm/IR/Instructions.h" |
| #include "llvm/IR/Value.h" |
| #include "llvm/Support/Allocator.h" |
| #include "llvm/Support/ArrayRecycler.h" |
| #include "llvm/Support/Casting.h" |
| #include "llvm/Support/Compiler.h" |
| #include "llvm/Support/raw_ostream.h" |
| #include <algorithm> |
| #include <cassert> |
| #include <iterator> |
| #include <utility> |
| |
| namespace llvm { |
| |
| class BasicBlock; |
| class Type; |
| |
| namespace GVNExpression { |
| |
| enum ExpressionType { |
| ET_Base, |
| ET_Constant, |
| ET_Variable, |
| ET_Dead, |
| ET_Unknown, |
| ET_BasicStart, |
| ET_Basic, |
| ET_AggregateValue, |
| ET_Phi, |
| ET_MemoryStart, |
| ET_Call, |
| ET_Load, |
| ET_Store, |
| ET_MemoryEnd, |
| ET_BasicEnd |
| }; |
| |
| class Expression { |
| private: |
| ExpressionType EType; |
| unsigned Opcode; |
| mutable hash_code HashVal = 0; |
| |
| public: |
| Expression(ExpressionType ET = ET_Base, unsigned O = ~2U) |
| : EType(ET), Opcode(O) {} |
| Expression(const Expression &) = delete; |
| Expression &operator=(const Expression &) = delete; |
| virtual ~Expression(); |
| |
| static unsigned getEmptyKey() { return ~0U; } |
| static unsigned getTombstoneKey() { return ~1U; } |
| |
| bool operator!=(const Expression &Other) const { return !(*this == Other); } |
| bool operator==(const Expression &Other) const { |
| if (getOpcode() != Other.getOpcode()) |
| return false; |
| if (getOpcode() == getEmptyKey() || getOpcode() == getTombstoneKey()) |
| return true; |
| // Compare the expression type for anything but load and store. |
| // For load and store we set the opcode to zero to make them equal. |
| if (getExpressionType() != ET_Load && getExpressionType() != ET_Store && |
| getExpressionType() != Other.getExpressionType()) |
| return false; |
| |
| return equals(Other); |
| } |
| |
| hash_code getComputedHash() const { |
| // It's theoretically possible for a thing to hash to zero. In that case, |
| // we will just compute the hash a few extra times, which is no worse that |
| // we did before, which was to compute it always. |
| if (static_cast<unsigned>(HashVal) == 0) |
| HashVal = getHashValue(); |
| return HashVal; |
| } |
| |
| virtual bool equals(const Expression &Other) const { return true; } |
| |
| // Return true if the two expressions are exactly the same, including the |
| // normally ignored fields. |
| virtual bool exactlyEquals(const Expression &Other) const { |
| return getExpressionType() == Other.getExpressionType() && equals(Other); |
| } |
| |
| unsigned getOpcode() const { return Opcode; } |
| void setOpcode(unsigned opcode) { Opcode = opcode; } |
| ExpressionType getExpressionType() const { return EType; } |
| |
| // We deliberately leave the expression type out of the hash value. |
| virtual hash_code getHashValue() const { return getOpcode(); } |
| |
| // Debugging support |
| virtual void printInternal(raw_ostream &OS, bool PrintEType) const { |
| if (PrintEType) |
| OS << "etype = " << getExpressionType() << ","; |
| OS << "opcode = " << getOpcode() << ", "; |
| } |
| |
| void print(raw_ostream &OS) const { |
| OS << "{ "; |
| printInternal(OS, true); |
| OS << "}"; |
| } |
| |
| LLVM_DUMP_METHOD void dump() const; |
| }; |
| |
| inline raw_ostream &operator<<(raw_ostream &OS, const Expression &E) { |
| E.print(OS); |
| return OS; |
| } |
| |
| class BasicExpression : public Expression { |
| private: |
| using RecyclerType = ArrayRecycler<Value *>; |
| using RecyclerCapacity = RecyclerType::Capacity; |
| |
| Value **Operands = nullptr; |
| unsigned MaxOperands; |
| unsigned NumOperands = 0; |
| Type *ValueType = nullptr; |
| |
| public: |
| BasicExpression(unsigned NumOperands) |
| : BasicExpression(NumOperands, ET_Basic) {} |
| BasicExpression(unsigned NumOperands, ExpressionType ET) |
| : Expression(ET), MaxOperands(NumOperands) {} |
| BasicExpression() = delete; |
| BasicExpression(const BasicExpression &) = delete; |
| BasicExpression &operator=(const BasicExpression &) = delete; |
| ~BasicExpression() override; |
| |
| static bool classof(const Expression *EB) { |
| ExpressionType ET = EB->getExpressionType(); |
| return ET > ET_BasicStart && ET < ET_BasicEnd; |
| } |
| |
| /// Swap two operands. Used during GVN to put commutative operands in |
| /// order. |
| void swapOperands(unsigned First, unsigned Second) { |
| std::swap(Operands[First], Operands[Second]); |
| } |
| |
| Value *getOperand(unsigned N) const { |
| assert(Operands && "Operands not allocated"); |
| assert(N < NumOperands && "Operand out of range"); |
| return Operands[N]; |
| } |
| |
| void setOperand(unsigned N, Value *V) { |
| assert(Operands && "Operands not allocated before setting"); |
| assert(N < NumOperands && "Operand out of range"); |
| Operands[N] = V; |
| } |
| |
| unsigned getNumOperands() const { return NumOperands; } |
| |
| using op_iterator = Value **; |
| using const_op_iterator = Value *const *; |
| |
| op_iterator op_begin() { return Operands; } |
| op_iterator op_end() { return Operands + NumOperands; } |
| const_op_iterator op_begin() const { return Operands; } |
| const_op_iterator op_end() const { return Operands + NumOperands; } |
| iterator_range<op_iterator> operands() { |
| return iterator_range<op_iterator>(op_begin(), op_end()); |
| } |
| iterator_range<const_op_iterator> operands() const { |
| return iterator_range<const_op_iterator>(op_begin(), op_end()); |
| } |
| |
| void op_push_back(Value *Arg) { |
| assert(NumOperands < MaxOperands && "Tried to add too many operands"); |
| assert(Operands && "Operandss not allocated before pushing"); |
| Operands[NumOperands++] = Arg; |
| } |
| bool op_empty() const { return getNumOperands() == 0; } |
| |
| void allocateOperands(RecyclerType &Recycler, BumpPtrAllocator &Allocator) { |
| assert(!Operands && "Operands already allocated"); |
| Operands = Recycler.allocate(RecyclerCapacity::get(MaxOperands), Allocator); |
| } |
| void deallocateOperands(RecyclerType &Recycler) { |
| Recycler.deallocate(RecyclerCapacity::get(MaxOperands), Operands); |
| } |
| |
| void setType(Type *T) { ValueType = T; } |
| Type *getType() const { return ValueType; } |
| |
| bool equals(const Expression &Other) const override { |
| if (getOpcode() != Other.getOpcode()) |
| return false; |
| |
| const auto &OE = cast<BasicExpression>(Other); |
| return getType() == OE.getType() && NumOperands == OE.NumOperands && |
| std::equal(op_begin(), op_end(), OE.op_begin()); |
| } |
| |
| hash_code getHashValue() const override { |
| return hash_combine(this->Expression::getHashValue(), ValueType, |
| hash_combine_range(op_begin(), op_end())); |
| } |
| |
| // Debugging support |
| void printInternal(raw_ostream &OS, bool PrintEType) const override { |
| if (PrintEType) |
| OS << "ExpressionTypeBasic, "; |
| |
| this->Expression::printInternal(OS, false); |
| OS << "operands = {"; |
| for (unsigned i = 0, e = getNumOperands(); i != e; ++i) { |
| OS << "[" << i << "] = "; |
| Operands[i]->printAsOperand(OS); |
| OS << " "; |
| } |
| OS << "} "; |
| } |
| }; |
| |
| class op_inserter { |
| private: |
| using Container = BasicExpression; |
| |
| Container *BE; |
| |
| public: |
| using iterator_category = std::output_iterator_tag; |
| using value_type = void; |
| using difference_type = void; |
| using pointer = void; |
| using reference = void; |
| |
| explicit op_inserter(BasicExpression &E) : BE(&E) {} |
| explicit op_inserter(BasicExpression *E) : BE(E) {} |
| |
| op_inserter &operator=(Value *val) { |
| BE->op_push_back(val); |
| return *this; |
| } |
| op_inserter &operator*() { return *this; } |
| op_inserter &operator++() { return *this; } |
| op_inserter &operator++(int) { return *this; } |
| }; |
| |
| class MemoryExpression : public BasicExpression { |
| private: |
| const MemoryAccess *MemoryLeader; |
| |
| public: |
| MemoryExpression(unsigned NumOperands, enum ExpressionType EType, |
| const MemoryAccess *MemoryLeader) |
| : BasicExpression(NumOperands, EType), MemoryLeader(MemoryLeader) {} |
| MemoryExpression() = delete; |
| MemoryExpression(const MemoryExpression &) = delete; |
| MemoryExpression &operator=(const MemoryExpression &) = delete; |
| |
| static bool classof(const Expression *EB) { |
| return EB->getExpressionType() > ET_MemoryStart && |
| EB->getExpressionType() < ET_MemoryEnd; |
| } |
| |
| hash_code getHashValue() const override { |
| return hash_combine(this->BasicExpression::getHashValue(), MemoryLeader); |
| } |
| |
| bool equals(const Expression &Other) const override { |
| if (!this->BasicExpression::equals(Other)) |
| return false; |
| const MemoryExpression &OtherMCE = cast<MemoryExpression>(Other); |
| |
| return MemoryLeader == OtherMCE.MemoryLeader; |
| } |
| |
| const MemoryAccess *getMemoryLeader() const { return MemoryLeader; } |
| void setMemoryLeader(const MemoryAccess *ML) { MemoryLeader = ML; } |
| }; |
| |
| class CallExpression final : public MemoryExpression { |
| private: |
| CallInst *Call; |
| |
| public: |
| CallExpression(unsigned NumOperands, CallInst *C, |
| const MemoryAccess *MemoryLeader) |
| : MemoryExpression(NumOperands, ET_Call, MemoryLeader), Call(C) {} |
| CallExpression() = delete; |
| CallExpression(const CallExpression &) = delete; |
| CallExpression &operator=(const CallExpression &) = delete; |
| ~CallExpression() override; |
| |
| static bool classof(const Expression *EB) { |
| return EB->getExpressionType() == ET_Call; |
| } |
| |
| // Debugging support |
| void printInternal(raw_ostream &OS, bool PrintEType) const override { |
| if (PrintEType) |
| OS << "ExpressionTypeCall, "; |
| this->BasicExpression::printInternal(OS, false); |
| OS << " represents call at "; |
| Call->printAsOperand(OS); |
| } |
| }; |
| |
| class LoadExpression final : public MemoryExpression { |
| private: |
| LoadInst *Load; |
| |
| public: |
| LoadExpression(unsigned NumOperands, LoadInst *L, |
| const MemoryAccess *MemoryLeader) |
| : LoadExpression(ET_Load, NumOperands, L, MemoryLeader) {} |
| |
| LoadExpression(enum ExpressionType EType, unsigned NumOperands, LoadInst *L, |
| const MemoryAccess *MemoryLeader) |
| : MemoryExpression(NumOperands, EType, MemoryLeader), Load(L) {} |
| |
| LoadExpression() = delete; |
| LoadExpression(const LoadExpression &) = delete; |
| LoadExpression &operator=(const LoadExpression &) = delete; |
| ~LoadExpression() override; |
| |
| static bool classof(const Expression *EB) { |
| return EB->getExpressionType() == ET_Load; |
| } |
| |
| LoadInst *getLoadInst() const { return Load; } |
| void setLoadInst(LoadInst *L) { Load = L; } |
| |
| bool equals(const Expression &Other) const override; |
| bool exactlyEquals(const Expression &Other) const override { |
| return Expression::exactlyEquals(Other) && |
| cast<LoadExpression>(Other).getLoadInst() == getLoadInst(); |
| } |
| |
| // Debugging support |
| void printInternal(raw_ostream &OS, bool PrintEType) const override { |
| if (PrintEType) |
| OS << "ExpressionTypeLoad, "; |
| this->BasicExpression::printInternal(OS, false); |
| OS << " represents Load at "; |
| Load->printAsOperand(OS); |
| OS << " with MemoryLeader " << *getMemoryLeader(); |
| } |
| }; |
| |
| class StoreExpression final : public MemoryExpression { |
| private: |
| StoreInst *Store; |
| Value *StoredValue; |
| |
| public: |
| StoreExpression(unsigned NumOperands, StoreInst *S, Value *StoredValue, |
| const MemoryAccess *MemoryLeader) |
| : MemoryExpression(NumOperands, ET_Store, MemoryLeader), Store(S), |
| StoredValue(StoredValue) {} |
| StoreExpression() = delete; |
| StoreExpression(const StoreExpression &) = delete; |
| StoreExpression &operator=(const StoreExpression &) = delete; |
| ~StoreExpression() override; |
| |
| static bool classof(const Expression *EB) { |
| return EB->getExpressionType() == ET_Store; |
| } |
| |
| StoreInst *getStoreInst() const { return Store; } |
| Value *getStoredValue() const { return StoredValue; } |
| |
| bool equals(const Expression &Other) const override; |
| |
| bool exactlyEquals(const Expression &Other) const override { |
| return Expression::exactlyEquals(Other) && |
| cast<StoreExpression>(Other).getStoreInst() == getStoreInst(); |
| } |
| |
| // Debugging support |
| void printInternal(raw_ostream &OS, bool PrintEType) const override { |
| if (PrintEType) |
| OS << "ExpressionTypeStore, "; |
| this->BasicExpression::printInternal(OS, false); |
| OS << " represents Store " << *Store; |
| OS << " with StoredValue "; |
| StoredValue->printAsOperand(OS); |
| OS << " and MemoryLeader " << *getMemoryLeader(); |
| } |
| }; |
| |
| class AggregateValueExpression final : public BasicExpression { |
| private: |
| unsigned MaxIntOperands; |
| unsigned NumIntOperands = 0; |
| unsigned *IntOperands = nullptr; |
| |
| public: |
| AggregateValueExpression(unsigned NumOperands, unsigned NumIntOperands) |
| : BasicExpression(NumOperands, ET_AggregateValue), |
| MaxIntOperands(NumIntOperands) {} |
| AggregateValueExpression() = delete; |
| AggregateValueExpression(const AggregateValueExpression &) = delete; |
| AggregateValueExpression & |
| operator=(const AggregateValueExpression &) = delete; |
| ~AggregateValueExpression() override; |
| |
| static bool classof(const Expression *EB) { |
| return EB->getExpressionType() == ET_AggregateValue; |
| } |
| |
| using int_arg_iterator = unsigned *; |
| using const_int_arg_iterator = const unsigned *; |
| |
| int_arg_iterator int_op_begin() { return IntOperands; } |
| int_arg_iterator int_op_end() { return IntOperands + NumIntOperands; } |
| const_int_arg_iterator int_op_begin() const { return IntOperands; } |
| const_int_arg_iterator int_op_end() const { |
| return IntOperands + NumIntOperands; |
| } |
| unsigned int_op_size() const { return NumIntOperands; } |
| bool int_op_empty() const { return NumIntOperands == 0; } |
| void int_op_push_back(unsigned IntOperand) { |
| assert(NumIntOperands < MaxIntOperands && |
| "Tried to add too many int operands"); |
| assert(IntOperands && "Operands not allocated before pushing"); |
| IntOperands[NumIntOperands++] = IntOperand; |
| } |
| |
| virtual void allocateIntOperands(BumpPtrAllocator &Allocator) { |
| assert(!IntOperands && "Operands already allocated"); |
| IntOperands = Allocator.Allocate<unsigned>(MaxIntOperands); |
| } |
| |
| bool equals(const Expression &Other) const override { |
| if (!this->BasicExpression::equals(Other)) |
| return false; |
| const AggregateValueExpression &OE = cast<AggregateValueExpression>(Other); |
| return NumIntOperands == OE.NumIntOperands && |
| std::equal(int_op_begin(), int_op_end(), OE.int_op_begin()); |
| } |
| |
| hash_code getHashValue() const override { |
| return hash_combine(this->BasicExpression::getHashValue(), |
| hash_combine_range(int_op_begin(), int_op_end())); |
| } |
| |
| // Debugging support |
| void printInternal(raw_ostream &OS, bool PrintEType) const override { |
| if (PrintEType) |
| OS << "ExpressionTypeAggregateValue, "; |
| this->BasicExpression::printInternal(OS, false); |
| OS << ", intoperands = {"; |
| for (unsigned i = 0, e = int_op_size(); i != e; ++i) { |
| OS << "[" << i << "] = " << IntOperands[i] << " "; |
| } |
| OS << "}"; |
| } |
| }; |
| |
| class int_op_inserter { |
| private: |
| using Container = AggregateValueExpression; |
| |
| Container *AVE; |
| |
| public: |
| using iterator_category = std::output_iterator_tag; |
| using value_type = void; |
| using difference_type = void; |
| using pointer = void; |
| using reference = void; |
| |
| explicit int_op_inserter(AggregateValueExpression &E) : AVE(&E) {} |
| explicit int_op_inserter(AggregateValueExpression *E) : AVE(E) {} |
| |
| int_op_inserter &operator=(unsigned int val) { |
| AVE->int_op_push_back(val); |
| return *this; |
| } |
| int_op_inserter &operator*() { return *this; } |
| int_op_inserter &operator++() { return *this; } |
| int_op_inserter &operator++(int) { return *this; } |
| }; |
| |
| class PHIExpression final : public BasicExpression { |
| private: |
| BasicBlock *BB; |
| |
| public: |
| PHIExpression(unsigned NumOperands, BasicBlock *B) |
| : BasicExpression(NumOperands, ET_Phi), BB(B) {} |
| PHIExpression() = delete; |
| PHIExpression(const PHIExpression &) = delete; |
| PHIExpression &operator=(const PHIExpression &) = delete; |
| ~PHIExpression() override; |
| |
| static bool classof(const Expression *EB) { |
| return EB->getExpressionType() == ET_Phi; |
| } |
| |
| bool equals(const Expression &Other) const override { |
| if (!this->BasicExpression::equals(Other)) |
| return false; |
| const PHIExpression &OE = cast<PHIExpression>(Other); |
| return BB == OE.BB; |
| } |
| |
| hash_code getHashValue() const override { |
| return hash_combine(this->BasicExpression::getHashValue(), BB); |
| } |
| |
| // Debugging support |
| void printInternal(raw_ostream &OS, bool PrintEType) const override { |
| if (PrintEType) |
| OS << "ExpressionTypePhi, "; |
| this->BasicExpression::printInternal(OS, false); |
| OS << "bb = " << BB; |
| } |
| }; |
| |
| class DeadExpression final : public Expression { |
| public: |
| DeadExpression() : Expression(ET_Dead) {} |
| DeadExpression(const DeadExpression &) = delete; |
| DeadExpression &operator=(const DeadExpression &) = delete; |
| |
| static bool classof(const Expression *E) { |
| return E->getExpressionType() == ET_Dead; |
| } |
| }; |
| |
| class VariableExpression final : public Expression { |
| private: |
| Value *VariableValue; |
| |
| public: |
| VariableExpression(Value *V) : Expression(ET_Variable), VariableValue(V) {} |
| VariableExpression() = delete; |
| VariableExpression(const VariableExpression &) = delete; |
| VariableExpression &operator=(const VariableExpression &) = delete; |
| |
| static bool classof(const Expression *EB) { |
| return EB->getExpressionType() == ET_Variable; |
| } |
| |
| Value *getVariableValue() const { return VariableValue; } |
| void setVariableValue(Value *V) { VariableValue = V; } |
| |
| bool equals(const Expression &Other) const override { |
| const VariableExpression &OC = cast<VariableExpression>(Other); |
| return VariableValue == OC.VariableValue; |
| } |
| |
| hash_code getHashValue() const override { |
| return hash_combine(this->Expression::getHashValue(), |
| VariableValue->getType(), VariableValue); |
| } |
| |
| // Debugging support |
| void printInternal(raw_ostream &OS, bool PrintEType) const override { |
| if (PrintEType) |
| OS << "ExpressionTypeVariable, "; |
| this->Expression::printInternal(OS, false); |
| OS << " variable = " << *VariableValue; |
| } |
| }; |
| |
| class ConstantExpression final : public Expression { |
| private: |
| Constant *ConstantValue = nullptr; |
| |
| public: |
| ConstantExpression() : Expression(ET_Constant) {} |
| ConstantExpression(Constant *constantValue) |
| : Expression(ET_Constant), ConstantValue(constantValue) {} |
| ConstantExpression(const ConstantExpression &) = delete; |
| ConstantExpression &operator=(const ConstantExpression &) = delete; |
| |
| static bool classof(const Expression *EB) { |
| return EB->getExpressionType() == ET_Constant; |
| } |
| |
| Constant *getConstantValue() const { return ConstantValue; } |
| void setConstantValue(Constant *V) { ConstantValue = V; } |
| |
| bool equals(const Expression &Other) const override { |
| const ConstantExpression &OC = cast<ConstantExpression>(Other); |
| return ConstantValue == OC.ConstantValue; |
| } |
| |
| hash_code getHashValue() const override { |
| return hash_combine(this->Expression::getHashValue(), |
| ConstantValue->getType(), ConstantValue); |
| } |
| |
| // Debugging support |
| void printInternal(raw_ostream &OS, bool PrintEType) const override { |
| if (PrintEType) |
| OS << "ExpressionTypeConstant, "; |
| this->Expression::printInternal(OS, false); |
| OS << " constant = " << *ConstantValue; |
| } |
| }; |
| |
| class UnknownExpression final : public Expression { |
| private: |
| Instruction *Inst; |
| |
| public: |
| UnknownExpression(Instruction *I) : Expression(ET_Unknown), Inst(I) {} |
| UnknownExpression() = delete; |
| UnknownExpression(const UnknownExpression &) = delete; |
| UnknownExpression &operator=(const UnknownExpression &) = delete; |
| |
| static bool classof(const Expression *EB) { |
| return EB->getExpressionType() == ET_Unknown; |
| } |
| |
| Instruction *getInstruction() const { return Inst; } |
| void setInstruction(Instruction *I) { Inst = I; } |
| |
| bool equals(const Expression &Other) const override { |
| const auto &OU = cast<UnknownExpression>(Other); |
| return Inst == OU.Inst; |
| } |
| |
| hash_code getHashValue() const override { |
| return hash_combine(this->Expression::getHashValue(), Inst); |
| } |
| |
| // Debugging support |
| void printInternal(raw_ostream &OS, bool PrintEType) const override { |
| if (PrintEType) |
| OS << "ExpressionTypeUnknown, "; |
| this->Expression::printInternal(OS, false); |
| OS << " inst = " << *Inst; |
| } |
| }; |
| |
| } // end namespace GVNExpression |
| |
| } // end namespace llvm |
| |
| #endif // LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H |