fix
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h
index d70c4bb..3403433 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h
@@ -136,10 +136,12 @@
} // namespace chlo
class LLVMTypeConverter;
+class SymbolTable;
namespace disc_ral {
void populateDiscRalToLLVMConversionPatterns(LLVMTypeConverter* converter,
+ SymbolTable* symbol_table,
RewritePatternSet* patterns);
} // namespace disc_ral
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/ral_legalize_to_llvm.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/ral_legalize_to_llvm.cc
index 7ae2f4d..aef0f66 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/ral_legalize_to_llvm.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/ral_legalize_to_llvm.cc
@@ -152,17 +152,13 @@
// Creates a global const string op named `name` using the value if not exists
// and returns the Loaded value of this global op.
-Value loadOrCreateGlobalString(PatternRewriter& rewriter, Operation* op,
+Value loadOrCreateGlobalString(PatternRewriter& rewriter,
+ SymbolTable& symbol_table, Operation* op,
StringRef name, StringRef value) {
ModuleOp module = op->getParentOfType<ModuleOp>();
- GlobalOp globalOp = module.lookupSymbol<GlobalOp>(name);
+ GlobalOp globalOp = symbol_table.lookup<GlobalOp>(name);
if (!globalOp) {
OpBuilder::InsertionGuard guard(rewriter);
- // Double-check under the lock
- if (globalOp = module.lookupSymbol<GlobalOp>(name)) {
- assert(checkGlobalOpContent(globalOp, value));
- return loadGlobalString(rewriter, op->getLoc(), globalOp);
- }
OpBuilder::InsertPoint ip = rewriter.saveInsertionPoint();
rewriter.setInsertionPointToStart(module.getBody());
@@ -172,6 +168,9 @@
op->getLoc(), type, /*isConstant=*/true, LLVM::Linkage::Internal, name,
rewriter.getStringAttr(value), /*alignment=*/0);
+ // Update the symbol table
+ symbol_table.insert(globalOp);
+
rewriter.restoreInsertionPoint(ip);
} else {
assert(checkGlobalOpContent(globalOp, value));
@@ -182,7 +181,10 @@
// Converts a ral.dispatch_op to its llvm format.
struct DispatchOpToLLVMPattern : ConvertOpToLLVMPattern<DispatchOp> {
- using ConvertOpToLLVMPattern<DispatchOp>::ConvertOpToLLVMPattern;
+ DispatchOpToLLVMPattern(LLVMTypeConverter& type_converter,
+ SymbolTable& symbol_table)
+ : ConvertOpToLLVMPattern<DispatchOp>(type_converter),
+ symbol_table_(symbol_table) {}
// Returns the ral dispatch function and inserts the declaration if not found.
LLVMFuncOp getOrInsertDispatchFunction(PatternRewriter& rewriter,
@@ -199,6 +201,9 @@
LogicalResult matchAndRewrite(
DispatchOp dispatch_op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const override;
+
+ private:
+ SymbolTable& symbol_table_;
};
// Returns the llvm function definition of ral dispatch op and creates it first
@@ -206,16 +211,12 @@
LLVMFuncOp DispatchOpToLLVMPattern::getOrInsertDispatchFunction(
PatternRewriter& rewriter, Operation* op) const {
ModuleOp module = op->getParentOfType<ModuleOp>();
- LLVMFuncOp func = module.lookupSymbol<LLVMFuncOp>(kRalDispatchFunctionName);
+ LLVMFuncOp func = symbol_table_.lookup<LLVMFuncOp>(kRalDispatchFunctionName);
if (func) return func;
// Try to insert the function since it's not found.
OpBuilder::InsertionGuard guard(rewriter);
- // Double-check under the lock
- if (func = module.lookupSymbol<LLVMFuncOp>(kRalDispatchFunctionName)) {
- return func;
- }
OpBuilder::InsertPoint ip = rewriter.saveInsertionPoint();
rewriter.setInsertionPointToStart(module.getBody());
Type llvm_pointer_type =
@@ -232,6 +233,9 @@
llvm_pointer_pointer_type /* void** args */
},
/*isVarArg=*/false));
+
+ symbol_table_.insert(func);
+
rewriter.restoreInsertionPoint(ip);
return func;
@@ -333,7 +337,8 @@
callOpOperands.push_back(adaptor.ctx());
// the second argument is the target name
callOpOperands.push_back(loadOrCreateGlobalString(
- rewriter, op, target_name.str().drop_back(), target_name.str()));
+ rewriter, symbol_table_, op, target_name.str().drop_back(),
+ target_name.str()));
// the third argument is the args for target function
callOpOperands.push_back(packedArgs);
@@ -356,8 +361,10 @@
class ConvertLaunchFuncOpToRalCallPattern
: public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
public:
- ConvertLaunchFuncOpToRalCallPattern(LLVMTypeConverter& type_converter)
- : ConvertOpToLLVMPattern<gpu::LaunchFuncOp>(type_converter) {}
+ ConvertLaunchFuncOpToRalCallPattern(LLVMTypeConverter& type_converter,
+ SymbolTable& symbol_table)
+ : ConvertOpToLLVMPattern<gpu::LaunchFuncOp>(type_converter),
+ symbol_table_(symbol_table) {}
private:
Value generateParamsArray(gpu::LaunchFuncOp launch_op,
@@ -368,6 +375,8 @@
LogicalResult matchAndRewrite(
gpu::LaunchFuncOp launch_op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const override;
+
+ SymbolTable& symbol_table_;
};
// Creates a struct containing all kernel parameters on the stack and returns
@@ -463,8 +472,8 @@
StrT name_buffer(kernel_module.getName());
name_buffer.append("_blob");
- Value module_blob = loadOrCreateGlobalString(rewriter, op, name_buffer.str(),
- binary_attr.getValue());
+ Value module_blob = loadOrCreateGlobalString(
+ rewriter, symbol_table_, op, name_buffer.str(), binary_attr.getValue());
// Make sure the trailing zero is included in the constant.
auto kernel_name = launch_op.getKernelName();
@@ -477,7 +486,8 @@
(kernel_module.getName() + "_" + kernel_name + "_kernel_name")
.toStringRef(kernel_name_global_name_buffer);
Value kernel_name_global = loadOrCreateGlobalString(
- rewriter, op, kernel_name_global_name, kernel_name_buffer.str());
+ rewriter, symbol_table_, op, kernel_name_global_name,
+ kernel_name_buffer.str());
auto adaptor =
gpu::LaunchFuncOpAdaptor(operands, launch_op->getAttrDictionary());
@@ -517,6 +527,7 @@
public:
void runOnOperation() override {
ModuleOp m = getOperation();
+ SymbolTable symbol_table(m);
// Populate type conversions.
MLIRContext* ctx = m.getContext();
@@ -529,7 +540,8 @@
RewritePatternSet patterns(&getContext());
populateStdExpandOpsPatterns(patterns);
populateStdToLLVMConversionPatterns(type_converter, patterns);
- populateDiscRalToLLVMConversionPatterns(&type_converter, &patterns);
+ populateDiscRalToLLVMConversionPatterns(&type_converter, &symbol_table,
+ &patterns);
// Set target.
ConversionTarget target(*ctx);
@@ -556,12 +568,13 @@
} // namespace
void populateDiscRalToLLVMConversionPatterns(LLVMTypeConverter* converter,
+ SymbolTable* symbol_table,
RewritePatternSet* patterns) {
// clang-format off
patterns->insert<
ConvertLaunchFuncOpToRalCallPattern,
DispatchOpToLLVMPattern
- >(*converter);
+ >(*converter, *symbol_table);
// clang-format on
}