[XLA:CPU] Plumb through a minimal emitter for matmuls using the mlir linalg dialect

This is just the most basic lowering and will generate linalg.matmul for small
matmuls and then convert to loops. The result is fairly slow, but we can
iterate on that.

To make XLA use it set XLA_FLAGS=--xla_backend_extra_options=xla_use_linalg_for_dot

PiperOrigin-RevId: 312443993
Change-Id: Icaf20764c954803f5b1bacfaa9456839b28ba52c
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 2f432cd..3460e65 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -118,6 +118,9 @@
         ":target_machine_features",
         "@com_google_absl//absl/base",
         "@com_google_absl//absl/types:span",
+        "@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
+        "@llvm-project//mlir:ExecutionEngineUtils",
+        "@llvm-project//mlir:LLVMDialect",
         "//tensorflow/compiler/xla/service:copy_insertion",
         "//tensorflow/compiler/xla/service:hlo_casting_utils",
         "//tensorflow/compiler/xla/service:dump",
@@ -366,6 +369,7 @@
         "@llvm-project//llvm:core",
         "@llvm-project//llvm:support",
         "@llvm-project//llvm:target",
+        "@llvm-project//mlir:IR",
     ],
 )
 
@@ -456,6 +460,7 @@
         ":cpu_options",
         ":cpu_runtime",
         ":ir_emission_utils",
+        ":mlir_emitter",
         ":target_machine_features",
         ":tiled_dot_emitter",
         ":vector_support_library",
@@ -474,6 +479,10 @@
         "//tensorflow/core:lib",
         "@com_google_absl//absl/strings",
         "@llvm-project//llvm:core",
+        "@llvm-project//mlir:EDSC",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:LinalgOps",
+        "@llvm-project//mlir:StandardOps",
     ],
 )
 
@@ -1070,3 +1079,24 @@
         "@llvm-project//llvm:target",
     ],
 )
+
+cc_library(
+    name = "mlir_emitter",
+    srcs = ["mlir_emitter.cc"],
+    hdrs = ["mlir_emitter.h"],
+    deps = [
+        "//tensorflow/compiler/mlir/xla:hlo_utils",
+        "//tensorflow/compiler/xla:shape_util",
+        "//tensorflow/compiler/xla:status",
+        "@llvm-project//llvm:core",
+        "@llvm-project//llvm:ipo",
+        "@llvm-project//llvm:linker",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:LLVMTransforms",
+        "@llvm-project//mlir:LinalgToLLVM",
+        "@llvm-project//mlir:LinalgTransforms",
+        "@llvm-project//mlir:Pass",
+        "@llvm-project//mlir:TargetLLVMIR",
+        "@llvm-project//mlir:VectorToLLVM",
+    ],
+)
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index fe769bb..b2416ac 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -42,6 +42,8 @@
 #include "llvm/Support/TargetSelect.h"
 #include "llvm/Target/TargetMachine.h"
 #include "llvm/Target/TargetOptions.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"  // from @llvm-project
+#include "mlir/InitAllDialects.h"  // from @llvm-project
 #include "tensorflow/compiler/xla/cpu_function_runtime.h"
 #include "tensorflow/compiler/xla/literal.h"
 #include "tensorflow/compiler/xla/map_util.h"
@@ -158,6 +160,8 @@
   // Initialize LLVM's MC layer for the native target.
   llvm::InitializeNativeTarget();
   llvm::InitializeNativeTargetAsmPrinter();
+
+  mlir::registerAllDialects();
 }
 
 namespace {
@@ -606,9 +610,11 @@
                        user_post_optimization_hook_);
 
   // Compile must be thread-safe so create a new LLVM context for the module.
