| //===- ConvertStandardToSPIRV.cpp - Standard to SPIR-V dialect conversion--===// |
| // |
| // Copyright 2019 The MLIR Authors. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| // ============================================================================= |
| // |
| // This file implements a pass to convert MLIR standard and builtin dialects |
| // into the SPIR-V dialect. |
| // |
| //===----------------------------------------------------------------------===// |
| #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h" |
| #include "mlir/Dialect/SPIRV/SPIRVDialect.h" |
| #include "mlir/Dialect/SPIRV/SPIRVOps.h" |
| #include "mlir/Dialect/StandardOps/Ops.h" |
| #include "llvm/ADT/SetVector.h" |
| |
| using namespace mlir; |
| |
| //===----------------------------------------------------------------------===// |
| // Type Conversion |
| //===----------------------------------------------------------------------===// |
| |
| static Type basicTypeConversion(Type t) { |
| // Check if the type is SPIR-V supported. If so return the type. |
| if (spirv::SPIRVDialect::isValidType(t)) { |
| return t; |
| } |
| |
| if (auto indexType = t.dyn_cast<IndexType>()) { |
| // Return I32 for index types. |
| return IntegerType::get(32, t.getContext()); |
| } |
| |
| if (auto memRefType = t.dyn_cast<MemRefType>()) { |
| auto elementType = memRefType.getElementType(); |
| if (memRefType.hasStaticShape()) { |
| // Convert to a multi-dimensional spv.array if size is known. |
| for (auto size : reverse(memRefType.getShape())) { |
| elementType = spirv::ArrayType::get(elementType, size); |
| } |
| return spirv::PointerType::get(elementType, |
| spirv::StorageClass::StorageBuffer); |
| } else { |
| // Vulkan SPIR-V validation rules require runtime array type to be the |
| // last member of a struct. |
| return spirv::PointerType::get(spirv::RuntimeArrayType::get(elementType), |
| spirv::StorageClass::StorageBuffer); |
| } |
| } |
| return Type(); |
| } |
| |
| Type SPIRVBasicTypeConverter::convertType(Type t) { |
| return basicTypeConversion(t); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Entry Function signature Conversion |
| //===----------------------------------------------------------------------===// |
| |
| /// Generates the type of variable given the type of object. |
| static Type getGlobalVarTypeForEntryFnArg(Type t) { |
| auto convertedType = basicTypeConversion(t); |
| if (auto ptrType = convertedType.dyn_cast<spirv::PointerType>()) { |
| if (!ptrType.getPointeeType().isa<spirv::StructType>()) { |
| return spirv::PointerType::get( |
| spirv::StructType::get(ptrType.getPointeeType()), |
| ptrType.getStorageClass()); |
| } |
| } else { |
| return spirv::PointerType::get(spirv::StructType::get(convertedType), |
| spirv::StorageClass::StorageBuffer); |
| } |
| return convertedType; |
| } |
| |
| Type SPIRVTypeConverter::convertType(Type t) { |
| return getGlobalVarTypeForEntryFnArg(t); |
| } |
| |
| /// Computes the replacement value for an argument of an entry function. It |
| /// allocates a global variable for this argument and adds statements in the |
| /// entry block to get a replacement value within function scope. |
| static Value *createAndLoadGlobalVarForEntryFnArg(PatternRewriter &rewriter, |
| size_t origArgNum, |
| Value *origArg) { |
| // Create a global variable for this argument. |
| auto insertionOp = rewriter.getInsertionBlock()->getParent(); |
| auto module = insertionOp->getParentOfType<spirv::ModuleOp>(); |
| if (!module) { |
| return nullptr; |
| } |
| auto funcOp = insertionOp->getParentOfType<FuncOp>(); |
| spirv::GlobalVariableOp var; |
| { |
| OpBuilder::InsertionGuard moduleInsertionGuard(rewriter); |
| rewriter.setInsertionPointToStart(&module.getBlock()); |
| std::string varName = |
| funcOp.getName().str() + "_arg_" + std::to_string(origArgNum); |
| var = rewriter.create<spirv::GlobalVariableOp>( |
| funcOp.getLoc(), |
| rewriter.getTypeAttr(getGlobalVarTypeForEntryFnArg(origArg->getType())), |
| rewriter.getStringAttr(varName), nullptr); |
| var.setAttr( |
| spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet), |
| rewriter.getI32IntegerAttr(0)); |
| var.setAttr( |
| spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding), |
| rewriter.getI32IntegerAttr(origArgNum)); |
| } |
| // Insert the addressOf and load instructions, to get back the converted value |
| // type. |
| auto addressOf = rewriter.create<spirv::AddressOfOp>(funcOp.getLoc(), var); |
| auto zero = rewriter.create<spirv::ConstantOp>(funcOp.getLoc(), |
| rewriter.getIntegerType(32), |
| rewriter.getI32IntegerAttr(0)); |
| auto accessChain = rewriter.create<spirv::AccessChainOp>( |
| funcOp.getLoc(), addressOf.pointer(), zero.constant()); |
| // If the original argument is a tensor/memref type, the value is not |
| // loaded. Instead the pointer value is returned to allow its use in access |
| // chain ops. |
| auto origArgType = origArg->getType(); |
| if (origArgType.isa<MemRefType>()) { |
| return accessChain; |
| } |
| return rewriter.create<spirv::LoadOp>( |
| funcOp.getLoc(), accessChain.component_ptr(), /*memory_access=*/nullptr, |
| /*alignment=*/nullptr); |
| } |
| |
| static FuncOp applySignatureConversion( |
| FuncOp funcOp, ConversionPatternRewriter &rewriter, |
| TypeConverter::SignatureConversion &signatureConverter) { |
| // Create a new function with an updated signature. |
| auto newFuncOp = rewriter.cloneWithoutRegions(funcOp); |
| rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), |
| newFuncOp.end()); |
| newFuncOp.setType(FunctionType::get(signatureConverter.getConvertedTypes(), |
| llvm::None, funcOp.getContext())); |
| |
| // Tell the rewriter to convert the region signature. |
| rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter); |
| rewriter.replaceOp(funcOp.getOperation(), llvm::None); |
| return newFuncOp; |
| } |
| |
| /// Gets the global variables that need to be specified as interface variable |
| /// with an spv.EntryPointOp. Traverses the body of a entry function to do so. |
| LogicalResult getInterfaceVariables(FuncOp funcOp, |
| SmallVectorImpl<Attribute> &interfaceVars) { |
| auto module = funcOp.getParentOfType<spirv::ModuleOp>(); |
| if (!module) { |
| return failure(); |
| } |
| llvm::SetVector<Operation *> interfaceVarSet; |
| for (auto &block : funcOp) { |
| // TODO(ravishankarm) : This should in reality traverse the entry function |
| // call graph and collect all the interfaces. For now, just traverse the |
| // instructions in this function. |
| auto callOps = block.getOps<CallOp>(); |
| if (std::distance(callOps.begin(), callOps.end())) { |
| return funcOp.emitError("Collecting interface variables through function " |
| "calls unimplemented"); |
| } |
| for (auto op : block.getOps<spirv::AddressOfOp>()) { |
| auto var = module.lookupSymbol<spirv::GlobalVariableOp>(op.variable()); |
| if (var.type().cast<spirv::PointerType>().getStorageClass() == |
| spirv::StorageClass::StorageBuffer) { |
| continue; |
| } |
| interfaceVarSet.insert(var.getOperation()); |
| } |
| } |
| for (auto &var : interfaceVarSet) { |
| interfaceVars.push_back(SymbolRefAttr::get( |
| cast<spirv::GlobalVariableOp>(var).sym_name(), funcOp.getContext())); |
| } |
| return success(); |
| } |
| |
| namespace mlir { |
| LogicalResult lowerFunction(FuncOp funcOp, SPIRVTypeConverter *typeConverter, |
| ConversionPatternRewriter &rewriter, |
| FuncOp &newFuncOp) { |
| auto fnType = funcOp.getType(); |
| if (fnType.getNumResults()) { |
| return funcOp.emitError("SPIR-V lowering only supports functions with no " |
| "return values right now"); |
| } |
| TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs()); |
| auto basicTypeConverter = typeConverter->getBasicTypeConverter(); |
| for (auto origArgType : enumerate(fnType.getInputs())) { |
| auto convertedType = basicTypeConverter->convertType(origArgType.value()); |
| if (!convertedType) { |
| return funcOp.emitError("unable to convert argument of type '") |
| << convertedType << "'"; |
| } |
| signatureConverter.addInputs(origArgType.index(), convertedType); |
| } |
| newFuncOp = applySignatureConversion(funcOp, rewriter, signatureConverter); |
| return success(); |
| } |
| |
| LogicalResult lowerAsEntryFunction(FuncOp funcOp, |
| SPIRVTypeConverter *typeConverter, |
| ConversionPatternRewriter &rewriter, |
| FuncOp &newFuncOp) { |
| auto fnType = funcOp.getType(); |
| if (fnType.getNumResults()) { |
| return funcOp.emitError("SPIR-V lowering only supports functions with no " |
| "return values right now"); |
| } |
| // For entry functions need to make the signature void(void). Compute the |
| // replacement value for all arguments and replace all uses. |
| TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs()); |
| { |
| OpBuilder::InsertionGuard moduleInsertionGuard(rewriter); |
| rewriter.setInsertionPointToStart(&funcOp.front()); |
| for (auto origArg : enumerate(funcOp.getArguments())) { |
| auto replacement = createAndLoadGlobalVarForEntryFnArg( |
| rewriter, origArg.index(), origArg.value()); |
| rewriter.replaceUsesOfBlockArgument(origArg.value(), replacement); |
| } |
| } |
| newFuncOp = applySignatureConversion(funcOp, rewriter, signatureConverter); |
| return success(); |
| } |
| |
| LogicalResult finalizeEntryFunction(FuncOp newFuncOp, OpBuilder &builder) { |
| // Add the spv.EntryPointOp after collecting all the interface variables |
| // needed. |
| SmallVector<Attribute, 1> interfaceVars; |
| if (failed(getInterfaceVariables(newFuncOp, interfaceVars))) { |
| return failure(); |
| } |
| builder.create<spirv::EntryPointOp>(newFuncOp.getLoc(), |
| spirv::ExecutionModel::GLCompute, |
| newFuncOp, interfaceVars); |
| // Specify the spv.ExecutionModeOp. |
| |
| /// TODO(ravishankarm): Vulkan environment for SPIR-V requires "either a |
| /// LocalSize execution mode or an object decorated with the WorkgroupSize |
| /// decoration must be specified." Better approach is to use the |
| /// WorkgroupSize GlobalVariable with initializer being a specialization |
| /// constant. But current support for specialization constant does not allow |
| /// for this. So for now use the execution mode. Hard-wiring this to {1, 1, |
| /// 1} for now. To be fixed ASAP. |
| builder.create<spirv::ExecutionModeOp>(newFuncOp.getLoc(), newFuncOp, |
| spirv::ExecutionMode::LocalSize, |
| ArrayRef<int32_t>{1, 1, 1}); |
| return success(); |
| } |
| } // namespace mlir |
| |
| //===----------------------------------------------------------------------===// |
| // Operation conversion |
| //===----------------------------------------------------------------------===// |
| |
| namespace { |
| |
| /// Convert integer binary operations to SPIR-V operations. Cannot use tablegen |
| /// for this. If the integer operation is on variables of IndexType, the type of |
| /// the return value of the replacement operation differs from that of the |
| /// replaced operation. This is not handled in tablegen-based pattern |
| /// specification. |
| template <typename StdOp, typename SPIRVOp> |
| class IntegerOpConversion final : public ConversionPattern { |
| public: |
| IntegerOpConversion(MLIRContext *context) |
| : ConversionPattern(StdOp::getOperationName(), 1, context) {} |
| |
| PatternMatchResult |
| matchAndRewrite(Operation *op, ArrayRef<Value *> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| rewriter.template replaceOpWithNewOp<SPIRVOp>( |
| op, operands[0]->getType(), operands, ArrayRef<NamedAttribute>()); |
| return this->matchSuccess(); |
| } |
| }; |
| |
| /// Convert load -> spv.LoadOp. The operands of the replaced operation are of |
| /// IndexType while that of the replacement operation are of type i32. This is |
| /// not suppored in tablegen based pattern specification. |
| // TODO(ravishankarm) : These could potentially be templated on the operation |
| // being converted, since the same logic should work for linalg.load. |
| class LoadOpConversion final : public ConversionPattern { |
| public: |
| LoadOpConversion(MLIRContext *context) |
| : ConversionPattern(LoadOp::getOperationName(), 1, context) {} |
| |
| PatternMatchResult |
| matchAndRewrite(Operation *op, ArrayRef<Value *> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| LoadOpOperandAdaptor loadOperands(operands); |
| auto basePtr = loadOperands.memref(); |
| auto ptrType = basePtr->getType().dyn_cast<spirv::PointerType>(); |
| if (!ptrType) { |
| return matchFailure(); |
| } |
| auto loadPtr = rewriter.create<spirv::AccessChainOp>( |
| op->getLoc(), basePtr, loadOperands.indices()); |
| auto loadPtrType = loadPtr.getType().cast<spirv::PointerType>(); |
| rewriter.replaceOpWithNewOp<spirv::LoadOp>( |
| op, loadPtrType.getPointeeType(), loadPtr, /*memory_access =*/nullptr, |
| /*alignment =*/nullptr); |
| return matchSuccess(); |
| } |
| }; |
| |
| /// Convert return -> spv.Return. |
| class ReturnToSPIRVConversion : public ConversionPattern { |
| public: |
| ReturnToSPIRVConversion(MLIRContext *context) |
| : ConversionPattern(ReturnOp::getOperationName(), 1, context) {} |
| PatternMatchResult |
| matchAndRewrite(Operation *op, ArrayRef<Value *> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| if (op->getNumOperands()) { |
| return matchFailure(); |
| } |
| rewriter.replaceOpWithNewOp<spirv::ReturnOp>(op); |
| return matchSuccess(); |
| } |
| }; |
| |
| /// Convert store -> spv.StoreOp. The operands of the replaced operation are of |
| /// IndexType while that of the replacement operation are of type i32. This is |
| /// not suppored in tablegen based pattern specification. |
| // TODO(ravishankarm) : These could potentially be templated on the operation |
| // being converted, since the same logic should work for linalg.store. |
| class StoreOpConversion final : public ConversionPattern { |
| public: |
| StoreOpConversion(MLIRContext *context) |
| : ConversionPattern(StoreOp::getOperationName(), 1, context) {} |
| |
| PatternMatchResult |
| matchAndRewrite(Operation *op, ArrayRef<Value *> operands, |
| ConversionPatternRewriter &rewriter) const override { |
| StoreOpOperandAdaptor storeOperands(operands); |
| auto value = storeOperands.value(); |
| auto basePtr = storeOperands.memref(); |
| auto ptrType = basePtr->getType().dyn_cast<spirv::PointerType>(); |
| if (!ptrType) { |
| return matchFailure(); |
| } |
| auto storePtr = rewriter.create<spirv::AccessChainOp>( |
| op->getLoc(), basePtr, storeOperands.indices()); |
| rewriter.replaceOpWithNewOp<spirv::StoreOp>(op, storePtr, value, |
| /*memory_access =*/nullptr, |
| /*alignment =*/nullptr); |
| return matchSuccess(); |
| } |
| }; |
| |
| } // namespace |
| |
| namespace { |
| /// Import the Standard Ops to SPIR-V Patterns. |
| #include "StandardToSPIRV.cpp.inc" |
| } // namespace |
| |
| namespace mlir { |
| void populateStandardToSPIRVPatterns(MLIRContext *context, |
| OwningRewritePatternList &patterns) { |
| populateWithGenerated(context, &patterns); |
| // Add the return op conversion. |
| patterns.insert<IntegerOpConversion<AddIOp, spirv::IAddOp>, |
| IntegerOpConversion<MulIOp, spirv::IMulOp>, LoadOpConversion, |
| ReturnToSPIRVConversion, StoreOpConversion>(context); |
| } |
| } // namespace mlir |