fix
diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h
index d70c4bb..3403433 100644
--- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h
+++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h
@@ -136,10 +136,12 @@
 }  // namespace chlo
 
 class LLVMTypeConverter;
+class SymbolTable;
 
 namespace disc_ral {
 
 void populateDiscRalToLLVMConversionPatterns(LLVMTypeConverter* converter,
+                                             SymbolTable* symbol_table,
                                              RewritePatternSet* patterns);
 
 }  // namespace disc_ral
diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/ral_legalize_to_llvm.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/ral_legalize_to_llvm.cc
index 7ae2f4d..aef0f66 100644
--- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/ral_legalize_to_llvm.cc
+++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/ral_legalize_to_llvm.cc
@@ -152,17 +152,13 @@
 
 // Creates a global const string op named `name` using the value if not exists
 // and returns the Loaded value of this global op.
-Value loadOrCreateGlobalString(PatternRewriter& rewriter, Operation* op,
+Value loadOrCreateGlobalString(PatternRewriter& rewriter,
+                               SymbolTable& symbol_table, Operation* op,
                                StringRef name, StringRef value) {
   ModuleOp module = op->getParentOfType<ModuleOp>();
-  GlobalOp globalOp = module.lookupSymbol<GlobalOp>(name);
+  GlobalOp globalOp = symbol_table.lookup<GlobalOp>(name);
   if (!globalOp) {
     OpBuilder::InsertionGuard guard(rewriter);
-    // Double-check under the lock
-    if (globalOp = module.lookupSymbol<GlobalOp>(name)) {
-      assert(checkGlobalOpContent(globalOp, value));
-      return loadGlobalString(rewriter, op->getLoc(), globalOp);
-    }
     OpBuilder::InsertPoint ip = rewriter.saveInsertionPoint();
     rewriter.setInsertionPointToStart(module.getBody());
 
@@ -172,6 +168,9 @@
         op->getLoc(), type, /*isConstant=*/true, LLVM::Linkage::Internal, name,
         rewriter.getStringAttr(value), /*alignment=*/0);
 
+    // Update the symbol table
+    symbol_table.insert(globalOp);
+
     rewriter.restoreInsertionPoint(ip);
   } else {
     assert(checkGlobalOpContent(globalOp, value));
@@ -182,7 +181,10 @@
 
 // Converts a ral.dispatch_op to its llvm format.
 struct DispatchOpToLLVMPattern : ConvertOpToLLVMPattern<DispatchOp> {
-  using ConvertOpToLLVMPattern<DispatchOp>::ConvertOpToLLVMPattern;
+  DispatchOpToLLVMPattern(LLVMTypeConverter& type_converter,
+                          SymbolTable& symbol_table)
+      : ConvertOpToLLVMPattern<DispatchOp>(type_converter),
+        symbol_table_(symbol_table) {}
 
   // Returns the ral dispatch function and inserts the declaration if not found.
   LLVMFuncOp getOrInsertDispatchFunction(PatternRewriter& rewriter,
@@ -199,6 +201,9 @@
   LogicalResult matchAndRewrite(
       DispatchOp dispatch_op, ArrayRef<Value> operands,
       ConversionPatternRewriter& rewriter) const override;
+
+ private:
+  SymbolTable& symbol_table_;
 };
 
 // Returns the llvm function definition of ral dispatch op and creates it first
@@ -206,16 +211,12 @@
 LLVMFuncOp DispatchOpToLLVMPattern::getOrInsertDispatchFunction(
     PatternRewriter& rewriter, Operation* op) const {
   ModuleOp module = op->getParentOfType<ModuleOp>();
-  LLVMFuncOp func = module.lookupSymbol<LLVMFuncOp>(kRalDispatchFunctionName);
+  LLVMFuncOp func = symbol_table_.lookup<LLVMFuncOp>(kRalDispatchFunctionName);
 
   if (func) return func;
 
   // Try to insert the function since it's not found.
   OpBuilder::InsertionGuard guard(rewriter);
-  // Double-check under the lock
-  if (func = module.lookupSymbol<LLVMFuncOp>(kRalDispatchFunctionName)) {
-    return func;
-  }
   OpBuilder::InsertPoint ip = rewriter.saveInsertionPoint();
   rewriter.setInsertionPointToStart(module.getBody());
   Type llvm_pointer_type =
@@ -232,6 +233,9 @@
               llvm_pointer_pointer_type /* void** args */
           },
           /*isVarArg=*/false));
+
+  symbol_table_.insert(func);
+
   rewriter.restoreInsertionPoint(ip);
 
   return func;
@@ -333,7 +337,8 @@
   callOpOperands.push_back(adaptor.ctx());
   // the second argument is the target name
   callOpOperands.push_back(loadOrCreateGlobalString(
-      rewriter, op, target_name.str().drop_back(), target_name.str()));
+      rewriter, symbol_table_, op, target_name.str().drop_back(),
+      target_name.str()));
   // the third argument is the args for target function
   callOpOperands.push_back(packedArgs);
 