-  auto llvm_context = absl::make_unique<llvm::LLVMContext>();
-  auto llvm_module =
-      absl::make_unique<llvm::Module>("__compute_module", *llvm_context);
+  mlir::MLIRContext mlir_context;
+  auto llvm_module = absl::make_unique<llvm::Module>(
+      "__compute_module",
+      mlir_context.getRegisteredDialect<mlir::LLVM::LLVMDialect>()
+          ->getLLVMContext());
 
   auto jit = absl::make_unique<SimpleOrcJIT>(
       CompilerTargetOptions(module->config()),
@@ -662,7 +668,7 @@
   // before a caller computation.
 
   LLVMTargetMachineFeatures target_machine_features(jit->target_machine());
-  IrEmitter ir_emitter(*module, *assignment, llvm_module.get(),
+  IrEmitter ir_emitter(&mlir_context, *module, *assignment, llvm_module.get(),
                        std::move(instruction_to_profile_idx),
                        std::move(computation_to_profile_idx),
                        &target_machine_features,
@@ -816,8 +822,11 @@
           opt_level));
 
   // Compile must be thread-safe so create a new LLVM context for the module.
-  llvm::LLVMContext llvm_context;
-  llvm::Module llvm_module("__compute_module", llvm_context);
+  mlir::MLIRContext mlir_context;
+  llvm::Module llvm_module(
+      "__compute_module",
+      mlir_context.getRegisteredDialect<mlir::LLVM::LLVMDialect>()
+          ->getLLVMContext());
   llvm_module.setDataLayout(target_machine->createDataLayout());
   llvm_module.setTargetTriple(triple.getTriple());
   if (pic_level != llvm::PICLevel::NotPIC) {
@@ -866,7 +875,7 @@
     }
 
     LLVMTargetMachineFeatures target_machine_features(target_machine.get());
-    IrEmitter ir_emitter(*module, *assignment, &llvm_module,
+    IrEmitter ir_emitter(&mlir_context, *module, *assignment, &llvm_module,
                          std::move(instruction_to_profile_idx),
                          std::move(computation_to_profile_idx),
                          &target_machine_features,
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc
index ff654c8..c022201 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc
@@ -25,6 +25,7 @@
 const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor";
 const char* const kXlaForceEnableExperimentalLlvmIrGemm =
     "xla_force_enable_experimental_llvm_ir_gemm";
+const char* const kXlaUseLinalgForDot = "xla_use_linalg_for_dot";
 const char* const kLlvmIrGemmTileSize = "xla_llvm_ir_gemm_tile_size";
 
 }  // namespace
@@ -63,6 +64,12 @@
   return extra_options_map.count(kXlaForceEnableExperimentalLlvmIrGemm) > 0;
 }
 
+bool UseLinalgForDot(const HloModuleConfig& config) {
+  const auto& extra_options_map =
+      config.debug_options().xla_backend_extra_options();
+  return extra_options_map.count(kXlaUseLinalgForDot) > 0;
+}
+
 static absl::string_view RemoveSuffix(absl::string_view str,
                                       absl::string_view suffix) {
   CHECK_GE(str.size(), suffix.size());
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.h b/tensorflow/compiler/xla/service/cpu/cpu_options.h
index 99e6702..5d25aef 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_options.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_options.h
@@ -27,6 +27,7 @@
 bool OptimizeForSizeRequested(const HloModuleConfig& config);
 bool VectorizedReduceDisabled(const HloModuleConfig& config);
 bool ForceEnableExperimentalLlvmIrGemm(const HloModuleConfig& config);
+bool UseLinalgForDot(const HloModuleConfig& config);
 absl::optional<int64> LlvmIrGemvTilingFactor(const HloModuleConfig& config);
 absl::optional<std::tuple<int64, int64, int64>> LlvmIrGemmTileSize(
     const HloModuleConfig& config);
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
index 7dba826..e1ad146 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
@@ -23,8 +23,17 @@
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Value.h"
+#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h"  // from @llvm-project
+#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"  // from @llvm-project
+#include "mlir/EDSC/Builders.h"  // from @llvm-project
+#include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/Function.h"  // from @llvm-project
+#include "mlir/IR/MLIRContext.h"  // from @llvm-project
+#include "mlir/IR/Value.h"  // from @llvm-project
+#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
 #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/cpu/mlir_emitter.h"
 #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
 #include "tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h"
 #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
@@ -89,6 +98,9 @@
   // and the output have to be row major.
   kTiledLlvmIrGemm,
 
+  // The dot operation is lowered into linalg.matmul op and lowered to LLVM IR.
+  kLinalgMatmul,
+
   // The dot operation is lowered into a call into an Eigen routine.  No fusions
   // are supported today.  The two inputs and the output have to be row major.
   // However, we do allow transposing either the LHS or the RHS as part of the
@@ -112,7 +124,7 @@
                         const llvm_ir::IrArray& rhs_array,
                         const llvm_ir::IrArray* addend_array,
                         llvm::Value* executable_run_options_value,
-                        llvm::IRBuilder<>* b,
+                        llvm::IRBuilder<>* b, mlir::MLIRContext* mlir_context,
                         const HloModuleConfig& hlo_module_config,
                         const TargetMachineFeatures& target_machine_features);
 
@@ -163,6 +175,9 @@
   // Lowers the dot operation as a tiled Matrix*Matrix loop.
   void EmitTiledLlvmIrGemm();
 
+  // Lowers the dot operation through MLIR's linalg.matmul.
+  Status EmitLinalgMatmul();
+
   // Lowers the dot operation as a naive nested loop that computes the result
   // one element at a time.
   void EmitNaiveLlvmIrGemm();
@@ -194,20 +209,19 @@
   const llvm_ir::IrArray* addend_array_;
   llvm::Value* executable_run_options_value_;
   llvm::IRBuilder<>* b_;
+  mlir::MLIRContext* mlir_context_;
   const HloModuleConfig& hlo_module_config_;
   const TargetMachineFeatures& target_machine_features_;
 };
 }  // namespace
 
