| //===- StandardTypes.cpp - MLIR Standard Type Classes ---------------------===// |
| // |
| // Copyright 2019 The MLIR Authors. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| // ============================================================================= |
| |
| #include "mlir/IR/StandardTypes.h" |
| #include "TypeDetail.h" |
| #include "mlir/IR/AffineMap.h" |
| #include "mlir/IR/Diagnostics.h" |
| #include "mlir/Support/STLExtras.h" |
| #include "llvm/ADT/APFloat.h" |
| #include "llvm/ADT/Twine.h" |
| #include "llvm/Support/raw_ostream.h" |
| |
| using namespace mlir; |
| using namespace mlir::detail; |
| |
| //===----------------------------------------------------------------------===// |
| // Type |
| //===----------------------------------------------------------------------===// |
| |
| bool Type::isBF16() { return getKind() == StandardTypes::BF16; } |
| bool Type::isF16() { return getKind() == StandardTypes::F16; } |
| bool Type::isF32() { return getKind() == StandardTypes::F32; } |
| bool Type::isF64() { return getKind() == StandardTypes::F64; } |
| |
| bool Type::isIndex() { return isa<IndexType>(); } |
| |
| /// Return true if this is an integer type with the specified width. |
| bool Type::isInteger(unsigned width) { |
| if (auto intTy = dyn_cast<IntegerType>()) |
| return intTy.getWidth() == width; |
| return false; |
| } |
| |
| bool Type::isIntOrIndex() { return isa<IndexType>() || isa<IntegerType>(); } |
| |
| bool Type::isIntOrIndexOrFloat() { |
| return isa<IndexType>() || isa<IntegerType>() || isa<FloatType>(); |
| } |
| |
| bool Type::isIntOrFloat() { return isa<IntegerType>() || isa<FloatType>(); } |
| |
| //===----------------------------------------------------------------------===// |
| // Integer Type |
| //===----------------------------------------------------------------------===// |
| |
| // static constexpr must have a definition (until in C++17 and inline variable). |
| constexpr unsigned IntegerType::kMaxWidth; |
| |
| /// Verify the construction of an integer type. |
| LogicalResult IntegerType::verifyConstructionInvariants( |
| llvm::Optional<Location> loc, MLIRContext *context, unsigned width) { |
| if (width > IntegerType::kMaxWidth) { |
| if (loc) |
| emitError(*loc) << "integer bitwidth is limited to " |
| << IntegerType::kMaxWidth << " bits"; |
| return failure(); |
| } |
| return success(); |
| } |
| |
| unsigned IntegerType::getWidth() const { return getImpl()->width; } |
| |
| //===----------------------------------------------------------------------===// |
| // Float Type |
| //===----------------------------------------------------------------------===// |
| |
| unsigned FloatType::getWidth() { |
| switch (getKind()) { |
| case StandardTypes::BF16: |
| case StandardTypes::F16: |
| return 16; |
| case StandardTypes::F32: |
| return 32; |
| case StandardTypes::F64: |
| return 64; |
| default: |
| llvm_unreachable("unexpected type"); |
| } |
| } |
| |
| /// Returns the floating semantics for the given type. |
| const llvm::fltSemantics &FloatType::getFloatSemantics() { |
| if (isBF16()) |
| // Treat BF16 like a double. This is unfortunate but BF16 fltSemantics is |
| // not defined in LLVM. |
| // TODO(jpienaar): add BF16 to LLVM? fltSemantics are internal to APFloat.cc |
| // else one could add it. |
| // static const fltSemantics semBF16 = {127, -126, 8, 16}; |
| return APFloat::IEEEdouble(); |
| if (isF16()) |
| return APFloat::IEEEhalf(); |
| if (isF32()) |
| return APFloat::IEEEsingle(); |
| if (isF64()) |
| return APFloat::IEEEdouble(); |
| llvm_unreachable("non-floating point type used"); |
| } |
| |
| unsigned Type::getIntOrFloatBitWidth() { |
| assert(isIntOrFloat() && "only ints and floats have a bitwidth"); |
| if (auto intType = dyn_cast<IntegerType>()) { |
| return intType.getWidth(); |
| } |
| |
| auto floatType = cast<FloatType>(); |
| return floatType.getWidth(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ShapedType |
| //===----------------------------------------------------------------------===// |
| |
| Type ShapedType::getElementType() const { |
| return static_cast<ImplType *>(impl)->elementType; |
| } |
| |
| unsigned ShapedType::getElementTypeBitWidth() const { |
| return getElementType().getIntOrFloatBitWidth(); |
| } |
| |
| int64_t ShapedType::getNumElements() const { |
| assert(hasStaticShape() && "cannot get element count of dynamic shaped type"); |
| auto shape = getShape(); |
| int64_t num = 1; |
| for (auto dim : shape) |
| num *= dim; |
| return num; |
| } |
| |
| int64_t ShapedType::getRank() const { return getShape().size(); } |
| |
| bool ShapedType::hasRank() const { return !isa<UnrankedTensorType>(); } |
| |
| int64_t ShapedType::getDimSize(int64_t i) const { |
| assert(i >= 0 && i < getRank() && "invalid index for shaped type"); |
| return getShape()[i]; |
| } |
| |
| /// Get the number of bits require to store a value of the given shaped type. |
| /// Compute the value recursively since tensors are allowed to have vectors as |
| /// elements. |
| int64_t ShapedType::getSizeInBits() const { |
| assert(hasStaticShape() && |
| "cannot get the bit size of an aggregate with a dynamic shape"); |
| |
| auto elementType = getElementType(); |
| if (elementType.isIntOrFloat()) |
| return elementType.getIntOrFloatBitWidth() * getNumElements(); |
| |
| // Tensors can have vectors and other tensors as elements, other shaped types |
| // cannot. |
| assert(isa<TensorType>() && "unsupported element type"); |
| assert((elementType.isa<VectorType>() || elementType.isa<TensorType>()) && |
| "unsupported tensor element type"); |
| return getNumElements() * elementType.cast<ShapedType>().getSizeInBits(); |
| } |
| |
| ArrayRef<int64_t> ShapedType::getShape() const { |
| switch (getKind()) { |
| case StandardTypes::Vector: |
| return cast<VectorType>().getShape(); |
| case StandardTypes::RankedTensor: |
| return cast<RankedTensorType>().getShape(); |
| case StandardTypes::MemRef: |
| return cast<MemRefType>().getShape(); |
| default: |
| llvm_unreachable("not a ShapedType or not ranked"); |
| } |
| } |
| |
| int64_t ShapedType::getNumDynamicDims() const { |
| return llvm::count_if(getShape(), isDynamic); |
| } |
| |
| bool ShapedType::hasStaticShape() const { |
| return hasRank() && llvm::none_of(getShape(), isDynamic); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // VectorType |
| //===----------------------------------------------------------------------===// |
| |
| VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) { |
| return Base::get(elementType.getContext(), StandardTypes::Vector, shape, |
| elementType); |
| } |
| |
| VectorType VectorType::getChecked(ArrayRef<int64_t> shape, Type elementType, |
| Location location) { |
| return Base::getChecked(location, elementType.getContext(), |
| StandardTypes::Vector, shape, elementType); |
| } |
| |
| LogicalResult VectorType::verifyConstructionInvariants( |
| llvm::Optional<Location> loc, MLIRContext *context, ArrayRef<int64_t> shape, |
| Type elementType) { |
| if (shape.empty()) { |
| if (loc) |
| emitError(*loc, "vector types must have at least one dimension"); |
| return failure(); |
| } |
| |
| if (!isValidElementType(elementType)) { |
| if (loc) |
| emitError(*loc, "vector elements must be int or float type"); |
| return failure(); |
| } |
| |
| if (any_of(shape, [](int64_t i) { return i <= 0; })) { |
| if (loc) |
| emitError(*loc, "vector types must have positive constant sizes"); |
| return failure(); |
| } |
| return success(); |
| } |
| |
| ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); } |
| |
| //===----------------------------------------------------------------------===// |
| // TensorType |
| //===----------------------------------------------------------------------===// |
| |
| // Check if "elementType" can be an element type of a tensor. Emit errors if |
| // location is not nullptr. Returns failure if check failed. |
| static inline LogicalResult checkTensorElementType(Optional<Location> location, |
| MLIRContext *context, |
| Type elementType) { |
| if (!TensorType::isValidElementType(elementType)) { |
| if (location) |
| emitError(*location, "invalid tensor element type"); |
| return failure(); |
| } |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // RankedTensorType |
| //===----------------------------------------------------------------------===// |
| |
| RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape, |
| Type elementType) { |
| return Base::get(elementType.getContext(), StandardTypes::RankedTensor, shape, |
| elementType); |
| } |
| |
| RankedTensorType RankedTensorType::getChecked(ArrayRef<int64_t> shape, |
| Type elementType, |
| Location location) { |
| return Base::getChecked(location, elementType.getContext(), |
| StandardTypes::RankedTensor, shape, elementType); |
| } |
| |
| LogicalResult RankedTensorType::verifyConstructionInvariants( |
| llvm::Optional<Location> loc, MLIRContext *context, ArrayRef<int64_t> shape, |
| Type elementType) { |
| for (int64_t s : shape) { |
| if (s < -1) { |
| if (loc) |
| emitError(*loc, "invalid tensor dimension size"); |
| return failure(); |
| } |
| } |
| return checkTensorElementType(loc, context, elementType); |
| } |
| |
| ArrayRef<int64_t> RankedTensorType::getShape() const { |
| return getImpl()->getShape(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // UnrankedTensorType |
| //===----------------------------------------------------------------------===// |
| |
| UnrankedTensorType UnrankedTensorType::get(Type elementType) { |
| return Base::get(elementType.getContext(), StandardTypes::UnrankedTensor, |
| elementType); |
| } |
| |
| UnrankedTensorType UnrankedTensorType::getChecked(Type elementType, |
| Location location) { |
| return Base::getChecked(location, elementType.getContext(), |
| StandardTypes::UnrankedTensor, elementType); |
| } |
| |
| LogicalResult UnrankedTensorType::verifyConstructionInvariants( |
| llvm::Optional<Location> loc, MLIRContext *context, Type elementType) { |
| return checkTensorElementType(loc, context, elementType); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // MemRefType |
| //===----------------------------------------------------------------------===// |
| |
| /// Get or create a new MemRefType based on shape, element type, affine |
| /// map composition, and memory space. Assumes the arguments define a |
| /// well-formed MemRef type. Use getChecked to gracefully handle MemRefType |
| /// construction failures. |
| MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, |
| ArrayRef<AffineMap> affineMapComposition, |
| unsigned memorySpace) { |
| auto result = getImpl(shape, elementType, affineMapComposition, memorySpace, |
| /*location=*/llvm::None); |
| assert(result && "Failed to construct instance of MemRefType."); |
| return result; |
| } |
| |
| /// Get or create a new MemRefType based on shape, element type, affine |
| /// map composition, and memory space declared at the given location. |
| /// If the location is unknown, the last argument should be an instance of |
| /// UnknownLoc. If the MemRefType defined by the arguments would be |
| /// ill-formed, emits errors (to the handler registered with the context or to |
| /// the error stream) and returns nullptr. |
| MemRefType MemRefType::getChecked(ArrayRef<int64_t> shape, Type elementType, |
| ArrayRef<AffineMap> affineMapComposition, |
| unsigned memorySpace, Location location) { |
| return getImpl(shape, elementType, affineMapComposition, memorySpace, |
| location); |
| } |
| |
| /// Get or create a new MemRefType defined by the arguments. If the resulting |
| /// type would be ill-formed, return nullptr. If the location is provided, |
| /// emit detailed error messages. To emit errors when the location is unknown, |
| /// pass in an instance of UnknownLoc. |
| MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType, |
| ArrayRef<AffineMap> affineMapComposition, |
| unsigned memorySpace, |
| Optional<Location> location) { |
| auto *context = elementType.getContext(); |
| |
| for (int64_t s : shape) { |
| // Negative sizes are not allowed except for `-1` that means dynamic size. |
| if (s < -1) { |
| if (location) |
| emitError(*location, "invalid memref size"); |
| return {}; |
| } |
| } |
| |
| // Check that the structure of the composition is valid, i.e. that each |
| // subsequent affine map has as many inputs as the previous map has results. |
| // Take the dimensionality of the MemRef for the first map. |
| auto dim = shape.size(); |
| unsigned i = 0; |
| for (const auto &affineMap : affineMapComposition) { |
| if (affineMap.getNumDims() != dim) { |
| if (location) |
| emitError(*location) |
| << "memref affine map dimension mismatch between " |
| << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i)) |
| << " and affine map" << i + 1 << ": " << dim |
| << " != " << affineMap.getNumDims(); |
| return nullptr; |
| } |
| |
| dim = affineMap.getNumResults(); |
| ++i; |
| } |
| |
| // Drop identity maps from the composition. |
| // This may lead to the composition becoming empty, which is interpreted as an |
| // implicit identity. |
| llvm::SmallVector<AffineMap, 2> cleanedAffineMapComposition; |
| for (const auto &map : affineMapComposition) { |
| if (map.isIdentity()) |
| continue; |
| cleanedAffineMapComposition.push_back(map); |
| } |
| |
| return Base::get(context, StandardTypes::MemRef, shape, elementType, |
| cleanedAffineMapComposition, memorySpace); |
| } |
| |
| ArrayRef<int64_t> MemRefType::getShape() const { return getImpl()->getShape(); } |
| |
| ArrayRef<AffineMap> MemRefType::getAffineMaps() const { |
| return getImpl()->getAffineMaps(); |
| } |
| |
| unsigned MemRefType::getMemorySpace() const { return getImpl()->memorySpace; } |
| |
| //===----------------------------------------------------------------------===// |
| /// ComplexType |
| //===----------------------------------------------------------------------===// |
| |
| ComplexType ComplexType::get(Type elementType) { |
| return Base::get(elementType.getContext(), StandardTypes::Complex, |
| elementType); |
| } |
| |
| ComplexType ComplexType::getChecked(Type elementType, Location location) { |
| return Base::getChecked(location, elementType.getContext(), |
| StandardTypes::Complex, elementType); |
| } |
| |
| /// Verify the construction of an integer type. |
| LogicalResult ComplexType::verifyConstructionInvariants( |
| llvm::Optional<Location> loc, MLIRContext *context, Type elementType) { |
| if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>()) { |
| if (loc) |
| emitError(*loc, "invalid element type for complex"); |
| return failure(); |
| } |
| return success(); |
| } |
| |
| Type ComplexType::getElementType() { return getImpl()->elementType; } |
| |
| //===----------------------------------------------------------------------===// |
| /// TupleType |
| //===----------------------------------------------------------------------===// |
| |
| /// Get or create a new TupleType with the provided element types. Assumes the |
| /// arguments define a well-formed type. |
| TupleType TupleType::get(ArrayRef<Type> elementTypes, MLIRContext *context) { |
| return Base::get(context, StandardTypes::Tuple, elementTypes); |
| } |
| |
| /// Return the elements types for this tuple. |
| ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); } |
| |
| /// Accumulate the types contained in this tuple and tuples nested within it. |
| /// Note that this only flattens nested tuples, not any other container type, |
| /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to |
| /// (i32, tensor<i32>, f32, i64) |
| void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) { |
| for (Type type : getTypes()) { |
| if (auto nestedTuple = type.dyn_cast<TupleType>()) |
| nestedTuple.getFlattenedTypes(types); |
| else |
| types.push_back(type); |
| } |
| } |
| |
| /// Return the number of element types. |
| size_t TupleType::size() const { return getImpl()->size(); } |