@@ -356,8 +361,10 @@
 class ConvertLaunchFuncOpToRalCallPattern
     : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
  public:
-  ConvertLaunchFuncOpToRalCallPattern(LLVMTypeConverter& type_converter)
-      : ConvertOpToLLVMPattern<gpu::LaunchFuncOp>(type_converter) {}
+  ConvertLaunchFuncOpToRalCallPattern(LLVMTypeConverter& type_converter,
+                                      SymbolTable& symbol_table)
+      : ConvertOpToLLVMPattern<gpu::LaunchFuncOp>(type_converter),
+        symbol_table_(symbol_table) {}
 
  private:
   Value generateParamsArray(gpu::LaunchFuncOp launch_op,
@@ -368,6 +375,8 @@
   LogicalResult matchAndRewrite(
       gpu::LaunchFuncOp launch_op, ArrayRef<Value> operands,
       ConversionPatternRewriter& rewriter) const override;
+
+  SymbolTable& symbol_table_;
 };
 
 // Creates a struct containing all kernel parameters on the stack and returns
@@ -463,8 +472,8 @@
   StrT name_buffer(kernel_module.getName());
   name_buffer.append("_blob");
 
-  Value module_blob = loadOrCreateGlobalString(rewriter, op, name_buffer.str(),
-                                               binary_attr.getValue());
+  Value module_blob = loadOrCreateGlobalString(
+      rewriter, symbol_table_, op, name_buffer.str(), binary_attr.getValue());
 
   // Make sure the trailing zero is included in the constant.
   auto kernel_name = launch_op.getKernelName();
@@ -477,7 +486,8 @@
       (kernel_module.getName() + "_" + kernel_name + "_kernel_name")
           .toStringRef(kernel_name_global_name_buffer);
   Value kernel_name_global = loadOrCreateGlobalString(
-      rewriter, op, kernel_name_global_name, kernel_name_buffer.str());
+      rewriter, symbol_table_, op, kernel_name_global_name,
+      kernel_name_buffer.str());
 
   auto adaptor =
       gpu::LaunchFuncOpAdaptor(operands, launch_op->getAttrDictionary());
@@ -517,6 +527,7 @@
  public:
   void runOnOperation() override {
     ModuleOp m = getOperation();
+    SymbolTable symbol_table(m);
 
     // Populate type conversions.
     MLIRContext* ctx = m.getContext();
@@ -529,7 +540,8 @@
     RewritePatternSet patterns(&getContext());
     populateStdExpandOpsPatterns(patterns);
     populateStdToLLVMConversionPatterns(type_converter, patterns);
-    populateDiscRalToLLVMConversionPatterns(&type_converter, &patterns);
+    populateDiscRalToLLVMConversionPatterns(&type_converter, &symbol_table,
+                                            &patterns);
 
     // Set target.
     ConversionTarget target(*ctx);
@@ -556,12 +568,13 @@
 }  // namespace
 
 void populateDiscRalToLLVMConversionPatterns(LLVMTypeConverter* converter,
+                                             SymbolTable* symbol_table,
                                              RewritePatternSet* patterns) {
   // clang-format off
   patterns->insert<
       ConvertLaunchFuncOpToRalCallPattern,
       DispatchOpToLLVMPattern
-    >(*converter);
+    >(*converter, *symbol_table);
   // clang-format on
 }