-DotOpEmitter::DotOpEmitter(DotInfo dot_info, string dot_hlo_name,
-                           const llvm_ir::IrArray& target_array,
-                           const llvm_ir::IrArray& lhs_array,
-                           const llvm_ir::IrArray& rhs_array,
-                           const llvm_ir::IrArray* addend_array,
-                           llvm::Value* executable_run_options_value,
-                           llvm::IRBuilder<>* b,
-                           const HloModuleConfig& hlo_module_config,
-                           const TargetMachineFeatures& target_machine_features)
+DotOpEmitter::DotOpEmitter(
+    DotInfo dot_info, string dot_hlo_name, const llvm_ir::IrArray& target_array,
+    const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
+    const llvm_ir::IrArray* addend_array,
+    llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
+    mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config,
+    const TargetMachineFeatures& target_machine_features)
     : dot_info_(std::move(dot_info)),
       dot_hlo_name_(std::move(dot_hlo_name)),
       target_array_(target_array),
@@ -216,9 +230,36 @@
       addend_array_(addend_array),
       executable_run_options_value_(executable_run_options_value),
       b_(b),
+      mlir_context_(mlir_context),
       hlo_module_config_(hlo_module_config),
       target_machine_features_(target_machine_features) {}
 
+Status DotOpEmitter::EmitLinalgMatmul() {
+  Shape operand_shapes[] = {dot_info_.lhs_shape, dot_info_.rhs_shape};
+  llvm::Value* operand_ptrs[] = {lhs_array_.GetBasePointer(),
+                                 rhs_array_.GetBasePointer()};
+  llvm::Value* target_ptr = target_array_.GetBasePointer();
+
+  // Zero out the output buffer.
+  int64 size_bytes = ShapeUtil::ByteSizeOf(dot_info_.result_shape);
+  b_->CreateMemSet(target_ptr, b_->getInt8(0), /*Size=*/size_bytes,
+                   /*Align=*/llvm::MaybeAlign(1));
+
+  std::string name =
+      absl::StrCat("linalgMatMul_", dot_info_.result_shape.ToString(true), "_",
+                   dot_info_.lhs_shape.ToString(true), "_",
+                   dot_info_.rhs_shape.ToString(true));
+  return EmitMlirFuncAndCall(
+      mlir_context_, b_, dot_info_.result_shape, operand_shapes, target_ptr,
+      operand_ptrs, name, [&](mlir::OpBuilder* builder, mlir::FuncOp function) {
+        mlir::edsc::ScopedContext scope(*builder, function.getLoc());
+        mlir::Value a = function.getArgument(0), b = function.getArgument(1),
+                    c = function.getArgument(2);
+        mlir::edsc::intrinsics::linalg_matmul(b, c, a);
+        mlir::edsc::intrinsics::std_ret();
+      });
+}
+
 void DotOpEmitter::EmitTiledLlvmIrGemm() {
   PrimitiveType primitive_type = dot_info_.result_shape.element_type();
   MatMultDims mat_mult_dims = GetMatMultDims();
@@ -418,6 +459,9 @@
       EmitTiledLlvmIrGemm();
       return Status::OK();
 
+    case DotImplementationStrategy::kLinalgMatmul:
+      return EmitLinalgMatmul();
+
     case DotImplementationStrategy::kEigen:
       return EmitCallToRuntime();
   }
