blob: 88c3bc2e6b72b9505f85f9773059e5254207159b [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h"
namespace mlir {
namespace kernel_gen {
namespace tf_framework {
namespace {
using LLVM::LLVMFuncOp;
using LLVM::LLVMType;
static constexpr StringRef kCInterfaceAlloc = "_mlir_ciface_tf_alloc";
static constexpr StringRef kCInterfaceDealloc = "_mlir_ciface_tf_dealloc";
static constexpr StringRef kCInterfaceReportError =
"_mlir_ciface_tf_report_error";
/// Base class for patterns converting TF Framework ops to function calls.
template <typename OpTy>
class ConvertToLLVMCallOpPattern : public ConvertOpToLLVMPattern<OpTy> {
public:
using ConvertOpToLLVMPattern<OpTy>::ConvertOpToLLVMPattern;
// Attempts to find function symbol in the module, adds it if not found.
FlatSymbolRefAttr getOrInsertTFFunction(PatternRewriter &rewriter,
Operation *op) const {
ModuleOp module = op->getParentOfType<ModuleOp>();
StringRef tf_func_name = GetFuncName();
auto tf_func = module.lookupSymbol<LLVMFuncOp>(tf_func_name);
if (!tf_func) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
auto func_type = GetFuncType();
tf_func = rewriter.create<LLVMFuncOp>(rewriter.getUnknownLoc(),
tf_func_name, func_type);
}
return SymbolRefAttr::get(tf_func_name, rewriter.getContext());
}
protected:
virtual StringRef GetFuncName() const = 0;
virtual LLVMType GetFuncType() const = 0;
};
class TFAllocOpConverter : public ConvertToLLVMCallOpPattern<TFAllocOp> {
public:
using ConvertToLLVMCallOpPattern<TFAllocOp>::ConvertToLLVMCallOpPattern;
LogicalResult matchAndRewrite(
TFAllocOp tf_alloc_op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
mlir::Operation *op = tf_alloc_op.getOperation();
Location loc = op->getLoc();
TFAllocOp::Adaptor transformed(operands);
MemRefType memref_type = tf_alloc_op.getType();
// Get memref descriptor sizes.
SmallVector<Value, 4> sizes;
SmallVector<Value, 4> strides;
Value sizeBytes;
getMemRefDescriptorSizes(loc, memref_type,
llvm::to_vector<4>(transformed.dyn_sizes()),
rewriter, sizes, strides, sizeBytes);
// Get number of elements.
Value num_elements = getNumElements(loc, sizes, rewriter);
// Get element size.
Value element_size =
getSizeInBytes(loc, memref_type.getElementType(), rewriter);
// Convert `output_index` or set it to -1 if the attribute is missing.
LLVM::LLVMType llvmInt32Type =
LLVM::LLVMType::getInt32Ty(rewriter.getContext());
Value output_index = rewriter.create<LLVM::ConstantOp>(
loc, llvmInt32Type,
rewriter.getI32IntegerAttr(tf_alloc_op.output_index().hasValue()
? tf_alloc_op.output_index().getValue()
: -1));
// Convert `candidate_input_indices`.
auto candidates_count_and_ptr = ConvertI32ArrayAttrToStackAllocatedArray(
loc, tf_alloc_op.input_indices(), &rewriter);
// Insert function call.
FlatSymbolRefAttr tf_func_ref = getOrInsertTFFunction(rewriter, op);
Value allocated_byte_ptr =
rewriter
.create<LLVM::CallOp>(
loc, getVoidPtrType(), tf_func_ref,
llvm::makeArrayRef({transformed.ctx(), num_elements,
element_size, output_index,
candidates_count_and_ptr.first,
candidates_count_and_ptr.second}))
.getResult(0);
MemRefDescriptor memRefDescriptor = CreateMemRefDescriptor(
loc, rewriter, memref_type, allocated_byte_ptr, sizes);
// Return the final value of the descriptor.
rewriter.replaceOp(op, {memRefDescriptor});
return success();
}
protected:
StringRef GetFuncName() const override { return kCInterfaceAlloc; }
LLVMType GetFuncType() const override {
LLVMType llvm_i32_type =
LLVM::LLVMType::getInt32Ty(getDialect().getContext());
LLVMType llvm_i32_ptr_type = llvm_i32_type.getPointerTo();
LLVMType llvm_void_ptr_type = getVoidPtrType();
return LLVMType::getFunctionTy(
llvm_void_ptr_type,
llvm::makeArrayRef(
{/*void* op_kernel_ctx*/ llvm_void_ptr_type,
/*size_t num_elements*/ getIndexType(),
/*size_t element_size*/ getIndexType(),
/*int32_t output_index*/ llvm_i32_type,
/*int32_t num_candidates*/ llvm_i32_type,
/*int32_t* candidate_input_indices*/ llvm_i32_ptr_type}),
/*isVarArg=*/false);
}
private:
MemRefDescriptor CreateMemRefDescriptor(Location loc,
ConversionPatternRewriter &rewriter,
MemRefType memref_type,
Value allocated_byte_ptr,
ArrayRef<Value> sizes) const {
auto memref_desc = MemRefDescriptor::undef(
rewriter, loc, typeConverter->convertType(memref_type));
// TF AllocateRaw returns aligned pointer => AllocatedPtr == AlignedPtr.
Value allocated_type_ptr = rewriter.create<LLVM::BitcastOp>(
loc, getElementPtrType(memref_type), allocated_byte_ptr);
memref_desc.setAllocatedPtr(rewriter, loc, allocated_type_ptr);
memref_desc.setAlignedPtr(rewriter, loc, allocated_type_ptr);
memref_desc.setConstantOffset(rewriter, loc, 0);
if (memref_type.getRank() == 0) {
return memref_desc;
}
// Compute strides and populate descriptor `size` and `stride` fields.
Value stride_carried = createIndexConstant(rewriter, loc, 1);
for (int pos = sizes.size() - 1; pos >= 0; --pos) {
Value size = sizes[pos];
memref_desc.setSize(rewriter, loc, pos, size);
memref_desc.setStride(rewriter, loc, pos, stride_carried);
// Update stride
if (pos > 0) {
stride_carried =
rewriter.create<LLVM::MulOp>(loc, stride_carried, size);
}
}
return memref_desc;
}
std::pair<Value, Value> ConvertI32ArrayAttrToStackAllocatedArray(
Location loc, llvm::Optional<ArrayAttr> attr,
ConversionPatternRewriter *rewriter) const {
LLVMType llvm_i32_type =
LLVM::LLVMType::getInt32Ty(getDialect().getContext());
LLVMType llvm_i32_ptr_type = llvm_i32_type.getPointerTo();
// If the attribute is missing or empty, set the element count to 0 and
// return NULL.
if (!attr.hasValue() || attr.getValue().empty()) {
Value zero = rewriter->create<LLVM::ConstantOp>(
loc, llvm_i32_type, rewriter->getI32IntegerAttr(0));
Value null_ptr = rewriter->create<LLVM::NullOp>(loc, llvm_i32_ptr_type);
return std::make_pair(zero, null_ptr);
}
// Allocate array to store the elements.
auto &array_attr = attr.getValue();
Value array_size = rewriter->create<LLVM::ConstantOp>(
loc, llvm_i32_type, rewriter->getI32IntegerAttr(array_attr.size()));
Value array_ptr = rewriter->create<LLVM::AllocaOp>(
loc, llvm_i32_ptr_type, array_size, /*alignment=*/0);
for (auto &dim : llvm::enumerate(array_attr)) {
Value index = rewriter->create<LLVM::ConstantOp>(
loc, llvm_i32_type, rewriter->getI32IntegerAttr(dim.index()));
Value elem_ptr = rewriter->create<LLVM::GEPOp>(loc, llvm_i32_ptr_type,
array_ptr, index);
Value elem = rewriter->create<LLVM::ConstantOp>(
loc, llvm_i32_type,
rewriter->getI32IntegerAttr(
dim.value().cast<IntegerAttr>().getInt()));
rewriter->create<LLVM::StoreOp>(loc, elem, elem_ptr);
}
return std::make_pair(array_size, array_ptr);
}
};
class TFDeallocOpConverter : public ConvertToLLVMCallOpPattern<TFDeallocOp> {
public:
using ConvertToLLVMCallOpPattern<TFDeallocOp>::ConvertToLLVMCallOpPattern;
LogicalResult matchAndRewrite(
TFDeallocOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
TFDeallocOp::Adaptor transformed(operands);
MemRefDescriptor memref(transformed.memref());
Value allocated_bytes_ptr = rewriter.create<LLVM::BitcastOp>(
op.getLoc(), getVoidPtrType(),
memref.allocatedPtr(rewriter, op.getLoc()));
// Insert function call.
FlatSymbolRefAttr tf_func_ref = getOrInsertTFFunction(rewriter, op);
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
op, llvm::None, tf_func_ref,
llvm::makeArrayRef({transformed.ctx(), allocated_bytes_ptr}));
return success();
}
protected:
StringRef GetFuncName() const override { return kCInterfaceDealloc; }
LLVMType GetFuncType() const override {
return LLVM::LLVMType::getFunctionTy(getVoidType(),
{getVoidPtrType(), getVoidPtrType()},
/*isVarArg=*/false);
}
};
class ReportErrorOpConverter
: public ConvertToLLVMCallOpPattern<ReportErrorOp> {
public:
using ConvertToLLVMCallOpPattern<ReportErrorOp>::ConvertToLLVMCallOpPattern;
LogicalResult matchAndRewrite(
ReportErrorOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
ReportErrorOp::Adaptor transformed(operands,
op.getOperation()->getAttrDictionary());
Location loc = op.getLoc();
auto module = op.getParentOfType<ModuleOp>();
Value message_constant = GenerateErrorMessageConstant(
loc, module, transformed.msg().getValue(), rewriter);
// Insert function call.
FlatSymbolRefAttr tf_func_ref = getOrInsertTFFunction(rewriter, op);
Value error_code = rewriter.create<LLVM::ConstantOp>(
loc, typeConverter->convertType(rewriter.getI32Type()),
transformed.error_code());
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
op, llvm::None, tf_func_ref,
llvm::makeArrayRef({transformed.ctx(), error_code, message_constant}));
return success();
}
protected:
StringRef GetFuncName() const override { return kCInterfaceReportError; }
LLVMType GetFuncType() const override {
MLIRContext *ctx = &getTypeConverter()->getContext();
auto i8_ptr_type = LLVM::LLVMType::getInt8Ty(ctx).getPointerTo();
auto i32_type = LLVM::LLVMType::getInt32Ty(ctx);
return LLVM::LLVMType::getFunctionTy(
getVoidType(), {getVoidPtrType(), i32_type, i8_ptr_type},
/*isVarArg=*/false);
}
private:
// Generates an LLVM IR dialect global that contains the name of the given
// kernel function as a C string, and returns a pointer to its beginning.
Value GenerateErrorMessageConstant(Location loc, Operation *module,
StringRef message,
OpBuilder &builder) const {
std::string loc_str;
llvm::raw_string_ostream loc_stream(loc_str);
loc_stream << message << " at ";
loc.print(loc_stream);
StringRef generated_error(loc_stream.str().c_str());
std::string global_name =
llvm::formatv("error_message_{0}", llvm::hash_value(generated_error));
Operation *global_constant =
SymbolTable::lookupNearestSymbolFrom(module, global_name);
if (global_constant) {
Value globalPtr = builder.create<LLVM::AddressOfOp>(
loc, cast<LLVM::GlobalOp>(global_constant));
MLIRContext *ctx = &getTypeConverter()->getContext();
Value c0 = builder.create<LLVM::ConstantOp>(
loc, LLVM::LLVMType::getInt64Ty(ctx),
builder.getIntegerAttr(builder.getIndexType(), 0));
return builder.create<LLVM::GEPOp>(loc, LLVM::LLVMType::getInt8PtrTy(ctx),
globalPtr, ValueRange{c0, c0});
}
return LLVM::createGlobalString(loc, builder, global_name, generated_error,
LLVM::Linkage::Internal);
}
};
class NullContextOpConverter : public ConvertOpToLLVMPattern<NullContextOp> {
public:
using ConvertOpToLLVMPattern<NullContextOp>::ConvertOpToLLVMPattern;
LogicalResult matchAndRewrite(
NullContextOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<LLVM::NullOp>(op, getVoidPtrType());
return success();
}
};
class NullMemRefOpConverter : public ConvertOpToLLVMPattern<NullMemRefOp> {
public:
using ConvertOpToLLVMPattern<NullMemRefOp>::ConvertOpToLLVMPattern;
LogicalResult matchAndRewrite(
NullMemRefOp null_memref_op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
mlir::Operation *op = null_memref_op.getOperation();
Location loc = op->getLoc();
auto result_type = null_memref_op.getType().cast<UnrankedMemRefType>();
LLVMType llvm_result_type =
typeConverter->convertType(result_type).cast<LLVMType>();
auto desc =
UnrankedMemRefDescriptor::undef(rewriter, loc, llvm_result_type);
Value zero = createIndexConstant(rewriter, loc, 0);
desc.setRank(rewriter, loc, zero);
// Due to the current way of handling unranked memref results escaping, we
// have to actually construct a ranked underlying descriptor instead of just
// setting its pointer to NULL.
SmallVector<Value, 4> sizes;
UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
desc, sizes);
Value underlying_desc_ptr = rewriter.create<LLVM::AllocaOp>(
loc, getVoidPtrType(), sizes.front(), llvm::None);
// Populate underlying ranked descriptor.
unsigned address_space = result_type.getMemorySpace();
Type elem_type = result_type.getElementType();
LLVM::LLVMType llvm_elem_type =
typeConverter->convertType(elem_type).cast<LLVMType>();
LLVM::LLVMType elem_ptr_ptr_type =
llvm_elem_type.getPointerTo(address_space).getPointerTo();
auto nullPtr = rewriter.create<LLVM::NullOp>(
loc, llvm_elem_type.getPointerTo(address_space));
UnrankedMemRefDescriptor::setAllocatedPtr(
rewriter, loc, underlying_desc_ptr, elem_ptr_ptr_type, nullPtr);
UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(),
underlying_desc_ptr,
elem_ptr_ptr_type, nullPtr);
UnrankedMemRefDescriptor::setOffset(rewriter, loc, *getTypeConverter(),
underlying_desc_ptr, elem_ptr_ptr_type,
zero);
desc.setMemRefDescPtr(rewriter, loc, underlying_desc_ptr);
rewriter.replaceOp(op, {desc});
return success();
}
};
} // namespace
void PopulateTFFrameworkToLLVMConversionPatterns(
LLVMTypeConverter *converter, OwningRewritePatternList *patterns) {
// clang-format off
patterns->insert<
NullContextOpConverter,
NullMemRefOpConverter,
ReportErrorOpConverter,
TFAllocOpConverter,
TFDeallocOpConverter
>(*converter);
// clang-format on
}
} // namespace tf_framework
} // namespace kernel_gen
} // namespace mlir