blob: 035de4f815dfdbb93454c9195ad8645c3049e137 [file] [log] [blame]
//===- ConvertStandardToSPIRV.cpp - Standard to SPIR-V dialect conversion--===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements a pass to convert MLIR standard and builtin dialects
// into the SPIR-V dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/StandardOps/Ops.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Type Conversion
//===----------------------------------------------------------------------===//
SPIRVBasicTypeConverter::SPIRVBasicTypeConverter(MLIRContext *context)
: spirvDialect(context->getRegisteredDialect<spirv::SPIRVDialect>()) {}
Type SPIRVBasicTypeConverter::convertType(Type t) {
// Check if the type is SPIR-V supported. If so return the type.
if (spirvDialect->isValidSPIRVType(t)) {
return t;
}
if (auto memRefType = t.dyn_cast<MemRefType>()) {
if (memRefType.hasStaticShape()) {
// Convert MemrefType to spv.array if size is known.
// TODO(ravishankarm) : For now hard-coding this to be StorageBuffer. Need
// to support other Storage Classes.
return spirv::PointerType::get(
spirv::ArrayType::get(memRefType.getElementType(),
memRefType.getNumElements()),
spirv::StorageClass::StorageBuffer);
}
}
return Type();
}
//===----------------------------------------------------------------------===//
// Entry Function signature Conversion
//===----------------------------------------------------------------------===//
LogicalResult
SPIRVTypeConverter::convertSignatureArg(unsigned inputNo, Type type,
SignatureConversion &result) {
// Try to convert the given input type.
auto convertedType = basicTypeConverter->convertType(type);
// TODO(ravishankarm) : Vulkan spec requires these to be a
// spirv::StructType. This is not a SPIR-V requirement, so just making this a
// pointer type for now.
if (!convertedType)
return failure();
// For arguments to entry functions, convert the type into a pointer type if
// it is already not one.
if (!convertedType.isa<spirv::PointerType>()) {
// TODO(ravishankarm) : For now hard-coding this to be StorageBuffer. Need
// to support other Storage classes.
convertedType = spirv::PointerType::get(convertedType,
spirv::StorageClass::StorageBuffer);
}
// Add the new inputs.
result.addInputs(inputNo, convertedType);
return success();
}
static LogicalResult lowerFunctionImpl(
FuncOp funcOp, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter, TypeConverter *typeConverter,
TypeConverter::SignatureConversion &signatureConverter, FuncOp &newFuncOp) {
auto fnType = funcOp.getType();
if (fnType.getNumResults()) {
return funcOp.emitError("SPIR-V dialect only supports functions with no "
"return values right now");
}
for (auto &argType : enumerate(fnType.getInputs())) {
// Get the type of the argument
if (failed(typeConverter->convertSignatureArg(
argType.index(), argType.value(), signatureConverter))) {
return funcOp.emitError("unable to convert argument type ")
<< argType.value() << " to SPIR-V type";
}
}
// Create a new function with an updated signature.
newFuncOp = rewriter.cloneWithoutRegions(funcOp);
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end());
newFuncOp.setType(FunctionType::get(signatureConverter.getConvertedTypes(),
llvm::None, funcOp.getContext()));
// Tell the rewriter to convert the region signature.
rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
rewriter.replaceOp(funcOp.getOperation(), llvm::None);
return success();
}
namespace mlir {
LogicalResult lowerFunction(FuncOp funcOp, ArrayRef<Value *> operands,
SPIRVTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter,
FuncOp &newFuncOp) {
auto fnType = funcOp.getType();
TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
return lowerFunctionImpl(funcOp, operands, rewriter,
typeConverter->getBasicTypeConverter(),
signatureConverter, newFuncOp);
}
LogicalResult lowerAsEntryFunction(FuncOp funcOp, ArrayRef<Value *> operands,
SPIRVTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter,
FuncOp &newFuncOp) {
auto fnType = funcOp.getType();
TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
if (failed(lowerFunctionImpl(funcOp, operands, rewriter, typeConverter,
signatureConverter, newFuncOp))) {
return failure();
}
// Create spv.globalVariable ops for each of the arguments. These need to be
// bound by the runtime. For now use descriptor_set 0, and arg number as the
// binding number.
auto module = funcOp.getParentOfType<spirv::ModuleOp>();
if (!module) {
return funcOp.emitError("expected op to be within a spv.module");
}
OpBuilder builder(module.getOperation()->getRegion(0));
SmallVector<Attribute, 4> interface;
for (auto &convertedArgType :
llvm::enumerate(signatureConverter.getConvertedTypes())) {
std::string varName = funcOp.getName().str() + "_arg_" +
std::to_string(convertedArgType.index());
auto variableOp = builder.create<spirv::GlobalVariableOp>(
funcOp.getLoc(), builder.getTypeAttr(convertedArgType.value()),
builder.getStringAttr(varName), nullptr);
variableOp.setAttr("descriptor_set", builder.getI32IntegerAttr(0));
variableOp.setAttr("binding",
builder.getI32IntegerAttr(convertedArgType.index()));
interface.push_back(builder.getSymbolRefAttr(variableOp.sym_name()));
}
// Create an entry point instruction for this function.
// TODO(ravishankarm) : Add execution mode for the entry function
builder.setInsertionPoint(&(module.getBlock().back()));
builder.create<spirv::EntryPointOp>(
funcOp.getLoc(),
builder.getI32IntegerAttr(
static_cast<int32_t>(spirv::ExecutionModel::GLCompute)),
builder.getSymbolRefAttr(newFuncOp.getName()),
builder.getArrayAttr(interface));
return success();
}
} // namespace mlir
//===----------------------------------------------------------------------===//
// Operation conversion
//===----------------------------------------------------------------------===//
namespace {
/// Convert return -> spv.Return.
class ReturnToSPIRVConversion : public ConversionPattern {
public:
ReturnToSPIRVConversion(MLIRContext *context)
: ConversionPattern(ReturnOp::getOperationName(), 1, context) {}
virtual PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
if (op->getNumOperands()) {
return matchFailure();
}
rewriter.replaceOpWithNewOp<spirv::ReturnOp>(op);
return matchSuccess();
}
};
} // namespace
namespace {
/// Import the Standard Ops to SPIR-V Patterns.
#include "StandardToSPIRV.cpp.inc"
} // namespace
namespace mlir {
void populateStandardToSPIRVPatterns(MLIRContext *context,
OwningRewritePatternList &patterns) {
populateWithGenerated(context, &patterns);
// Add the return op conversion.
patterns.insert<ReturnToSPIRVConversion>(context);
}
} // namespace mlir