@@ -886,9 +930,12 @@
   }
 
   if (IsAlignedGemm(dot_info, target_machine_features)) {
-    return CanEmitTiledLlvmIrGemm(config, dot_info, target_machine_features)
-               ? DotImplementationStrategy::kTiledLlvmIrGemm
-               : DotImplementationStrategy::kEigen;
+    if (CanEmitTiledLlvmIrGemm(config, dot_info, target_machine_features)) {
+      return options::UseLinalgForDot(config)
+                 ? DotImplementationStrategy::kLinalgMatmul
+                 : DotImplementationStrategy::kTiledLlvmIrGemm;
+    }
+    return DotImplementationStrategy::kEigen;
   }
 
   return DotImplementationStrategy::kNaiveLlvmIr;
@@ -899,15 +946,15 @@
     const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
     const llvm_ir::IrArray* addend_array,
     llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
-    const HloModuleConfig& hlo_module_config,
+    mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config,
     const TargetMachineFeatures& target_machine_features) {
   PrimitiveType type = target_array.GetShape().element_type();
   TF_RET_CHECK(S32 == type || F16 == type || F32 == type || F64 == type ||
                C64 == type || C128 == type);
   DotOpEmitter dot_emitter(std::move(dot_info), std::move(hlo_name),
                            target_array, lhs_array, rhs_array, addend_array,
-                           executable_run_options_value, b, hlo_module_config,
-                           target_machine_features);
+                           executable_run_options_value, b, mlir_context,
+                           hlo_module_config, target_machine_features);
   return dot_emitter.Emit();
 }
 
@@ -981,7 +1028,7 @@
     const HloInstruction& dot, const llvm_ir::IrArray& target_array,
     const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array,
     llvm::Value* executable_run_options_value, llvm::IRBuilder<>* b,
-    const HloModuleConfig& hlo_module_config,
+    mlir::MLIRContext* mlir_context, const HloModuleConfig& hlo_module_config,
     const TargetMachineFeatures& target_machine_features) {
   TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(dot.dot_dimension_numbers()));
 
@@ -1039,7 +1086,7 @@
         // Emit the inner non-batch dot operation.
         return EmitNonBatchDotOperation(
             dot_info, dot.name(), target_slice, lhs_slice, rhs_slice, nullptr,
-            executable_run_options_value, b, hlo_module_config,
+            executable_run_options_value, b, mlir_context, hlo_module_config,
             target_machine_features);
       });
 }
@@ -1089,7 +1136,7 @@
                         const llvm_ir::IrArray& rhs_array,
                         const llvm_ir::IrArray* addend_array,
                         llvm::Value* executable_run_options_value,
-                        llvm::IRBuilder<>* b,
+                        llvm::IRBuilder<>* b, mlir::MLIRContext* mlir_context,
                         const HloModuleConfig& hlo_module_config,
                         const TargetMachineFeatures& target_machine_features) {
   // This routine assumes that the dot operation is not in a parallelized
@@ -1099,13 +1146,13 @@
   if (IsBatchDot(dot)) {
     TF_RET_CHECK(addend_array == nullptr);
     return EmitBatchDotOperation(dot, target_array, lhs_array, rhs_array,
-                                 executable_run_options_value, b,
+                                 executable_run_options_value, b, mlir_context,
                                  hlo_module_config, target_machine_features);
   }
 
   return EmitNonBatchDotOperation(DotInfo(dot), dot.name(), target_array,
                                   lhs_array, rhs_array, addend_array,
-                                  executable_run_options_value, b,
+                                  executable_run_options_value, b, mlir_context,
                                   hlo_module_config, target_machine_features);
 }
 }  // namespace cpu
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
index 105bd30..d9cf8a2 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
@@ -18,6 +18,7 @@
 
 #include "absl/strings/string_view.h"
 #include "llvm/IR/IRBuilder.h"
