blob: e92ad03d776778421a216d78c16014e36aa0d308 [file] [log] [blame]
//===- ConvertStandardToSPIRV.h - Convert to SPIR-V dialect -----*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// Provides type converters and patterns to convert from standard types/ops to
// SPIR-V types and operations. Also provides utilities and base classes to use
// while targeting SPIR-V dialect.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_STANDARDTOSPIRV_CONVERTSTANDARDTOSPIRV_H
#define MLIR_CONVERSION_STANDARDTOSPIRV_CONVERTSTANDARDTOSPIRV_H
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Support/StringExtras.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
class LoadOp;
class ReturnOp;
class StoreOp;
/// Type conversion from Standard Types to SPIR-V Types.
class SPIRVBasicTypeConverter : public TypeConverter {
public:
/// Converts types to SPIR-V supported types.
virtual Type convertType(Type t);
};
/// Converts a function type according to the requirements of a SPIR-V entry
/// function. The arguments need to be converted to spv.Variables of spv.ptr
/// types so that they could be bound by the runtime.
class SPIRVTypeConverter final : public TypeConverter {
public:
explicit SPIRVTypeConverter(SPIRVBasicTypeConverter *basicTypeConverter)
: basicTypeConverter(basicTypeConverter) {}
/// Converts types to SPIR-V types using the basic type converter.
Type convertType(Type t) override;
/// Gets the basic type converter.
SPIRVBasicTypeConverter *getBasicTypeConverter() const {
return basicTypeConverter;
}
private:
SPIRVBasicTypeConverter *basicTypeConverter;
};
/// Base class to define a conversion pattern to translate Ops into SPIR-V.
template <typename OpTy> class SPIRVOpLowering : public ConversionPattern {
public:
SPIRVOpLowering(MLIRContext *context, SPIRVTypeConverter &typeConverter)
: ConversionPattern(OpTy::getOperationName(), 1, context),
typeConverter(typeConverter) {}
protected:
/// Gets the global variable associated with a builtin and add
/// it if it doesnt exist.
Value *loadFromBuiltinVariable(Operation *op, spirv::BuiltIn builtin,
ConversionPatternRewriter &rewriter) const {
auto moduleOp = op->getParentOfType<spirv::ModuleOp>();
if (!moduleOp) {
op->emitError("expected operation to be within a SPIR-V module");
return nullptr;
}
auto varOp =
getOrInsertBuiltinVariable(moduleOp, op->getLoc(), builtin, rewriter);
auto ptr = rewriter
.create<spirv::AddressOfOp>(op->getLoc(), varOp.type(),
rewriter.getSymbolRefAttr(varOp))
.pointer();
return rewriter.create<spirv::LoadOp>(
op->getLoc(),
ptr->getType().template cast<spirv::PointerType>().getPointeeType(),
ptr, /*memory_access =*/nullptr, /*alignment =*/nullptr);
}
/// Type lowering class.
SPIRVTypeConverter &typeConverter;
private:
/// Look through all global variables in `moduleOp` and check if there is a
/// spv.globalVariable that has the same `builtin` attribute.
spirv::GlobalVariableOp getBuiltinVariable(spirv::ModuleOp &moduleOp,
spirv::BuiltIn builtin) const {
for (auto varOp : moduleOp.getBlock().getOps<spirv::GlobalVariableOp>()) {
if (auto builtinAttr = varOp.getAttrOfType<StringAttr>(convertToSnakeCase(
stringifyDecoration(spirv::Decoration::BuiltIn)))) {
auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
if (varBuiltIn && varBuiltIn.getValue() == builtin) {
return varOp;
}
}
}
return nullptr;
}
/// Gets name of global variable for a buitlin.
std::string getBuiltinVarName(spirv::BuiltIn builtin) const {
return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() +
"__";
}
/// Gets or inserts a global variable for a builtin within a module.
spirv::GlobalVariableOp
getOrInsertBuiltinVariable(spirv::ModuleOp &moduleOp, Location loc,
spirv::BuiltIn builtin,
ConversionPatternRewriter &builder) const {
if (auto varOp = getBuiltinVariable(moduleOp, builtin)) {
return varOp;
}
auto ip = builder.saveInsertionPoint();
builder.setInsertionPointToStart(&moduleOp.getBlock());
auto name = getBuiltinVarName(builtin);
spirv::GlobalVariableOp newVarOp;
switch (builtin) {
case spirv::BuiltIn::NumWorkgroups:
case spirv::BuiltIn::WorkgroupSize:
case spirv::BuiltIn::WorkgroupId:
case spirv::BuiltIn::LocalInvocationId:
case spirv::BuiltIn::GlobalInvocationId: {
auto ptrType = spirv::PointerType::get(
builder.getVectorType({3}, builder.getIntegerType(32)),
spirv::StorageClass::Input);
newVarOp = builder.create<spirv::GlobalVariableOp>(
loc, builder.getTypeAttr(ptrType), builder.getStringAttr(name),
nullptr);
newVarOp.setAttr(
convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn)),
builder.getStringAttr(stringifyBuiltIn(builtin)));
break;
}
default:
emitError(loc, "unimplemented builtin variable generation for ")
<< stringifyBuiltIn(builtin);
}
builder.restoreInsertionPoint(ip);
return newVarOp;
}
};
/// Legalizes a function as a non-entry function.
LogicalResult lowerFunction(FuncOp funcOp, SPIRVTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter,
FuncOp &newFuncOp);
/// Legalizes a function as an entry function.
LogicalResult lowerAsEntryFunction(FuncOp funcOp,
SPIRVTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter,
FuncOp &newFuncOp);
/// Finalizes entry function legalization. Inserts the spv.EntryPoint and
/// spv.ExecutionMode ops.
LogicalResult finalizeEntryFunction(FuncOp newFuncOp, OpBuilder &builder);
/// Appends to a pattern list additional patterns for translating StandardOps to
/// SPIR-V ops.
void populateStandardToSPIRVPatterns(MLIRContext *context,
OwningRewritePatternList &patterns);
} // namespace mlir
#endif // MLIR_CONVERSION_STANDARDTOSPIRV_CONVERTSTANDARDTOSPIRV_H