+#include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
 #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -63,7 +64,7 @@
                         const llvm_ir::IrArray& rhs_array,
                         const llvm_ir::IrArray* addend_array,
                         llvm::Value* executable_run_options_value,
-                        llvm::IRBuilder<>* b,
+                        llvm::IRBuilder<>* b, mlir::MLIRContext* mlir_context,
                         const HloModuleConfig& hlo_module_config,
                         const TargetMachineFeatures& target_machine_features);
 }  // namespace cpu
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 70dde91..043ad68 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -89,8 +89,8 @@
 namespace cpu {
 
 IrEmitter::IrEmitter(
-    const HloModule& hlo_module, const BufferAssignment& assignment,
-    llvm::Module* llvm_module,
+    mlir::MLIRContext* mlir_context, const HloModule& hlo_module,
+    const BufferAssignment& assignment, llvm::Module* llvm_module,
     std::unordered_map<const HloInstruction*, int64> instruction_to_profile_idx,
     std::unordered_map<const HloComputation*, int64> computation_to_profile_idx,
     const TargetMachineFeatures* target_machine_features,
@@ -99,6 +99,7 @@
       module_(llvm_module),
       arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()),
       b_(llvm_module->getContext()),
+      mlir_context_(mlir_context),
       instruction_to_profile_idx_(std::move(instruction_to_profile_idx)),
       computation_to_profile_idx_(std::move(computation_to_profile_idx)),
       alias_analysis_(hlo_module, assignment, &llvm_module->getContext()),
@@ -898,7 +899,7 @@
   // Dot operation is complicated so we delegate to a helper class.
   return EmitDotOperation(*dot, target_array, lhs_array, rhs_array,
                           /*addend_array=*/nullptr,
-                          GetExecutableRunOptionsArgument(), &b_,
+                          GetExecutableRunOptionsArgument(), &b_, mlir_context_,
                           hlo_module_config_, target_machine_features_);
 }
 
@@ -2305,10 +2306,10 @@
     llvm_ir::IrArray addend_array(
         GetIrArrayFor(fusion->operand(addend_param_number)));
 
-    TF_RETURN_IF_ERROR(
-        EmitDotOperation(*dot, target_array, lhs_array, rhs_array,
-                         &addend_array, GetExecutableRunOptionsArgument(), &b_,
-                         hlo_module_config_, target_machine_features_));
+    TF_RETURN_IF_ERROR(EmitDotOperation(
+        *dot, target_array, lhs_array, rhs_array, &addend_array,
+        GetExecutableRunOptionsArgument(), &b_, mlir_context_,
+        hlo_module_config_, target_machine_features_));
     return Status::OK();
   } else {
     return Unimplemented("Fusion kind not implemented on CPU");
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 9b0d11e..6617851 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -17,6 +17,7 @@
 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_
 
 #include <stddef.h>
+
 #include <map>
 #include <memory>
 #include <string>
@@ -32,6 +33,7 @@
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Value.h"
 #include "llvm/Target/TargetMachine.h"
+#include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
 #include "tensorflow/compiler/xla/service/cpu/ir_function.h"
 #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
@@ -69,14 +71,16 @@
   // hlo_module: the HLO module we are emitting IR for.
   // assignment: a BufferAssignment from which we know which buffers are used by
   //             the HLO nodes.
-  // llvm_module: the LLVM module to emit IR into.
+  // mlir_context: the MLIR context used for IR emission.
+  // llvm_module: the LLVM module to emit IR into. It's built using the LLVM
+  //              context inside of mlir_context.
   // instruction_to_profile_idx: the mapping from HLO instructions to their
   //              index in the profiling array.
   // computation_to_profile_idx: the mapping from HLO computations to their
   //              index in the profiling array.
   // emit_code_for_msan: whether emitted code should be compatible with msan.
-  IrEmitter(const HloModule& hlo_module, const BufferAssignment& assignment,
-            llvm::Module* llvm_module,
+  IrEmitter(mlir::MLIRContext* mlir_context, const HloModule& hlo_module,
+            const BufferAssignment& assignment, llvm::Module* llvm_module,
             std::unordered_map<const HloInstruction*, int64>
                 instruction_to_profile_idx,
             std::unordered_map<const HloComputation*, int64>
@@ -442,6 +446,7 @@
   // module's function list).
   std::unique_ptr<IrFunction> compute_function_;
   llvm::IRBuilder<> b_;
+  mlir::MLIRContext* mlir_context_;
 
   // The buffer allocation slice for the root of the computation being compiled.
   // Only relevant for thread local computations.
diff --git a/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc b/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc
new file mode 100644
index 0000000..e7d52c2
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc
@@ -0,0 +1,132 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/mlir_emitter.h"
+
+#include "llvm/Linker/Linker.h"
+#include "llvm/Transforms/IPO/Internalize.h"
+#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"  // from @llvm-project
+#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"  // from @llvm-project
+#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"  // from @llvm-project
+#include "mlir/Dialect/Linalg/Passes.h"  // from @llvm-project
+#include "mlir/IR/Module.h"  // from @llvm-project
+#include "mlir/Pass/Pass.h"  // from @llvm-project
+#include "mlir/Pass/PassManager.h"  // from @llvm-project
+#include "mlir/Target/LLVMIR.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
+
+namespace xla {
+namespace cpu {
+namespace {
+
+// Lower an MLIR module to an LLVM module.
+std::unique_ptr<llvm::Module> MakeLLVMModule(mlir::OwningModuleRef module) {
+  mlir::PassManager manager(module->getContext());
+  manager.addPass(mlir::createConvertLinalgToLoopsPass());
+  manager.addPass(mlir::createConvertLinalgToLLVMPass());
+  manager.addPass(mlir::createConvertVectorToLLVMPass());
+  manager.addPass(mlir::createLowerToLLVMPass());
+  CHECK(succeeded(manager.run(*module)));
+  return mlir::translateModuleToLLVMIR(*module);
+}
+
+// Get arguments to pass a memref to an mlir function.
+void BuildViewForBuffer(llvm::SmallVectorImpl<llvm::Value *> *args,
+                        llvm::IRBuilder<> *b, const Shape &opShape,
+                        llvm::Value *op_val) {
+  llvm::Type *ty = op_val->getType();
+  while (auto aty = llvm::dyn_cast<llvm::ArrayType>(
+             llvm::cast<llvm::PointerType>(ty)->getElementType())) {
+    ty = aty->getElementType()->getPointerTo();
+  }
+  op_val = b->CreateBitCast(op_val, ty);
+
+  args->push_back(op_val);          // Allocated pointer.
+  args->push_back(op_val);          // Aligned pointer.
+  args->push_back(b->getInt64(0));  // Offset.
+
+  // Sizes.
+  for (int64 dim : opShape.dimensions()) {
+    args->push_back(b->getInt64(dim));
+  }
+
+  int64_t accumulated_stride = 1;
+  llvm::SmallVector<int64_t, 4> strides(opShape.rank(), 1);
+  for (int64 dim : LayoutUtil::MinorToMajor(opShape)) {
+    strides[dim] = accumulated_stride;
+    accumulated_stride *= opShape.dimensions(dim);
+  }
+
+  // Strides.
+  for (int64 stride : strides) {
+    args->push_back(b->getInt64(stride));
+  }
+}
+}  // namespace
+
+Status EmitMlirFuncAndCall(
+    mlir::MLIRContext *context, llvm::IRBuilder<> *b, const Shape &result_shape,
+    llvm::ArrayRef<Shape> operand_shapes, llvm::Value *result_ptr,
+    llvm::ArrayRef<llvm::Value *> operand_ptrs, llvm::StringRef func_name,
+    llvm::function_ref<void(mlir::OpBuilder *, mlir::FuncOp)> emitter) {
+  llvm::Module *llvm_module = b->GetInsertBlock()->getParent()->getParent();
+  mlir::Builder mlir_builder(context);
+
+  // Get memref types for the inputs and output.
+  TF_ASSIGN_OR_RETURN(mlir::Type ret_memref, ConvertTensorShapeToMemRefType(
+                                                 result_shape, mlir_builder));
+  std::vector<mlir::Type> operand_types = {ret_memref};
+  for (int i = 0; i != operand_shapes.size(); ++i) {
+    TF_ASSIGN_OR_RETURN(
+        mlir::Type op_memref,
+        ConvertTensorShapeToMemRefType(operand_shapes[i], mlir_builder));
+    operand_types.push_back(op_memref);
+  }
+
+  // Create the function an call the emission callback.
+  mlir::Location loc = mlir::UnknownLoc::get(context);
+  auto function = mlir::FuncOp::create(
+      loc, func_name, mlir::FunctionType::get(operand_types, {}, context));
+  function.addEntryBlock();
+  mlir::OwningModuleRef mlir_module = mlir::ModuleOp::create(loc);
+  mlir_module->push_back(function);
+  mlir::OpBuilder op_builder(&function.getBody());
+  emitter(&op_builder, function);
+
+  // Now link it all into the main LLVM module.
+  auto mlir_llvm_module = MakeLLVMModule(std::move(mlir_module));
+  mlir_llvm_module->setDataLayout(llvm_module->getDataLayout());
+  llvm::Linker::linkModules(
+      *llvm_module, std::move(mlir_llvm_module), llvm::Linker::None,
+      [](llvm::Module &M, const llvm::StringSet<> &GVS) {
+        llvm::internalizeModule(M, [&GVS](const llvm::GlobalValue &GV) {
+          return !GV.hasName() || (GVS.count(GV.getName()) == 0);
+        });
+      });
+
+  // And leave behind a call to the function generated by MLIR.
+  llvm::Function *func = llvm_module->getFunction(func_name);
+  llvm::SmallVector<llvm::Value *, 4> op_vals;
+  BuildViewForBuffer(&op_vals, b, result_shape, result_ptr);
+  for (int i = 0; i != operand_shapes.size(); ++i) {
+    BuildViewForBuffer(&op_vals, b, operand_shapes[i], operand_ptrs[i]);
+  }
+  b->CreateCall(func, op_vals);
+
+  return Status::OK();
+}
+
+}  // namespace cpu
+}  // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/mlir_emitter.h b/tensorflow/compiler/xla/service/cpu/mlir_emitter.h
new file mode 100644
index 0000000..bc0741e
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/mlir_emitter.h
@@ -0,0 +1,43 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_MLIR_EMITTER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_MLIR_EMITTER_H_
+
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Value.h"
+#include "mlir/IR/Builders.h"  // from @llvm-project
+#include "mlir/IR/Function.h"  // from @llvm-project
+#include "mlir/IR/MLIRContext.h"  // from @llvm-project
+#include "tensorflow/compiler/xla/shape.h"
+#include "tensorflow/compiler/xla/status.h"
+
+namespace xla {
+namespace cpu {
+
+// Create a new MLIR function with the name `func_name`, populate it with
+// `emitter` and create a call, passing it the buffers defined by
+// resultShape/resultPtr and operandShapes/operandPtrs. The function is added to
+// the LLVM module at `b`s insertion point.
+Status EmitMlirFuncAndCall(
+    mlir::MLIRContext *context, llvm::IRBuilder<> *b, const Shape &result_shape,
+    llvm::ArrayRef<Shape> operand_shapes, llvm::Value *result_ptr,
+    llvm::ArrayRef<llvm::Value *> operand_ptrs, llvm::StringRef func_name,
+    llvm::function_ref<void(mlir::OpBuilder *, mlir::FuncOp)> emitter);
+
+}  // namespace cpu
+}  // namespace xla
+
+#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_MLIR_EMITTER_H_